diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..6ff864e9 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,39 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +max_line_length = 100 + +[*.java] +indent_size = 4 +indent_style = space +tab_width = 4 +ij_continuation_indent_size = 8 +ij_java_binary_operation_sign_on_next_line = true +ij_java_binary_operation_wrap = normal +ij_java_call_parameters_new_line_after_left_paren = true +ij_java_call_parameters_wrap = on_every_item +ij_java_class_count_to_use_import_on_demand = 9999 +ij_java_doc_add_blank_line_after_param_comments = true +ij_java_doc_add_blank_line_after_return = true +ij_java_doc_align_exception_comments = false +ij_java_doc_align_param_comments = false +ij_java_doc_do_not_wrap_if_one_line = true +ij_java_doc_enable_formatting = true +ij_java_doc_indent_on_continuation = true +ij_java_doc_keep_empty_lines = true +ij_java_doc_preserve_line_breaks = false +ij_java_imports_layout = com.alibaba.flink.shuffle.**, |, org.apache.flink.**, |, org.apache.flink.shaded.**, |, *, |, javax.**, |, java.**, |, scala.**, |, $* +ij_java_layout_static_imports_separately = true +ij_java_method_call_chain_wrap = on_every_item +ij_java_method_parameters_new_line_after_left_paren = true +ij_java_method_parameters_wrap = on_every_item +ij_java_names_count_to_use_import_on_demand = 9999 +ij_java_variable_annotation_wrap = normal +ij_java_wrap_first_method_in_call_chain = true + +[*.xml] +indent_style = tab +indent_size = 4 diff --git a/.github/actions/rerun-tests/Dockerfile b/.github/actions/rerun-tests/Dockerfile new file mode 100644 index 00000000..a1a0cecc --- /dev/null +++ b/.github/actions/rerun-tests/Dockerfile @@ -0,0 +1,7 @@ +FROM alpine:3.10 + +RUN apk add --no-cache curl jq + +COPY entrypoint.sh /entrypoint.sh + +ENTRYPOINT ["/entrypoint.sh"] diff --git a/.github/actions/rerun-tests/action.yaml b/.github/actions/rerun-tests/action.yaml new file mode 100644 index 00000000..e828b516 --- /dev/null +++ b/.github/actions/rerun-tests/action.yaml @@ -0,0 +1,11 @@ +name: 'Re-Test' +description: 'Re-Runs the last workflow for a PR' +inputs: + token: + description: 'GitHub Token' + required: true +runs: + using: 'docker' + image: 'Dockerfile' + env: + GITHUB_TOKEN: ${{ inputs.token }} diff --git a/.github/actions/rerun-tests/entrypoint.sh b/.github/actions/rerun-tests/entrypoint.sh new file mode 100755 index 00000000..409b153f --- /dev/null +++ b/.github/actions/rerun-tests/entrypoint.sh @@ -0,0 +1,35 @@ +#!/bin/sh + +set -ex + +if ! jq -e '.issue.pull_request' ${GITHUB_EVENT_PATH}; then + echo "Not a PR... Exiting." + exit 0 +fi + +if [ "$(jq -r '.comment.body' ${GITHUB_EVENT_PATH})" != "/retest" ]; then + echo "Nothing to do... Exiting." + exit 0 +fi + +PR_URL=$(jq -r '.issue.pull_request.url' ${GITHUB_EVENT_PATH}) + +curl --request GET \ + --url "${PR_URL}" \ + --header "authorization: Bearer ${GITHUB_TOKEN}" \ + --header "content-type: application/json" > pr.json + +ACTOR=$(jq -r '.user.login' pr.json) +BRANCH=$(jq -r '.head.ref' pr.json) + +curl --request GET \ + --url "https://api.github.com/repos/${GITHUB_REPOSITORY}/actions/runs?event=pull_request&actor=${ACTOR}&branch=${BRANCH}" \ + --header "authorization: Bearer ${GITHUB_TOKEN}" \ + --header "content-type: application/json" | jq '.workflow_runs | max_by(.run_number)' > run.json + +RERUN_URL=$(jq -r '.rerun_url' run.json) + +curl --request POST \ + --url "${RERUN_URL}" \ + --header "authorization: Bearer ${GITHUB_TOKEN}" \ + --header "content-type: application/json" diff --git a/.github/workflows/command.yaml b/.github/workflows/command.yaml new file mode 100644 index 00000000..ab5d08e1 --- /dev/null +++ b/.github/workflows/command.yaml @@ -0,0 +1,16 @@ +name: commands +on: + issue_comment: + types: [created] + +jobs: + retest: + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@v2 + + - name: Re-Test Action + uses: ./.github/actions/rerun-tests + with: + token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/flink-remote-shuffle-tests.yaml b/.github/workflows/flink-remote-shuffle-tests.yaml new file mode 100644 index 00000000..4bc29a1a --- /dev/null +++ b/.github/workflows/flink-remote-shuffle-tests.yaml @@ -0,0 +1,24 @@ +name: Remote Shuffle Service +on: + push: + branches: [main] + pull_request: + branches: [main] +jobs: + build-and-test: + runs-on: ubuntu-18.04 + steps: + - uses: actions/checkout@v2 + - name: Set up JDK 8 + uses: actions/setup-java@v2 + with: + java-version: "8" + distribution: "adopt" + - name: Cache Maven packages + uses: actions/cache@v2 + with: + path: ~/.m2 + key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }} + restore-keys: ${{ runner.os }}-m2 + - name: Build and test with maven + run: mvn clean install -DskipTests; mvn -PincludeE2E -B clean verify diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..cb34e38e --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +.idea +target +tmp +*.class +*.iml +*.swp +*.jar +*.zip +*.log +.DS_Store +build-target +!flink-rpc-akka.jar +!shaded-flink-rpc-akka.jar diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..42d79a07 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,30 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +FROM openjdk:8-jre +ARG REMOTE_SHUFFLE_VERSION + +RUN set -e && mkdir -p /flink-remote-shuffle + +RUN ln -s /flink-remote-shuffle /flink-shuffle + +COPY ./shuffle-dist/target/flink-remote-shuffle-${REMOTE_SHUFFLE_VERSION}-bin/flink-remote-shuffle-${REMOTE_SHUFFLE_VERSION} /flink-remote-shuffle + +ENV LANG en_US.UTF-8 +ENV LANGUAGE en_US.UTF-8 +ENV LC_ALL en_US.UTF-8 diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..92d420b8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,203 @@ +Copyright 2021 The Alibaba Group Holding Ltd. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021, The Alibaba Group Holding Ltd. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/NOTICE b/NOTICE new file mode 100644 index 00000000..3632a7df --- /dev/null +++ b/NOTICE @@ -0,0 +1,28 @@ +Remote Shuffle Service for Flink +Copyright 2021 The Alibaba Group Holding Ltd + +This product includes software developed at The Alibaba Group Holding Ltd +(https://www.alibaba.com/). + +This project bundles the following dependencies under the Apache Software License 2.0 (http://www.apache.org/licenses/LICENSE-2.0.txt): + +- org.apache.flink:flink-core:1.14.0 +- org.apache.flink:flink-runtime:1.14.0 +- org.apache.flink:flink-rpc:1.14.0 +- commons-cli:commons-cli:1.3.1 +- org.apache.commons:commons-lang3:3.3.2 +- io.netty:netty:4.1.49.Final +- org.apache.zookeeper:zookeeper:3.4.14 +- org.apache.curator:curator-framework:4.2.0 +- io.fabric8:kubernetes-client:5.2.1 +- org.apache.logging.log4j:log4j-api:2.12.1 +- org.apache.logging.log4j:log4j-core:2.12.1 +- org.apache.logging.log4j:log4j-slf4j-impl:2.12.1 +- org.apache.logging.log4j:log4j-1.2-api:2.12.1 +- com.alibaba.middleware:metrics-core-api:2.0.6 +- com.alibaba.middleware:metrics-core-impl:2.0.6 +- com.alibaba.middleware:metrics-integration:2.0.6 +- com.alibaba.middleware:metrics-rest:2.0.6 +- com.alibaba.middleware:metrics-reporter:2.0.6 + + diff --git a/README.md b/README.md new file mode 100644 index 00000000..9348e0f1 --- /dev/null +++ b/README.md @@ -0,0 +1,185 @@ +# Remote Shuffle Service for Flink + +- [Overview](#overview) +- [Supported Flink Version](#supported-flink-version) +- [Document](#document) +- [Building from Source](#building-from-source) +- [Example](#example) +- [How to Contribute](#how-to-contribute) +- [Support](#support) +- [Acknowledge](#acknowledge) + +## Overview + +This project implements a remote shuffle service for batch data processing +of [Flink](https://flink.apache.org/). By adopting the storage and compute separation architecture, +it brings several important benefits: + +1. The scale up/down of computing resources and storage resources is now decoupled which means you + can scale each up/down on demand freely. + +2. Compute and storage stability never influence each other anymore. The remote shuffle service is + free of user-code which can improve shuffle stability. For example, the termination + of `TaskExecutor`s will not lead to data loss and the termination of remote `ShuffleWorker`s is + tolerable. + +3. By offloading the data shuffle work to the remote shuffle service, the computation resources can + be released immediately after the upstream map tasks finish which can save resources. + +In addition, the remote shuffle implementation borrows some good designs from Flink which can +benefit both stability and performance, for example: + +1. Managed memory is preferred. Both the storage and network memory are managed which can + significantly solve the OutOfMemory issue. + +2. The credit-based backpressure mechanism is adopted which is good for both network stability and + performance. + +3. The zero-copy network data transmission is implemented which can save memory and is also good for + stability and performance. + +Besides, there are other important optimizations like load balancing and better sequential IO ( +benefiting from the centralized service per node), tcp connection reuse, shuffle data compression, +adaptive execution (together with FLIP-187), etc. + +Before going open source, this project has been used in production widely and behaves well on both +stability and performance. Hope you enjoy it. + +## Supported Flink Version + +The remote shuffle service works together with Flink 1.14+. Some patches are needed to be applied to +Flink to support lower Flink versions. If you need any help on that, please let us know, we can +offer some help to prepare the patches for the Flink version you use. + +## Document + +The remote shuffle service supports standalone, yarn and k8s deployment. You can find the full user +guide [here](./docs/user_guide.md) +. In the future, more internal implementation detail specifications will be supplemented. + +## Building from Source + +To build this flink remote shuffle project from source, you should first clone the project: + +```bash +git clone git@github.com:flink-extended/flink-remote-shuffle.git +``` + +Then you can build the project using maven (Maven and Java 8 required): + +```bash +cd flink-remote-shuffle # switch to the remote shuffle project home directory +mvn package -DskipTests +``` + +After finish, you can find the target distribution in the build-target folder. Note that if you want to run tests locally, we suggest you to run `mvn install -DskipTests` first to avoid potential failures. + +For k8s deployment, you can run the following command to build the docker image (Docker required): + +```bash +cd flink-remote-shuffle # switch to the remote shuffle project home directory +sh ./tools/build_docker_image.sh +``` + +You can also publish the docker image by running the following command. The script that publishes +the docker image takes three arguments: the first one is the registry address (default value is +'docker.io'), the second one is the namespace (default value is 'flinkremoteshuffle'), the third one +is the repository name (default value is 'flink-remote-shuffle'). + +```bash +cd flink-remote-shuffle # switch to the remote shuffle project home directory +sh ./tools/publish_docker_image.sh REGISTRY NAMESPACE REPOSITORY +``` + +## Example + +After building the code from source, you can start and run a demo flink batch job using the remote +shuffle service locally (Flink 1.14+ required): + +As the first step, you can download the Flink distribution from the +Flink's [download page](https://flink.apache.org/downloads.html#apache-flink-1140), for example, +Apache Flink 1.14.0 for Scala 2.11: + +```bash +wget https://dlcdn.apache.org/flink/flink-1.14.0/flink-1.14.0-bin-scala_2.11.tgz +tar zxvf flink-1.14.0-bin-scala_2.11.tgz +``` + +Then after building the remote shuffle project from source, you can copy the shuffle plugin jar file +from build-target/lib directory (for example, build-target/lib/shuffle-plugin-1.0-SNAPSHOT.jar) to +the Flink lib directory and copy the build-in example job jar file to the flink home directory ( +flink-1.14.0): + +```bash +cp flink-remote-shuffle/build-target/lib/shuffle-plugin-1.0-SNAPSHOT.jar flink-1.14.0/lib/ +cp flink-remote-shuffle/build-target/examples/BatchJobDemo.jar flink-1.14.0/ +``` + +After that, you can start a local remote shuffle cluster by running the following command: + +```bash +cd flink-remote-shuffle # switch to the remote shuffle project home directory +cd build-target # run after building from source +./bin/start-cluster.sh -D remote-shuffle.storage.local-data-dirs="[HDD]/tmp/" -D remote-shuffle.memory.data-writing-size=256m -D remote-shuffle.memory.data-reading-size=256m +``` + +Then you can start a local Flink cluster and config Flink to use the remote shuffle service by +running the following command: + +```bash +cd flink-1.14.0 # switch to the flink home directory +./bin/start-cluster.sh -D shuffle-service-factory.class=com.alibaba.flink.shuffle.plugin.RemoteShuffleServiceFactory -D remote-shuffle.manager.rpc-address=127.0.0.1 +``` + +Finally, you can run the demo batch job: + +```bash +cd flink-1.14.0 # switch to the flink home directory +bin/flink run -c com.alibaba.flink.shuffle.examples.BatchJobDemo ./BatchJobDemo.jar +``` + +To stop the local cluster, you can just run the stop-cluster.sh script in the bin directory: + +```bash +cd flink-1.14.0 # switch to the flink home directory +bin/stop-cluster.sh +``` + +```bash +cd flink-remote-shuffle # switch to the remote shuffle project home directory +bin/stop-cluster.sh +``` + +## How to Contribute + +Any feedback of this project is highly appreciated. You can report a bug by opening an issue on +GitHub. You can also contribute any new features or improvements. See +the [contribution guide](./docs/contribution.md) +for more information. + +## Support + +We provide free support for users using this project. You can scan the following QR code to join +the [DingTalk](https://www.dingtalk.com/) user support group for further help and collaboration: + +English: + +
+ +
+ +Chinese: + +
+ +
+ +Another way is to join the Slack channel by clicking +this [invitation](https://join.slack.com/t/slack-5xu7894/shared_invite/zt-ykp807ok-1JXMcE6HS~NCplRp2T31fQ) +. + +## Acknowledge + +This is a Flink ecosystem project. Apache Flink is an excellent unified stateful data processing +engine. This project borrows some good designs (e.g. the credit-based backpressure) and building +blocks (e.g. rpc and high availability) from Flink. diff --git a/docs/configuration.md b/docs/configuration.md new file mode 100644 index 00000000..0178adfd --- /dev/null +++ b/docs/configuration.md @@ -0,0 +1,254 @@ + + +# Configuration + +- [Options for Flink Cluster](#options-for-flink-cluster) + - [Data Transmission Related (Client)](#data-transmission-related-(client)) + - [ShuffleMaster Related](#shufflemaster-related) + - [High Availability Related (Client)](#high-availability-related-(client)) +- [Options for Shuffle Cluster](#options-for-shuffle-cluster) + - [High Availability Related (Server)](#high-availability-related-(server)) + - [RPC & Heartbeat Related](#rpc--heartbeat-related) + - [ShuffleWorker Related](#shuffleworker-related) + - [ShuffleManager Related](#shufflemanager-related) + - [Data Transmission Related (Server)](#data-transmission-related-(server)) + - [Metric Related](#metric-related) +- [Options for Deployment](#options-for-deployment) + - [K8s Deployment Related](#k8s-deployment-related) + - [Yarn Deployment Related](#yarn-deployment-related) + - [Standalone Deployment Related](#standalone-deployment-related) + +This section will present all valid config options that can be used by the remote shuffle cluster +together with the corresponding Flink cluster (jobs using the remote shuffle). Among these config +options, some are required which means you must give a config value, some are optional and the +default values are usually good enough for most cases. + +For the configuration used by Flink cluster, you can just put them in the Flink configuration file, +please refer to the Flink document for more information. + +For the configuration used by the remote shuffle cluster, there are different ways to config them +depending on the deployment type: + +1. For standalone and local deployment, you should put the customized configuration in the remote + shuffle configuration file conf/remote-shuffle-conf.yaml with the format key: value. + +2. For yarn deployment, the `ShuffleManager` configuration should be put in the remote shuffle + configuration file conf/remote-shuffle-conf.yaml and the `ShuffleWorker` configuration should be + put in the yarn-site.xml. + +3. For k8s deployment, you should put the customized configuration in the k8s deployment yaml file. + +Please read the following part and refer to the deployment +guide ([Standalone](./deploy_standalone_mode.md) +, [Yarn](./deploy_on_yarn.md) +, [Kubernetes](./deploy_on_kubernetes.md)) +for more information. + +**Important:** To use the remote shuffle service in Flink, you must put the following configuration +in the Flink configuration file: + +```yaml +shuffle-service-factory.class: com.alibaba.flink.shuffle.plugin.RemoteShuffleServiceFactory +``` + +## Options for Flink Cluster + +This section will present the valid config options that can be used by the Flink cluster and should +be put in the Flink configuration file. + +### Data Transmission Related (Client) + +| Key | Value Type | Default Value | Version | Required | Description | +| --- | ---------- | ------------- | ------- | -------- | ----------- | +| `remote-shuffle.job.memory-per-partition` | MemorySize | `64m` | 1.0.0 | false | The size of network buffers required per result partition. The minimum valid value is 8M. Usually, several hundreds of megabytes memory is enough for large scale batch jobs. | +| `remote-shuffle.job.memory-per-gate` | MemorySize | `32m` | 1.0.0 | false | The size of network buffers required per input gate. The minimum valid value is 8m. Usually, several hundreds of megabytes memory is enough for large scale batch jobs. | +| `remote-shuffle.job.enable-data-compression` | Bool | `true` | 1.0.0 | false | Whether to enable shuffle data compression. Usually, enabling data compression can save the storage space and achieve better performance. | +| `remote-shuffle.job.data-partition-factory-name` | String | `com.alibaba.flink.shuffle.storage.partition.LocalFileMapPartitionFactory` | 1.0.0 | false | Defines the factory used to create new data partitions. According to the specified data partition factory from the client side, the `ShuffleManager` will return corresponding resources and the `ShuffleWorker` will create the corresponding partitions. All supported data partition factories can be found in the [data storage](./user_guide.md#data-storage) section. | +| `remote-shuffle.job.concurrent-readings-per-gate` | Integer | `2147483647` | 1.0.0 | false | The maximum number of remote shuffle channels to open and read concurrently per input gate. | +| `remote-shuffle.transfer.server.data-port` | Integer | `10086` | 1.0.0 | false | Data port to write shuffle data to and read shuffle data from `ShuffleWorker`s. This port must be accessible from the Flink cluster. | +| `remote-shuffle.transfer.client.num-threads` | Integer | `-1` | 1.0.0 | false | The number of Netty threads to be used at the client (flink job) side. The default `-1` means that `2 * (the number of slots)` will be used. | +| `remote-shuffle.transfer.client.connect-timeout` | Duration | `2min` | 1.0.0 | false | The TCP connection setup timeout of the Netty client. | +| `remote-shuffle.transfer.client.connect-retries` | Integer | `3` | 1.0.0 | false | Number of retries when failed to connect to the remote `ShuffleWorker`. | +| `remote-shuffle.transfer.client.connect-retry-wait` | Duration | `3s` | 1.0.0 | flase | Time to wait between two consecutive connection retries. | +| `remote-shuffle.transfer.transport-type` | String | `auto` | 1.0.0 | false | The Netty transport type, either `nio` or `epoll`. The `auto` means "selecting the proper mode automatically based on the platform. Note that the `epoll` mode can get better performance, less GC and have more advanced features which are only available on modern Linux. | +| `remote-shuffle.transfer.send-receive-buffer-size` | MemorySize | `0b` | 1.0.0 | false | The Netty send and receive buffer size. The default `0` means the system buffer size (cat /proc/sys/net/ipv4/tcp_[rw]mem) and is 4 MiB in modern Linux. | +| `remote-shuffle.transfer.heartbeat.interval` | Duration | `1min` | 1.0.0 | false | The time interval to send heartbeat between the Netty server and Netty client. | +| `remote-shuffle.transfer.heartbeat.timeout` | Duration | `5min` | 1.0.0 | flase | Heartbeat timeout used to detect broken Netty connections. | + +### ShuffleMaster Related + +| Key | Value Type | Default Value | Version | Required | Description | +| --- | ---------- | ------------- | ------- | -------- | ----------- | +| `remote-shuffle.worker.max-recovery-time` | Duration | `3min` | 1.0.0 | false | Maximum time to wait before reproducing the data stored in the lost worker (heartbeat timeout). The lost worker may become available again in this timeout. | +| `remote-shuffle.client.heartbeat.interval` | Duration | `10s` | 1.0.0 | false | Time interval for `ShuffleClient` (running in `ShuffleMaster`) to request heartbeat from `ShuffleManager`. | +| `remote-shuffle.client.heartbeat.timeout` | Duration | `120s` | 1.0.0 | false | Timeout for `ShuffleClient` (running in `ShuffleMaster`) and `ShuffleManager` to request and receive heartbeat. | +| `remote-shuffle.rpc.timeout` | Duration | `30s` | 1.0.0 | false | Timeout for `ShuffleClient` (running in `ShuffleMaster`) <-> `ShuffleManager` rpc calls. | +| `remote-shuffle.rpc.akka-frame-size` | String | `10485760b` | 1.0.0 | false | Maximum size of messages can be sent through rpc calls. | +| `remote-shuffle.cluster.registration.timeout` | Duration | `5min` | 1.0.0 | false | Defines the timeout for the `ShuffleClient` (running in `ShuffleMaster`) registration to the `ShuffleManager`. If the duration is exceeded without a successful registration, then the `ShuffleClient` terminates which will lead to the termination of the Flink AM. | +| `remote-shuffle.cluster.registration.error-delay` | Duration | `10s` | 1.0.0 | false | The pause made after a registration attempt caused an exception (other than timeout). | +| `remote-shuffle.cluster.registration.refused-delay` | Duration | `30s` | 1.0.0 | false | The pause made after the registration attempt was refused. | + +### High Availability Related (Client) + +| Key | Value Type | Default Value | Version | Required | Description | +| --- | ---------- | ------------- | ------- | -------- | ----------- | +| `remote-shuffle.high-availability.mode` | String | `NONE` | 1.0.0 | false (Must be set if you want to enable HA) | Defines high-availability mode used for the cluster execution. To enable high-availability, set this mode to `ZOOKEEPER` or specify FQN of factory class. | +| `remote-shuffle.ha.zookeeper.quorum` | String | `null` | 1.0.0 | false (Must be set if high-availability mode is ZOOKEEPER) | The ZooKeeper quorum to use when running the remote shuffle cluster in a high-availability mode with ZooKeeper. | +| `remote-shuffle.ha.zookeeper.root-path` | String | `flink-remote-shuffle` | 1.0.0 | false | The root path in ZooKeeper under which the remote shuffle cluster stores its entries. Different remote shuffle clusters will be distinguished by the cluster id. This config must be consistent between the Flink cluster side and the shuffle cluster side. | +| remote-shuffle.ha.zookeeper.session-timeout | Duration | `60s` | 1.0.0 | false | Defines the session timeout for the ZooKeeper session. | +| `remote-shuffle.ha.zookeeper.connection-timeout` | Duration | `15s` | 1.0.0 | false | Defines the connection timeout for the ZooKeeper client. | +| `remote-shuffle.ha.zookeeper.retry-wait` | Duration | `5s` | 1.0.0 | false | Defines the pause between consecutive connection retries. | +| `remote-shuffle.ha.zookeeper.max-retry-attempts` | Integer | `3` | 1.0.0 | false | Defines the number of connection retries before the client gives up. | + +## Options for Shuffle Cluster + +This section will present the valid config options that can be used by the shuffle cluster. Note: +Where to put the customized configuration depends on the deployment type. For k8s deployment, you +should put these in the k8s deployment yaml file. For yarn deployment, you should put these in the +yarn-site.xml for the `ShuffleWorker`s and in the conf/remote-shuffle-conf.yaml file for +the `ShuffleManager`. For standalone and local deployment, you should put these in the +conf/remote-shuffle-conf.yaml file. + +### High Availability Related (Server) + +| Key | Value Type | Default Value | Version | Required | Description | +| --- | ---------- | ------------- | ------- | -------- | ----------- | +| `remote-shuffle.cluster.id` | String | `/default-cluster` | 1.0.0 | false | The unique ID of the remote shuffle cluster used by high-availability. Different shuffle clusters sharing the same zookeeper instance must be configured with different cluster id. This config must be consistent between the `ShuffleManager` and `ShuffleWorker`s. | +| `remote-shuffle.high-availability.mode` | String | `NONE` | 1.0.0 | false (Must be set if you want to enable HA) | Defines high-availability mode used for the cluster execution. To enable high-availability, set this mode to `ZOOKEEPER` or specify FQN of factory class. | +| `remote-shuffle.ha.zookeeper.quorum` | String | `null` | 1.0.0 | false (Must be set if high-availability mode is ZOOKEEPER) | The ZooKeeper quorum to use when running the remote shuffle cluster in a high-availability mode with ZooKeeper. | +| `remote-shuffle.ha.zookeeper.root-path` | String | `flink-remote-shuffle` | 1.0.0 | false | The root path in ZooKeeper under which the remote shuffle cluster stores its entries. Different remote shuffle clusters will be distinguished by the cluster id. This config must be consistent between the Flink cluster side and the shuffle cluster side. | +| `remote-shuffle.ha.zookeeper.session-timeout` | Duration | `60s` | 1.0.0 | false | Defines the session timeout for the ZooKeeper session. | +| `remote-shuffle.ha.zookeeper.connection-timeout` | Duration | `15s` | 1.0.0 | false | Defines the connection timeout for the ZooKeeper client. | +| `remote-shuffle.ha.zookeeper.retry-wait` | Duration | `5s` | 1.0.0 | false | Defines the pause between consecutive connection retries. | +| `remote-shuffle.ha.zookeeper.max-retry-attempts` | Integer | `3` | 1.0.0 | false | Defines the number of connection retries before the client gives up. | + +### RPC & Heartbeat Related + +| Key | Value Type | Default Value | Version | Required | Description | +| --- | ---------- | ------------- | ------- | -------- | ----------- | +| `remote-shuffle.worker.heartbeat.interval` | Duration | `10s` | 1.0.0 | false | Time interval for `ShuffleManager` to request heartbeat from `ShuffleWorker`s. | +| `remote-shuffle.worker.heartbeat.timeout` | Duration | `60s` | 1.0.0 | false | Timeout for `ShuffleManager` and `ShuffleWorker` to request and receive heartbeat. | +| `remote-shuffle.rpc.timeout` | Duration | `30s` | 1.0.0 | false | Timeout for `ShuffleWorker` <-> `ShuffleManager` rpc calls. | +| `remote-shuffle.rpc.akka-frame-size` | String | `10485760b` | 1.0.0 | false | Maximum size of messages can be sent through rpc calls. | + +### ShuffleWorker Related + +| Key | Value Type | Default Value | Version | Required | Description | +| --- | ---------- | ------------- | ------- | -------- | ----------- | +| `remote-shuffle.storage.local-data-dirs` | String | `null` | 1.0.0 | true | Local file system directories to persist partitioned data to. Multiple directories can be configured and these directories should be separated by comma (,). Each configured directory can be attached with an optional label which indicates the disk type. The valid disk types include `SSD` and `HDD`. If no label is offered, the default type would be `HDD`. Here is a simple valid configuration example: **`[SSD]/dir1/,[HDD]/dir2/,/dir3/`**. This option must be configured and the configured directories must exist. | +| `remote-shuffle.storage.enable-data-checksum` | Bool | `false` | 1.0.0 | false | Whether to enable data checksum for data integrity verification or not. | +| `remote-shuffle.memory.data-writing-size` | MemorySize | `4g` | 1.0.0 | 1.0.0 | Size of memory to be allocated for data writing. Larger value means more direct memory consumption which may lead to better performance. The configured value must be no smaller than `64m` and the buffer size configured by `remote-shuffle.memory.buffer-size`, otherwise an exception will be thrown. | +| `remote-shuffle.memory.data-reading-size` | MemorySize | `4g` | 1.0.0 | 1.0.0 | Size of memory to be allocated for data reading. Larger value means more direct memory consumption which may lead to better performance. The configured value must be no smaller than `64m` and the buffer size configured by `remote-shuffle.memory.buffer-size`, otherwise an exception will be thrown. | +| `remote-shuffle.memory.buffer-size` | MemorySize | `32k` | 1.0.0 | false | Size of the buffer to be allocated. Those allocated buffers will be used by both network and storage for data transmission, data writing and data reading. | +| `remote-shuffle.storage.preferred-disk-type` | String | `SSD` | 1.0.0 | false | Preferred disk type to use for data storage. The valid types include `SSD` and `HDD`. If there are disks of the preferred type, only those disks will be used. However, this is not a strict restriction, which means if there is no disk of the preferred type, disks of other types will be also used. | +| `remote-shuffle.storage.hdd.num-executor-threads` | Integer | `8` | 1.0.0 | false | Number of threads to be used by data store for data partition processing of each HDD. The actual number of threads per disk will be `min[configured value, 4 * (number of processors)]`. | +| `remote-shuffle.storage.ssd.num-executor-threads` | Integer | `2147483647` | 1.0.0 | false | Number of threads to be used by data store for data partition processing of each SSD. The actual number of threads per disk will be `min[configured value, 4 * (number of processors)]`. | +| `remote-shuffle.storage.partition.max-writing-memory` | MemorySize | `128m` | 1.0.0 | false | Maximum memory size to use for the data writing of each data partition. Note that if the configured value is smaller than 16m, the minimum 16m will be used. | +| `remote-shuffle.storage.partition.max-reading-memory` | MemorySize | `128m` | 1.0.0 | false | Maximum memory size to use for the data reading of each data partition. Note that if the configured value is smaller than 16m, the minimum 16m will be used. | +| `remote-shuffle.storage.file-tolerable-failures` | Integer | `2147483647` | 1.0.0 | false | Maximum number of tolerable failures before marking a data partition as corrupted, which will trigger the reproduction of the corresponding data. | +| `remote-shuffle.cluster.registration.timeout` | Duration | `5min` | 1.0.0 | false | Defines the timeout for the `ShuffleWorker` registration to the `ShuffleManager`. If the duration is exceeded without a successful registration, then the `ShuffleWorker` terminates. | +| `remote-shuffle.cluster.registration.error-delay` | Duration | `10s` | 1.0.0 | false | The pause made after a registration attempt caused an exception (other than timeout). | +| `remote-shuffle.cluster.registration.refused-delay` | Duration | `30s` | 1.0.0 | false | The pause made after the registration attempt was refused. | +| `remote-shuffle.worker.host` | String | `null` | 1.0.0 | false | The external address of the network interface where the `ShuffleWorker` is exposed. If not set, it will be determined automatically. Note: Different workers may need different values for this option, usually it can be specified in a non-shared `ShuffleWorker` specific configuration file. | +| `remote-shuffle.worker.bind-policy` | String | `ip` | 1.0.0 | false | The automatic address binding policy used by the `ShuffleWorker` if `remote-shuffle.worker.host` is not set. The valid types include `name` and `ip`: `name` means using hostname as binding address, `ip` means using host's ip address as binding address. | +| `remote-shuffle.worker.bind-host` | String | `0.0.0.0` | 1.0.0 | false | The local address of the network interface that the `ShuffleWorker` binds to. | +| `remote-shuffle.worker.rpc-port` | String | `0` | 1.0.0 | false | Defines network port range the `ShuffleWorker` expects incoming RPC connections. Accepts a list of ports (“50100,50101”), ranges (“50100-50200”) or a combination of both. The default `0` means that the `ShuffleWorker` will search for a free port itself. | +| `remote-shuffle.worker.rpc-bind-port` | Integer | `null` | 1.0.0 | false | The local network port that the `ShuffleWorker` binds to. If not configured, the external port (configured by `remote-shuffle.worker.rpc-port`) will be used. | + +### ShuffleManager Related + +| Key | Value Type | Default Value | Version | Required | Description | +| --- | ---------- | ------------- | ------- | -------- | ----------- | +| `remote-shuffle.manager.rpc-address` | String | `null` | 1.0.0 | false | Defines the external network address to connect to for communication with the `ShuffleManager`. | +| `remote-shuffle.manager.rpc-bind-address` | String | `null` | 1.0.0 | false | The local address of the network interface that the `ShuffleManager` binds to. | +| `remote-shuffle.manager.rpc-port` | Integer | `23123` | 1.0.0 | false | Defines the external network port to connect to for communication with the `ShuffleManager`. | +| `remote-shuffle.manager.rpc-bind-port` | Integer | `null` | 1.0.0 | false | The local network port that the `ShuffleManager` binds to. If not configured, the external port (configured by `remote-shuffle.manager.rpc-port` ) will be used. | + +### Data Transmission Related (Server) + +| Key | Value Type | Default Value | Version | Required | Description | +| --- | ---------- | ------------- | ------- | -------- | ----------- | +| `remote-shuffle.transfer.transport-type` | String | `auto` | 1.0.0 | false | The Netty transport type, either `nio` or `epoll`. The `auto` means "selecting the proper mode automatically based on the platform. Note that the `epoll` mode can get better performance, less GC and have more advanced features which are only available on modern Linux. | +| `remote-shuffle.transfer.send-receive-buffer-size` | MemorySize | `0b` | 1.0.0 | false | The Netty send and receive buffer size. The default `0b` means the system buffer size (cat /proc/sys/net/ipv4/tcp_[rw]mem) and is 4 MiB in modern Linux. | +| `remote-shuffle.transfer.heartbeat.interval` | Duration | `1min` | 1.0.0 | false | The time interval to send heartbeat between the Netty server and Netty client. | +| `remote-shuffle.transfer.heartbeat.timeout` | Duration | `5min` | 1.0.0 | flase | Heartbeat timeout used to detect broken Netty connections. | + +### Metric Related + +| Key | Value Type | Default Value | Version | Required | Description | +| --- | ---------- | ------------- | ------- | -------- | ----------- | +| `remote-shuffle.metrics.enabled-http-server` | Bool | `true` | 1.0.0 | false | Whether the http server for requesting metrics is enabled. | +| `remote-shuffle.metrics.bind-host` | String | `0.0.0.0` | 1.0.0 | false | The local address of the network interface that the http metric server binds to. | +| `remote-shuffle.metrics.manager.bind-port` | Integer | `23101` | 1.0.0 | false | `ShuffleManager` http metric server bind port. | +| `remote-shuffle.metrics.worker.bind-port` | Integer | `23103` | 1.0.0 | false | `ShuffleWorker` http metric server bind port. | + +## Options for Deployment + +### K8s Deployment Related + +This section will present the valid config option that can be used by k8s deployment and should be +put in the k8s deployment yaml file. + +| Key | Value Type | Default Value | Version | Required | Description | +| --- | ---------- | ------------- | ------- | -------- | ----------- | +| `remote-shuffle.kubernetes.container.image` | String | `null` | 1.0.0 | false (Must be set if running in k8s environment) | Image to use for the remote `ShuffleManager` and worker containers. | +| `remote-shuffle.kubernetes.container.image.pull-policy` | String | `IfNotPresent` | 1.0.0 | false | The Kubernetes container image pull policy (`IfNotPresent` or `Always` or `Never`). The default policy is IfNotPresent to avoid putting pressure to image repository. | +| `remote-shuffle.kubernetes.host-network.enabled` | Bool | `true` | 1.0.0 | false | Whether to enable host network for pod. Generally, host network is faster. | +| `remote-shuffle.kubernetes.manager.cpu` | Double | `1.0` (It is better to increase this value for production usage) | 1.0.0 | false | The number of cpu used by the `ShuffleManager`. | +| `remote-shuffle.kubernetes.manager.env-vars` | String | `''` | 1.0.0 | true | Env vars for the `ShuffleManager`. Specified as key:value pairs separated by commas. You need to specify the right timezone through this config option, for example, set timezone as TZ:Asia/Shanghai. | +| `remote-shuffle.kubernetes.manager.labels` | String | `''` | 1.0.0 | false | The user-specified labels to be set for the `ShuffleManager` pod. Specified as key:value pairs separated by commas. For example, `version:alphav1,deploy:test`. | +| `remote-shuffle.kubernetes.manager.node-selector` | String | `''` | 1.0.0 | false | The user-specified node selector to be set for the `ShuffleManager` pod. Specified as key:value pairs separated by commas. For example, `environment:production,disk:ssd`. | +| `remote-shuffle.kubernetes.manager.tolerations` | String | `''` | 1.0.0 | false | The user-specified tolerations to be set to the `ShuffleManager` pod. The value should be in the form of `key:key1,operator:Equal,value:value1,effect:NoSchedule;key:key2,operator:Exists,effect:NoExecute,tolerationSeconds:6000`. | +| `remote-shuffle.kubernetes.worker.cpu` | Double | `1.0` (It is better to increase this value for production usage) | 1.0.0 | false | The number of cpu used by the `ShuffleWorker`. | +| `remote-shuffle.kubernetes.worker.env-vars` | String | `''` | 1.0.0 | true | Env vars for the `ShuffleWorker`. Specified as key:value pairs separated by commas. You need to specify the right timezone through this config option, for example, set timezone as `TZ:Asia/Shanghai`. | +| `remote-shuffle.kubernetes.worker.volume.empty-dirs` | String | `''` | 1.0.0 | false | Specify the kubernetes empty dir volumes that will be mounted into `ShuffleWorker` container. The value should be in form of `name:disk1,sizeLimit:5Gi,mountPath:/opt/disk1;name:disk2,sizeLimit:5Gi,mountPath:/opt/disk2. More specifically`, `name` is the name of the volume, `sizeLimit` is the limit size of the volume and `mountPath` is the mount path in container. | +| `remote-shuffle.kubernetes.worker.volume.host-paths` | String | `''` | 1.0.0 | false (Either this or `remote-shuffle.kubernetes.worker.volume.empty-dirs` must be configured for k8s deployment) | Specify the kubernetes HostPath volumes that will be mounted into `ShuffleWorker` container. The value should be in form of `name:disk1,path:/dump/1,mountPath:/opt/disk1;name:disk2,path:/dump/2,mountPath:/opt/disk2`. More specifically, `name` is the name of the volume, `path` is the directory location on host and `mountPath` is the mount path in container. | +| `remote-shuffle.kubernetes.worker.labels` | String | `''` | 1.0.0 | false | The user-specified labels to be set for the `ShuffleWorker` pods. Specified as key:value pairs separated by commas. For example, `version:alphav1,deploy:test`. | +| `remote-shuffle.kubernetes.worker.node-selector` | String | `''` | 1.0.0 | false | The user-specified node selector to be set for the `ShuffleWorker` pods. Specified as key:value pairs separated by commas. For example, `environment:production,disk:ssd`. | +| `remote-shuffle.kubernetes.worker.tolerations` | String | `''` | 1.0.0 | false | The user-specified tolerations to be set to the `ShuffleWorker` pods. The value should be in the form of `key:key1,operator:Equal,value:value1,effect:NoSchedule;key:key2,operator:Exists,effect:NoExecute,tolerationSeconds:6000`. | +| `remote-shuffle.kubernetes.manager.limit-factor.RESOURCE` | Integer | `1` | 1.0.0 | false | Kubernetes resource overuse limit factor for ShuffleManager. It should not be less than 1. The `RESOURCE` could be cpu, memory, ephemeral-storage and all other types supported by Kubernetes. For example, `remote-shuffle.kubernetes.manager.limit-factor.cpu: 8`. | +| `remote-shuffle.kubernetes.worker.limit-factor.RESOURCE` | Integer | `1` | 1.0.0 | false | Kubernetes resource overuse limit factor for `ShuffleWorker`. It should not be less than 1. The `RESOURCE` could be cpu, memory, ephemeral-storage and all other types supported by Kubernetes. For example, `remote-shuffle.kubernetes.manager.limit-factor.cpu: 8`. | +| `remote-shuffle.manager.memory.heap-size` | MemorySize | `4g` | 1.0.0 | false | Heap memory size to be used by the shuffle manager. | +| `remote-shuffle.manager.memory.off-heap-size` | MemorySize | `128m` | 1.0.0 | false | Off-heap memory size to be used by the shuffle manager. | +| `remote-shuffle.manager.jvm-opts` | String | `''` | 1.0.0 | false | Java options to start the JVM of the shuffle manager with. | +| `remote-shuffle.worker.memory.heap-size` | MemorySize | `1g` | 1.0.0 | false | Heap memory size to be used by the shuffle worker. | +| `remote-shuffle.worker.memory.off-heap-size` | MemorySize | `128m` | 1.0.0 | false | Off-heap memory size to be used by the shuffle worker. | +| `remote-shuffle.worker.jvm-opts` | String | `''` | 1.0.0 | false | Java options to start the JVM of the shuffle worker with. | + +### Yarn Deployment Related + +This section will present the valid config option that can be used by yarn deployment and should be +put in the yarn-site.xml file. + +| Key | Value Type | Default Value | Version | Required | Description | +| --- | ---------- | ------------- | ------- | -------- | ----------- | +| `remote-shuffle.yarn.worker-stop-on-failure` | Bool | `false` | 1.0.0 | false | Flag indicating whether to throw the encountered exceptions to the upper Yarn service. The parameter's default value is false. If it is set as true, the upper Yarn service may be stopped because of the exceptions from the `ShuffleWorker`. Note: This parameter needs to be configured in yarn-site.xml. | + +### Standalone Deployment Related + +| Key | Value Type | Default Value | Version | Required | Description | +| --- | ---------- | ------------- | ------- | -------- | ----------- | +| `remote-shuffle.manager.memory.heap-size` | MemorySize | `4g` | 1.0.0 | false | Heap memory size to be used by the shuffle manager. | +| `remote-shuffle.manager.memory.off-heap-size` | MemorySize | `128m` | 1.0.0 | false | Off-heap memory size to be used by the shuffle manager. | +| `remote-shuffle.manager.jvm-opts` | String | `''` | 1.0.0 | false | Java options to start the JVM of the shuffle manager with. | +| `remote-shuffle.worker.memory.heap-size` | MemorySize | `1g` | 1.0.0 | false | Heap memory size to be used by the shuffle worker. | +| `remote-shuffle.worker.memory.off-heap-size` | MemorySize | `128m` | 1.0.0 | false | Off-heap memory size to be used by the shuffle worker. | +| `remote-shuffle.worker.jvm-opts` | String | `''` | 1.0.0 | false | Java options to start the JVM of the shuffle worker with. | + diff --git a/docs/contribution.md b/docs/contribution.md new file mode 100644 index 00000000..40f2644c --- /dev/null +++ b/docs/contribution.md @@ -0,0 +1,64 @@ + + +# Contribution Guide + +This project will be improved continuously. We welcome any feedback and contribution to this +project. + +## Code Style + +This project adopts similar code style with Flink. You can run the following command to format the +code after you have made some changes. + +```bash +mvn spotless:apply +``` + +## How to Contribute + +For collaboration, feel free +to [contact us](../README.md#support) +. To report a bug, you can just open an issue on GitHub and attach the exceptions and your analysis +if any. For other improvements, you can contact us or open an issue first and describe what +improvement you would like to do. After reaching a consensus, you can open a pull request and your +pull request will get merged after reviewed. + +## Improvements on the Schedule + +There are already some further improvements on the schedule and welcome to contact us for +collaboration: + +1. Introduce web UI and add more metrics for better usability. + +2. Implement ReducePartition. + +3. Support more storage backends. + +4. More graceful upgrading. + +5. Further performance improvements. + +6. Production-ready standalone deployment. + +7. Isolation and security enhancement. + +8. Support adaptive execution. + +and so on. diff --git a/docs/deploy_on_kubernetes.md b/docs/deploy_on_kubernetes.md new file mode 100644 index 00000000..b9c5a285 --- /dev/null +++ b/docs/deploy_on_kubernetes.md @@ -0,0 +1,253 @@ + + +# Running Remote Shuffle Service on Kubernetes + +- [Getting Started](#getting-started) + - [Introduction](#introduction) + - [Preparation](#preparation) +- [Deploying Remote Shuffle Service Cluster](#deploying-remote-shuffle-service-cluster) + - [Deploying Remote Shuffle Operator](#deploying-remote-shuffle-operator) + - [Deploying Remote Shuffle Cluster](#deploying-remote-shuffle-cluster) +- [Submitting a Flink Job](#submitting-a-flink-job) +- [Logging & Configuration](#logging--configuration) + +## Getting Started +This page describes how to deploy remote shuffle service on Kubernetes. You can use the released image directly: docker.io/flinkremoteshuffle/flink-remote-shuffle:VERSION. Note that you need to replace the 'VERSION' filed with the actual version you want to use, for example, 1.0.0. + +### Introduction +Kubernetes is a popular container-orchestration system for automating application deployment, scaling, and management. Remote shuffle service allows you to directly deploy the services on a running Kubernetes cluster. + +### Preparation +The `Getting Started` section assumes that your environment fulfills the following requirements: +- A functional Kubernetes cluster (Kubernetes >= 1.13). + +- Make sure a valid Zookeeper cluster is ready. Or you can refer to [setting up a Zookeeper cluster](https://zookeeper.apache.org/doc/current/zookeeperStarted.html) to start a Zookeeper cluster manually. + +- [Download the latest binary release](https://github.com/flink-extended/flink-remote-shuffle/releases) or [build remote shuffle service yourself](https://github.com/flink-extended/flink-remote-shuffle#building-from-source). + +If you have problems setting up a Kubernetes cluster, take a look at [how to setup a Kubernetes cluster](https://kubernetes.io/docs/setup/). + +## Deploying Remote Shuffle Service Cluster +The remote shuffle service cluster contains a `ShuffleManager` and multiple `ShuffleWorker`s. The `ShuffleManager` runs as a Kubernetes [Deployment](https://kubernetes.io/docs/concepts/workloads/controllers/deployment/) (the number of replicas is 1), and the shuffle workers run as a Kubernetes [DaemonSet](https://kubernetes.io/docs/concepts/workloads/controllers/daemonset/) which means the number of `ShuffleWorker`s is the same as the number of machines in the Kubernetes cluster. The following two points should be noted here: + +1. Currently, we only support host network for network communication. + +2. The shuffle workers use a [hostPath](https://kubernetes.io/docs/concepts/storage/volumes/#hostpath) volume (specified by `remote-shuffle.kubernetes.worker.volume.host-paths`) or a [emptydir](https://kubernetes.io/docs/concepts/storage/volumes/#emptydir) volume (specified by `remote-shuffle.kubernetes.worker.volume.empty-dirs`) for shuffle data storage. + +Additionally, to make it easier to deploy on a Kubernetes cluster, we provided a Kubernetes [Operator](https://kubernetes.io/docs/concepts/extend-kubernetes/operator/) for remote shuffle service, which can control the life cycle of remote shuffle service instances, including creation, deletion, and upgrade. + +### Deploying Remote Shuffle Operator +Once you have your Kubernetes cluster ready and `kubectl` is configured to point to it, you can launch an operator via: + +```sh +# Note: You must configure the docker image to be used by modifying the template file first before running this command. + +kubectl apply -f kubernetes-shuffle-operator-template.yaml +``` + +The template file `kubernetes-shuffle-operator-template.yaml` should be in `conf/` directory and its contents are as follows. + +```yaml +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: flink-rss-cr +rules: +- apiGroups: ["apiextensions.k8s.io"] + resources: + - customresourcedefinitions + verbs: + - '*' +- apiGroups: ["shuffleoperator.alibaba.com"] + resources: + - remoteshuffles + verbs: + - '*' +- apiGroups: ["shuffleoperator.alibaba.com"] + resources: + - remoteshuffles/status + verbs: + - update +- apiGroups: ["apps"] + resources: + - deployments + - daemonsets + verbs: + - '*' +- apiGroups: [""] + resources: + - configmaps + verbs: + - '*' +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: flink-rss-crb +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: flink-rss-cr +subjects: +- kind: ServiceAccount + name: flink-rss-sa + namespace: flink-system-rss +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: flink-rss-sa + namespace: flink-system-rss +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + namespace: flink-system-rss + name: flink-remote-shuffle-operator +spec: + replicas: 1 + selector: + matchLabels: + app: flink-remote-shuffle-operator + template: + metadata: + labels: + app: flink-remote-shuffle-operator + spec: + serviceAccountName: flink-rss-sa + containers: + - name: flink-remote-shuffle-operator + image: # You need to configure the docker image to be used here. + imagePullPolicy: Always + command: + - bash + args: + - -c + - $JAVA_HOME/bin/java -classpath '/flink-remote-shuffle/opt/*' -Dlog4j.configurationFile=file:/flink-remote-shuffle/conf/log4j2-operator.properties -Dlog.file=/flink-remote-shuffle/log/operator.log com.alibaba.flink.shuffle.kubernetes.operator.RemoteShuffleApplicationOperatorEntrypoint +``` + +### Deploying Remote Shuffle Cluster +Then you can start `ShuffleManager` and `ShuffleWorker`s via: + +```sh +# Note: You must accomplish the template file first before running this command. + +kubectl apply -f kubernetes-shuffle-cluster-template.yaml +``` + +The template file `kubernetes-shuffle-cluster-template.yaml` should be in `conf/` directory and its contents are as follows. + +```yaml +apiVersion: shuffleoperator.alibaba.com/v1 +kind: RemoteShuffle +metadata: + namespace: flink-system-rss + name: flink-remote-shuffle +spec: + shuffleDynamicConfigs: + remote-shuffle.manager.jvm-opts: -verbose:gc -Xloggc:/flink-remote-shuffle/log/gc.log + remote-shuffle.worker.jvm-opts: -verbose:gc -Xloggc:/flink-remote-shuffle/log/gc.log + remote-shuffle.kubernetes.manager.cpu: 4 + remote-shuffle.kubernetes.worker.cpu: 4 + remote-shuffle.kubernetes.worker.limit-factor.cpu: 8 + remote-shuffle.kubernetes.container.image: + remote-shuffle.kubernetes.worker.volume.host-paths: name:disk,path:,mountPath:/data + remote-shuffle.storage.local-data-dirs: '[SSD]/data' + remote-shuffle.high-availability.mode: ZOOKEEPER + remote-shuffle.ha.zookeeper.quorum: + remote-shuffle.kubernetes.manager.env-vars: # You need to configure your time zone here, for example, TZ:Asia/Shanghai. + remote-shuffle.kubernetes.worker.env-vars: # You need to configure your time zone here, for example, TZ:Asia/Shanghai. + + shuffleFileConfigs: + log4j2.properties: | + monitorInterval=30 + + rootLogger.level = INFO + rootLogger.appenderRef.console.ref = ConsoleAppender + rootLogger.appenderRef.rolling.ref = RollingFileAppender + + # Log all info to the console + appender.console.name = ConsoleAppender + appender.console.type = CONSOLE + appender.console.layout.type = PatternLayout + appender.console.layout.pattern = %d{yyyy-MM-dd HH:mm:ss,SSS} %-5p [%t] %-60c %x - %m%n + + # Log all info in the given rolling file + appender.rolling.name = RollingFileAppender + appender.rolling.type = RollingFile + appender.rolling.append = true + appender.rolling.fileName = ${sys:log.file} + appender.rolling.filePattern = ${sys:log.file}.%i + appender.rolling.layout.type = PatternLayout + appender.rolling.layout.pattern = %d{yyyy-MM-dd HH:mm:ss,SSS} %-5p [%t] %-60c %x - %m%n + appender.rolling.policies.type = Policies + appender.rolling.policies.size.type = SizeBasedTriggeringPolicy + appender.rolling.policies.size.size=256MB + appender.rolling.policies.startup.type = OnStartupTriggeringPolicy + appender.rolling.strategy.type = DefaultRolloverStrategy + appender.rolling.strategy.max = ${env:MAX_LOG_FILE_NUMBER:-10} +``` + +- Note that `remote-shuffle.ha.zookeeper.quorum` should be accomplished according to the actual environment. + +- Note that `remote-shuffle.kubernetes.container.image` should be accomplished according to the shuffle service image built from source code. + +- Note that `remote-shuffle.kubernetes.worker.volume.host-paths` should be accomplished according to the actual storage path on host to be used. + +- Note that `remote-shuffle.kubernetes.manager.env-vars` and `remote-shuffle.kubernetes.worker.env-vars` should be accomplished to specify the right time zone. + +If you want to build a new image by yourself, please refer to [preparing docker environment](https://docs.docker.com/get-docker/) and [building from source](https://github.com/flink-extended/flink-remote-shuffle#building-from-source). + +After successfully running the above command `kubectl apply -f XXX`, a new shuffle service cluster will be started. + +## Submitting a Flink Job + +To submit a Flink job, please refer to [starting a Flink cluster](./quick_start.md#starting-a-flink-cluster) and [submitting a Flink job](./quick_start.md#submitting-a-flink-job). + +If you would like to run Flink jobs on Kubernetes, you need to follow the below steps: + +1. First of all, you need to build a new Flink docker image which contains remote shuffle plugin JAR file. Please refer to [how to customize the Flink Docker image](https://nightlies.apache.org/flink/flink-docs-release-1.14/docs/deployment/resource-providers/standalone/docker/#advanced-customization) for more information. The following is a simple customized Flink Docker file example: + +```dockerfile +FROM flink + +# The path of shuffle plugin JAR should be the lib/ directory of the remote shuffle distribution which need to be replaced by the really path in your environment. +COPY //shuffle-plugin-.jar /opt/flink/lib/ +``` + +2. The you should add the following configurations to `conf/flink-conf.yaml` in the extracted Flink directory to configure Flink to use the remote shuffle service. Please note that the configuration of `remote-shuffle.ha.zookeeper.quorum` should be exactly the same as that in `kubernetes-shuffle-cluster-template.yaml`. + +```yaml +shuffle-service-factory.class: com.alibaba.flink.shuffle.plugin.RemoteShuffleServiceFactory +remote-shuffle.high-availability.mode: ZOOKEEPER +remote-shuffle.ha.zookeeper.quorum: zk1.host:2181,zk2.host:2181,zk3.host:2181 +``` + +3. Finally, you can start a Flink cluster on Kubernetes and submit a Flink job. Please refer to [start a Flink cluster on Kubernetes](https://nightlies.apache.org/flink/flink-docs-release-1.14/docs/deployment/resource-providers/standalone/kubernetes/) or [Flink natively on Kubernetes](https://nightlies.apache.org/flink/flink-docs-release-1.14/docs/deployment/resource-providers/native_kubernetes/) for more information. + +## Logging & Configuration + +From the above YAML file templates, you can figure out how to configure remote shuffle service on Kubernetes. + +Kubernetes operator related options and log output file are specified in `kubernetes-shuffle-operator-template.yaml`. + +Any configurations in [configuration page](./configuration.md), log output format and log appender options of `ShuffleManager` and `ShuffleWorker` are configured in `kubernetes-shuffle-cluster-template.yaml`. + diff --git a/docs/deploy_on_yarn.md b/docs/deploy_on_yarn.md new file mode 100644 index 00000000..85866764 --- /dev/null +++ b/docs/deploy_on_yarn.md @@ -0,0 +1,160 @@ + + +# Running Remote Shuffle Service on YARN + +- [Getting Started](#getting-started) + - [Introduction](#introduction) + - [Preparation](#preparation) + - [Starting a ShuffleManager on YARN](#starting-a-shufflemanager-on-yarn) + - [Starting ShuffleWorkers on YARN](#starting-shuffleworkers-on-yarn) +- [Submitting a Flink Job](#submitting-a-flink-job) +- [Logging](#logging) +- [Supported Hadoop versions](#supported-hadoop-versions) + +## Getting Started +This *Getting Started* section guides you through setting up a fully functional Flink remote shuffle service on YARN. + +### Introduction +[YARN](https://hadoop.apache.org/docs/current/hadoop-yarn/hadoop-yarn-site/YARN.html) is a very popular resource management framework. Remote shuffle service for Flink can be deployed on YARN. `ShuffleManager` and `ShuffleWorker`s are started on YARN framework in different ways. `ShuffleManager` runs in the ApplicationMaster container of a special YARN application. `ShuffleWorker`s run as [Auxiliary Services](https://hadoop.apache.org/docs/stable/hadoop-mapreduce-client/hadoop-mapreduce-client-core/PluggableShuffleAndPluggableSort.html) on `NodeManager`s. + +### Preparation +This *Getting Started* section assumes a functional YARN environment (>= 2.4.1). YARN environments are provided most conveniently through services such as Alibaba Cloud, Google Cloud DataProc, Amazon EMR, or products like Cloudera, etc. You can also refer to [setting up a YARN environment locally](https://hadoop.apache.org/docs/current/hadoop-project-dist/hadoop-common/SingleCluster.html) or [setting up a YARN cluster](https://hadoop.apache.org/docs/current/hadoop-project-dist/hadoop-common/ClusterSetup.html) to setup a YARN environment manually. + +In addition, Zookeeper cluster is used for high availability. When the manager is started on the YARN container, its address information is written to Zookeeper, from where the `ShuffleWorker`s and `ShuffleClient` can obtain the manager's address. + +So the prerequisites to deploy remote shuffle service on YARN are as follows. + +- Make sure your YARN cluster is ready for accepting YARN applications. +- [Download the latest binary release](https://github.com/flink-extended/flink-remote-shuffle/releases) or [build remote shuffle service yourself](https://github.com/flink-extended/flink-remote-shuffle#building-from-source). +- **Important** Make sure `HADOOP_YARN_HOME` environment variable is set up. Use the command `yarn` or `$HADOOP_YARN_HOME/bin/yarn` to check whether the environment variable is set successfully, and the output should not display any errors. +- Make sure a valid Zookeeper cluster is ready. Or you can refer to [setting up a Zookeeper cluster](https://zookeeper.apache.org/doc/current/zookeeperStarted.html) to start a Zookeeper cluster manually. + +### Starting a ShuffleManager on YARN +Once the above prerequisites are ready, you can start `ShuffleManager` on YARN: + +```sh +# We assume to be in the root directory of the unzipped distribution + +# Start a ShuffleManager on YARN +./bin/yarn-shufflemanager.sh start --am-mem-mb 4096 -q root.default -D remote-shuffle.high-availability.mode=ZOOKEEPER -D remote-shuffle.ha.zookeeper.quorum=zk1.host:2181,zk2.host:2181,zk3.host:2181 + +# Stop the ShuffleManager on YARN +./bin/yarn-shufflemanager.sh stop + +# Use the following command to display detailed usage. +./bin/yarn-shufflemanager.sh -h +``` + +`ShuffleManager` can be configured through the `./bin/yarn-shufflemanager.sh` script using dynamic parameters. Any configurations in [configuration page](./configuration.md) supported by remote shuffle service can be passed in as a parameter of the script, such as `-D =`. You can also put these configurations in `conf/remote-shuffle-conf.yaml`. + +Through querying the metric server or checking the container output log, you can **check whether the manager is started** successfully. The default metric server address is `http://:23101/metrics/`. + +You have successfully start a `ShuffleManager` on YARN. If any exception occurs, the current`ShuffleManager` may stop and another new `ShuffleManager` will start in a new container, which will not affect the runing Flink batch jobs. + +After `ShuffleManager` is started, you can start `ShuffleWorker`s on `NodeManager`s, which will register to the started `ShuffleManager` based on the address obtained from the Zookeeper. + +### Starting ShuffleWorkers on YARN +Each `ShuffleWorker` starts as an auxiliary service on `NodeManager`. +To start the `ShuffleWorker` on each `NodaManager`, follow these instructions: + +1. Locate the `shuffle-dist-.jar`. If you compile the project manually, this should be under `./build-target/lib/` directory. If you use the download distribution, this should be under `./lib/` directory. + +2. Add the `shuffle-dist-.jar` to the classpath of each `NodeManager` in your YARN cluster. To achieve this, you can either add the following command to `/conf/yarn-env.sh` or move the `shuffle-dist-.jar` to `/share/hadoop/yarn/`. + +```sh +export HADOOP_CLASSPATH=$HADOOP_CLASSPATH:/shuffle-dist-.jar +``` + +3. In the `etc/hadoop/yarn-site.xml` on each `NodeManager`, add the following configurations. Please note that the Zookeeper address `remote-shuffle.ha.zookeeper.quorum` should be the same as the address used by the `ShuffleManager` at startup. `remote-shuffle.storage.local-data-dirs` is the local file system directories to persist partitioned data to. In order to avoid the death of `ShuffleManager` due to timeout when `ShuffleManager` does not exist, set the timeout to a very large value by `remote-shuffle.cluster.registration.timeout`, for example, set to 1 year. Any configurations in [configuration page](./configuration.md) can be added to **`yarn-site.xml`** to control the behavior of `ShuffleWorker`. + +```xml + + remote-shuffle.high-availability.mode + ZOOKEEPER + + + + remote-shuffle.ha.zookeeper.quorum + zk1.host:2181,zk2.host:2181,zk3.host:2181 + + + + remote-shuffle.storage.local-data-dirs + + [SSD]/PATH1/rss_data_dir1/,[HDD]/PATH2/rss_data_dir2/ + + + + + remote-shuffle.cluster.registration.timeout + 31536000000 + +``` + +4. Restart all `NodeManager`s in your cluster. Note that the heap and direct memory size of `NodeManager` should be increased to avoid `out of memory` problems. The heap size of `ShuffleWorker` is mainly used by `remote-shuffle.worker.memory.heap-size`, which is 1g by default. The direct memory used includes `remote-shuffle.worker.memory.off-heap-size`(128m by default), `remote-shuffle.memory.data-writing-size`(4g by default) and `remote-shuffle.memory.data-reading-size`(4g by default). In total, please increase at least 1g heap size and 8.2g direct memory size by default for `NodeManager`. In your production environment, you can adjust these configurations to change the memory usage of `ShuffleWorker`. + +Alternatively, when starting a local standalone or YARN cluster on your laptop, you can reduce `remote-shuffle.memory.data-reading-size` or `remote-shuffle.memory.data-writing-size` to decrease the memory usage of `ShuffleWorker`, for example, set to 128m. For more `ShuffleWorker` configurations, please refer to [configuration page](./configuration.md). + +Through querying the metric server or checking the container output log, you can **check whether a worker is started** successfully. The default metric server address is `http://:23103/metrics/` by default. + +Now you have started a remote shuffle service cluster on YARN successfully. + +## Submitting a Flink Job + +To submit a Flink job, please refer to [starting a Flink cluster](./quick_start.md#starting-a-flink-cluster) and [submitting a Flink job](./quick_start.md#submitting-a-flink-job). + +## Logging +**ShuffleManager Log** + +For `ShuffleManager` running on YARN, the log is output to the container log. You can use `conf/log4j2.properties` to modify the log level, log output format, etc. + +To enable GC log or modify other JVM GC options for `ShuffleManager`, add `remote-shuffle.yarn.manager-am-jvm-options` in `conf/remote-shuffle-conf.yaml`, +the following is a simple example: + +```yaml +remote-shuffle.yarn.manager-am-jvm-options: -verbose:gc -XX:NewRatio=3 -XX:+PrintGCDetails -XX:+PrintGCDateStamps -XX:ParallelGCThreads=4 -XX:+UseGCLogFileRotation +``` + +**ShuffleWorker Log** + +For `ShuffleWorker` running on YARN, as an auxiliary service of `NodeManager`, its log is output to the log of `NodeManager` by default. + +If you want to separate its logs from `NodeManager` logs by directory or modify the log level, you can modify `conf/log4j.properties` in **Hadoop** directory and restart all `NodeManager`s. Here is an example. + +```properties +# Flink Remote Shuffle Service Logs +flink.shuffle.logger=INFO,FLINKRSS +flink.shuffle.log.maxfilesize=512MB +flink.shuffle.log.maxbackupindex=20 +log4j.logger.com.alibaba.flink.shuffle=${flink.shuffle.logger} +log4j.additivity.com.alibaba.flink.shuffle=false +log4j.appender.FLINKRSS=org.apache.log4j.RollingFileAppender +log4j.appender.FLINKRSS.File=/flink-remote-shuffle/log/flink-remote-shuffle.log +log4j.appender.FLINKRSS.layout=org.apache.log4j.PatternLayout +log4j.appender.FLINKRSS.layout.ConversionPattern=%d{ISO8601} %p %c{2}: %m%n +log4j.appender.FLINKRSS.MaxFileSize=${flink.shuffle.log.maxfilesize} +log4j.appender.FLINKRSS.MaxBackupIndex=${flink.shuffle.log.maxbackupindex} +``` + +In this way, `ShuffleWorker` logs will be separated from `NodeManager` logs. All logs of the classes starting with `com.alibaba.flink.shuffle` will be output to `/flink-remote-shuffle/log/flink-remote-shuffle.log`. + +## Supported Hadoop versions +Remote shuffle service for Flink on YARN is compiled against Hadoop 2.4.1, and all Hadoop versions >= 2.4.1 are supported, including Hadoop 3.x. + diff --git a/docs/deploy_standalone_mode.md b/docs/deploy_standalone_mode.md new file mode 100644 index 00000000..30eb32b9 --- /dev/null +++ b/docs/deploy_standalone_mode.md @@ -0,0 +1,164 @@ + + +# Remote Shuffle Service Standalone Mode + +- [Introduction](#introduction) +- [Preparation](#preparation) +- [Cluster Quick Start Script](#cluster-quick-start-script) +- [Single Component Management Script](#single-component-management-script) +- [Submitting a Flink Job](#submitting-a-flink-job) +- [Logging & JVM arguments](#logging--jvm-arguments) + +## Introduction +In addition to running on the YARN and Kubernetes cluster managers, remote shuffle service also provides a simple standalone deploy mode. You have to take care of restarting failed processes and resources allocation during operation in this mode. + +This page mainly introduces two methods to start a standalone cluster: +1. Start a cluster with management scripts. With this method, cluster management becomes easier. After simple configuration, you can easily start `ShuffleManager` and `ShuffleWorker` with one line of command. +2. Start `ShuffleManager` and `ShuffleWorker` one by one with separate commands, which facilitates independent management of individual components. + +The followings are two important config options to be used: + +| Argument | Meaning | +| -------- | ------- | +|`remote-shuffle.manager.rpc-address` | The network IP address to connect to for communication with the shuffle manager. Only IP address without port.| +|`remote-shuffle.storage.local-data-dirs` | Local file system directories to persist partitioned data to. Multiple directories can be configured and these directories should be separated by comma ','. For example, [SSD]/PATH1/rss_data_dir1/,[HDD]/PATH2/rss_data_dir2/, the prefix [HDD] and [SSD] indicate the disk type.| + +## Preparation +Remote shuffle service runs on all UNIX-like environments, e.g. Linux, Mac OS X. Before you start the standalone cluster, make sure your system fullfils the following requirements. + +- Java 1.8.x or higher installed, +- [Download the latest binary release](https://github.com/flink-extended/flink-remote-shuffle/releases) or [build remote shuffle service yourself](https://github.com/flink-extended/flink-remote-shuffle#building-from-source). + +## Cluster Quick Start Script +**Starting and Stopping a cluster** + +`bin/start-cluster.sh` and `bin/stop-cluster.sh` rely on `conf/workers` to determine the number of cluster component instances. Note that you should only start one `ShuffleWorker` instance per physical node. + +If password-less SSH access to the listed machines is configured, and they share the same directory structure, the scripts can support starting and stopping instances remotely. + +***Example: Start a distributed shuffle service cluster with 2 ShuffleWorkers*** + +At present, only one manager is supported in standalone mode. + +Contents of `conf/managers`: + +``` +manager1 +``` + +`manager1` is the **actual IP address** where the `ShuffleManager` is started. Only one address is required. If multiple addresses are filled in, the first will be used. The `ShuffleManager` RPC port can be configured by `remote-shuffle.manager.rpc-port` in `conf/remote-shuffle-conf.yaml`, if not configured, port 23123 will be used by default. + +Note that if you want to start `ShuffleWorker`s on multiple machines, you need to replace `127.0.0.1` in `conf/managers` with the actual IP address. Otherwise, workers started on other machines cannot get the right `ShuffleManager` IP address and cannot connect to the manager. + +If you want to start only one `ShuffleWorker` and the `ShuffleWorker` is on the same machine where the `ShuffleManager` is started, the default `127.0.0.1` in `conf/managers` can meet the requirements. + +Contents of `conf/workers`: + +``` +worker1 +worker2 +``` + +Note that two workers means that at least two physical nodes are needed, if you only have one physical node, please only configure one worker here. + +Then you can execute the following command to start the cluster: + +```sh +./bin/start-cluster.sh -D remote-shuffle.storage.local-data-dirs="[SSD]/PATH/rss_data_dir/" +``` + +When executing `bin/start-cluster.sh`, please make sure the following requirements are ready. +- Password-less SSH access to the listed machines is configured. +- The same directory structure of the remote shuffle distribution should exist on each listed machine. +- A shuffle data directory(`[HDD]/PATH/rss_data_dir/`) on each `ShuffleWorker` should be created and the directory permissions should be accessible. + +You can use "-D" to pass any configuration options in [configuration page](./configuration.md) as parameters of `./bin/start-cluster.sh` to control the behavior of `ShuffleManager` and `ShuffleWorker`s, for example, `-D =`. These configurations can also be configured in `conf/remote-shuffle-conf.yaml`. + +After running `bin/start-cluster.sh`, the output log is as follows. + +```sh +Starting cluster. +Starting shufflemanager daemon on host your-host. +Starting shuffleworker daemon on host your-host. +``` + +If `bin/start-cluster.sh` is executed successfully, you have started a remote shuffle service cluster with a manager and 2 workers. +The `ShuffleManager` log and `ShuffleWorker` log will be output to the `log/` directory by default. + +You can stop the cluster with the command: + +```sh +./bin/stop-cluster.sh +``` + +The above illustrates a convenient way to manage a cluster. + +## Single Component Management Script + +This section describes another way to start a cluster. + +**Starting a ShuffleManager** + +You can start a standalone `ShuffleManager` by executing: + +```sh +./shufflemanager.sh start -D remote-shuffle.manager.rpc-address= +``` + +Note that `manager-ip-address` is the real address that can be connected from the outside which is not `127.0.0.1` or `localhost`. + +You can use "-D" to pass all options in the [configuration page](./configuration.md) to `ShuffleManager`, for example, `-D =`. These configurations can also be configured in `conf/remote-shuffle-conf.yaml`. + +Through querying the metric server or checking the container output log, you can **check whether the manager is started** successfully. The default metric server address is `http://:23101/metrics/`. + +Before starting a `ShuffleWorker`, you need create a directory to store shuffle data files, which is an indispensable configuration option. For example, the directory created is `[SSD]/PATH/rss_data_dir/`, which is a SSD type disk. + +**Starting ShuffleWorkers** + +Similarly, you can start one or more workers and connect them to the manager via: + +```sh +./bin/shuffleworker.sh start -D remote-shuffle.manager.rpc-address= -D remote-shuffle.storage.local-data-dirs="[HDD]/PATH/rss_data_dir/" +``` + +You can use "-D" to pass all options in the [configuration page](./configuration.md) to `ShuffleWorker`, for example, `-D =`. These configurations can also be configured in `conf/remote-shuffle-conf.yaml`. + +Through querying the metric server or checking the container output log, you can **check whether a worker is started** successfully. The default metric server address is `http://:23103/metrics/`. + +## Submitting a Flink Job +To submit a Flink job, please refer to [starting a Flink cluster](./quick_start.md#starting-a-flink-cluster) and [submitting a Flink job](./quick_start.md#submitting-a-flink-job). + +## Logging & JVM Arguments +**Configuring Log4j** + +You can modify `log4j2.properties` to control the log output format, log level, etc. For example, change `rootLogger.level=INFO` to `rootLogger.level=DEBUG` to enable debug logging. + +**Configuring Log Options and JVM Arguments** + +Adding configuration options to `conf/remote-shuffle-conf.yaml` to configure the log directory, memory size, JVM GC log, JVM GC options, etc. Here is a general example. + +```yaml +remote-shuffle.manager.jvm-opts: -verbose:gc -Dlog.file=/flink-remote-shuffle/log/shufflemanager.log -Xloggc:/flink-remote-shuffle/log/shufflemanager.gc.log -XX:NewRatio=3 -XX:+PrintGCDetails -XX:+PrintGCDateStamps -XX:ParallelGCThreads=4 -XX:+UseGCLogFileRotation -XX:NumberOfGCLogFiles=2 -XX:GCLogFileSize=256M + +remote-shuffle.worker.jvm-opts: -verbose:gc -Dlog.file=/flink-remote-shuffle/log/shuffleworker.log -Xloggc:/flink-remote-shuffle/log/shuffleworker.gc.log -XX:NewRatio=3 -XX:+PrintGCDetails -XX:+PrintGCDateStamps -XX:ParallelGCThreads=4 -XX:+UseGCLogFileRotation -XX:NumberOfGCLogFiles=2 -XX:GCLogFileSize=256M +``` + +According to this example configuration, Logs and GC logs will be output to `/flink-remote-shuffle/log/`. By default, no GC logs will be output and the logs are in `/log`. Please modify or add new parameters to control log output and JVM GC according to your production environment. + diff --git a/docs/imgs/map-partition.png b/docs/imgs/map-partition.png new file mode 100644 index 00000000..928aad3e Binary files /dev/null and b/docs/imgs/map-partition.png differ diff --git a/docs/imgs/reduce-partition.png b/docs/imgs/reduce-partition.png new file mode 100644 index 00000000..6a05c16d Binary files /dev/null and b/docs/imgs/reduce-partition.png differ diff --git a/docs/imgs/remote-shuffle.png b/docs/imgs/remote-shuffle.png new file mode 100644 index 00000000..ad44ffdd Binary files /dev/null and b/docs/imgs/remote-shuffle.png differ diff --git a/docs/imgs/support-en.jpeg b/docs/imgs/support-en.jpeg new file mode 100644 index 00000000..8e32c8c0 Binary files /dev/null and b/docs/imgs/support-en.jpeg differ diff --git a/docs/imgs/support-zh.jpeg b/docs/imgs/support-zh.jpeg new file mode 100644 index 00000000..dcc699c8 Binary files /dev/null and b/docs/imgs/support-zh.jpeg differ diff --git a/docs/quick_start.md b/docs/quick_start.md new file mode 100644 index 00000000..b5f6f775 --- /dev/null +++ b/docs/quick_start.md @@ -0,0 +1,140 @@ + + +# Quick Start + +- [Introduction](#introduction) +- [Build & Download](#build---download) +- [Browsing the Project Directory](#browsing-the-project-directory) +- [Starting Standalone Clusters](#starting-clusters) + * [Starting a Standalone Remote Shuffle Cluster](#starting-a-standalone-remote-shuffle-cluster) + * [Starting a Flink Cluster](#starting-a-flink-cluster) +- [Submitting a Flink Job](#submitting-a-flink-job) +- [Where to Go from Here](#where-to-go-from-here) + + +This tutorial provides a quick introduction to using remote shuffle service for Flink. This short guide will show you how to download the latest stable version of remote shuffle service, install and run it. You will also run an example Flink job using remote shuffle service and observe it in the Flink web UI. + +## Introduction +Following the below steps in this quick start guide, you will run a simple Flink batch job which uses the remote shuffle service: +1. Download remote shuffle service and Flink binary releases. +2. Start a standalone remote shuffle service cluster. +3. Start a Flink local cluster that uses the remote shuffle service for data shuffling. +4. Submit a simple Flink batch job to the started cluster. + +## Build & Download +Remote shuffler service runs on all UNIX-like environments, i.e. Linux, Mac OS X. You need to have Java 8 installed. To check the Java version installed, type in your terminal: + +```sh +java -version +``` + +1. [Download the latest binary release](https://github.com/flink-extended/flink-remote-shuffle/releases) or [build remote shuffle service yourself](https://github.com/flink-extended/flink-remote-shuffle#building-from-source). +2. Next, [download the latest binary release](https://flink.apache.org/downloads.html) of Flink, then extract the archive: + +```sh + tar -xzf flink-*.tgz +``` + +For the steps of downloading or installing Flink, you can also refer to [Flink first steps](https://nightlies.apache.org/flink/flink-docs-release-1.14//docs/try-flink/local_installation/). + +## Browsing the Project Directory +Navigate to the extracted shuffle service directory and list the contents by issuing: + +```sh +cd flink-remote-shuffle* && ls -l +``` + +You should see some directories as follows. + +| Directory | Meaning | +|--|--| +|`bin/` | Directory containing several bash scripts of the remote shuffle service that manage `ShuffleMananger` or `ShuffleWorker`.| +|`conf/` | Directory containing configuration files, including `remote-shuffle-conf.yaml`, `log4j2.properties`, etc.| +|`lib/` | Directory containing the remote shuffle service JARs compiled, including `shuffle-dist-*.jar`, log4j JARs, etc.| +|`log/` | Log directory should be empty. When running standalone shuffle service cluster, the logs of `ShuffleMananger` or `ShuffleWorker`s will be stored in this directory by default.| +|`opt/` | Directory containing the optional JARs used in some special environments, for example, `shuffle-kubernetes-operator-*.jar` is used when deploying on Kubernetes.| +|`examples/` | Directory containing several demo example JARs. | + +## Starting Clusters +### Starting a Standalone Remote Shuffle Cluster +Please refer to [how to start a standalone remote shuffle cluster](./deploy_standalone_mode.md#cluster-quick-start-script). + +### Starting a Flink Cluster +Before starting a Flink local cluster, +1. Make sure that a valid remote shuffle service cluster has been successfully started. +2. You need to copy the shuffle plugin JAR from the remote shuffle `lib` directory (for example, `lib/shuffle-plugin-*.jar`) to the Flink `lib` directory. + +For different startup modes of remote shuffle service, Flink job configurations are different, the details are as follows. + +- For standalone remote shuffle service, please add the following configurations to `conf/flink-conf.yaml` in the extracted Flink directory to use remote shuffle service when running a Flink batch job. The argument `manager-ip-address` is the ip address of `ShuffleManager` (for local remote shuffle cluster, it should be 127.0.0.1). + +```yaml +shuffle-service-factory.class: com.alibaba.flink.shuffle.plugin.RemoteShuffleServiceFactory +remote-shuffle.manager.rpc-address: +``` + +- For remote shuffle service on YARN or Kubernetes, please add the following configurations to `conf/flink-conf.yaml` in the extracted Flink directory to use remote shuffle service when running a Flink batch job. `remote-shuffle.ha.zookeeper.quorum` is the Zookeeper address of the `ShuffleManager` when high availability is enabled. + +```yaml +shuffle-service-factory.class: com.alibaba.flink.shuffle.plugin.RemoteShuffleServiceFactory +remote-shuffle.high-availability.mode: ZOOKEEPER +remote-shuffle.ha.zookeeper.quorum: zk1.host:2181,zk2.host:2181,zk3.host:2181 +``` + +Please refer the following links for different deployment mode of Flink: +- For standalone Flink cluster, please refer to [Flink standalone mode](https://nightlies.apache.org/flink/flink-docs-release-1.14/docs/deployment/resource-providers/standalone/overview/). +- For Flink cluster on YARN, please refer to [Flink on YARN](https://nightlies.apache.org/flink/flink-docs-release-1.14/docs/deployment/resource-providers/yarn/). +- For Flink cluster on YARN, please refer to [Flink on Kubernetes](https://nightlies.apache.org/flink/flink-docs-release-1.14/docs/deployment/resource-providers/standalone/kubernetes/) or [natively on Kubernetes](https://nightlies.apache.org/flink/flink-docs-release-1.14/docs/deployment/resource-providers/native_kubernetes/). + +Usually, starting a local Flink cluster by running the following command is enough for this quick start guide: + +```sh +# We assume to be in the root directory of the Flink extracted distribution + +./bin/start-cluster.sh +``` + +You should be able to navigate to the web UI at `http://:8081` to view the Flink dashboard and see that the cluster is up and running. + +Because the configurations related to shuffle have been modified, all jobs submitted to the Flink cluster will use remote shuffle service to shuffle data. + +## Submitting a Flink Job + +After starting the Flink cluster successfully, you can submit a simple Flink batch demo job. + +The example source code is in the `shuffle-examples` module. `BatchJobDemo` is a simple Flink batch job. And you need to copy the compiled demo JAR `examples/BatchJobDemo.jar` to the extracted Flink directory. Please run the following command to submit the example batch job. + +```sh +# Firstly, copy the example JAR +# cp examples/BatchJobDemo.jar + +# We assume to be in the root directory of the Flink extracted distribution + +./bin/flink run ./BatchJobDemo.jar +``` + +You have successfully ran a Flink batch job using remote shuffle service. + +## Where to Go from Here +Congratulations on running your first Flink application using remote shuffle service! + +- For an in-depth overview of the remote shuffle service, start with the [user guide](./user_guid.md). +- For running applications on a cluster, head to the [deployment overview](./user_guid.md#deployment). +- For detailed configurations, refer to the [configuration page](./configuration.md). diff --git a/docs/user_guide.md b/docs/user_guide.md new file mode 100644 index 00000000..77838cf3 --- /dev/null +++ b/docs/user_guide.md @@ -0,0 +1,400 @@ + + +# User Guide + +- [Architecture](#architecture) + - [Components](#components) + - [High Availability](#high-availability) + - [Data Storage](#data-storage) + - [Data Transmission](#data-transmission) +- [How to Use](#how-to-use) + - [Deployment](#deployment) + - [Configuration](#configuration) + - [Operations](#operations) + - [Fault Tolerance](#fault-tolerance) + - [Best Practices](#best-practices) + - [Checklist](#checklist) + +This document will present the basic architecture and some fundamental concepts of the remote +shuffle system. At the same time, it gives guidelines about how to deploy and run Flink batch jobs +using the remote shuffle service. + +## Architecture + +
+ +
+ +### Components + +The remote shuffle process involves the interaction of several important components: + ++ **ShuffleMaster:** `ShuffleMaster`, as an important part of Flink's pluggable shuffle + architecture, is the intermediate result partition registry used by Flink's `JobMaster`. + ++ **ShuffleManager:** `ShuffleManager` is a centralized shuffle cluster supervisor which is + responsible for assigning shuffle resources (shuffle data storage) to jobs using the remote + shuffle service. Each shuffle cluster has one active `ShuffleManager`. + ++ **ShuffleWorker:** `ShuffleWorker` is the storage of shuffle data which is managed by + the `ShuffleManager`. + ++ **ShuffleClient:** `ShuffleClient` runs in `ShuffleMaster` and communicates with `ShuffleManager`. + It acts as an agent between the Flink job and the remote shuffle service to allocate and release + shuffle resources. + ++ **WriteClient:** `WriteClient` is responsible for writing the shuffle data to the + remote `ShuffleWorker` for the corresponding Flink task. + ++ **ReadClient:** `ReadClient` is responsible for reading the shuffle data from the + remote `ShuffleWorker` for the corresponding Flink task. + ++ **DataPartition:** `DataPartition` is the data unit managed by `ShuffleWorker`. There are two + types of `DataPartition`. See the [data storage](#data-storage) section for more information. + +The overall remote shuffle process is as follows: + +1. Flink job scheduler allocates shuffle resources for task's intermediate output + from `ShuffleMaster`. The `ShuffleMaster` then requests to the `ShuffleManager` through + the `ShuffleClient`. The `ShuffleManager` then assigns the allocated shuffle resource. After + that, a `ShuffleDescriptor` containing the shuffle resource information (target `ShuffleWorker` + address) will be return to the Flink job scheduler. + +2. Flink job scheduler will pass the `ShuffleDescriptor` information when deploying tasks. Then the + deployed tasks will write the shuffle data to the target `ShuffleWorker` contained in + the `ShuffleDescriptor`. + +3. For the downstream consumer tasks, Flink job scheduler will pass the `ShuffleDescriptor`s of + upstream tasks when scheduling them. Then each downstream consumer task knows where to read + shuffle data from and will consume those shuffle data. + +### High Availability + +The remote shuffle service relies on [Zookeeper](https://zookeeper.apache.org/) for high +availability. The `ShuffleManager` will write its communication address to Zookeeper and +the `ShuffleClient` and `ShuffleWorker` will get this information and communicate with +the `ShuffleManager` through RPC. For Yarn and Kubernetes deployment, if the `ShuffleManager` gets +down for some reason like node offline, a new `ShuffleManager` instance will be started and +the `ShuffleClient` and `ShuffleWorker` will be notified the new `ShuffleManager`'s communication +address (see the [deployment](#deployment) section for more information). + +### Data Storage + +Splitting a large dataset into multiple smaller partitions for parallel processing is the basics of +MapReduce computation model. The remote shuffle system can be seen as a MapReduce-aware storage of +these partitioned data. Each `ShuffleWorker` manages a portion of the shuffle data +and `DataPartition` is the smallest data unit. There are two types of `DataPartition`: + +**MapPartition:** A `MapPartition` contains all data produced by an upstream `MapTask`. It may +contain data to be consumed by multiple `ReduceTask`s. If the `MapTask` has multiple outputs, each +will be a `MapPartition`. The following picture shows the storage structure of a `MapPartition` and +its relationship with `MapTask` and `ReduceTask`: + +
+ +
+ +**ReducePartition:** A `ReducePartition` contains all data produced by all upstream `MapTask`s will +be consumed by one `ReduceTask`. If the `ReduceTask` has multiple inputs, each will be +a `ReducePartition`. The following picture shows the storage structure of a `ReducePartition` and +its relationship with `MapTask` and `ReduceTask`: + +
+ +
+ +**Note:** `ReducePartition` is not completely implemented yet and will be implemented soon in a +future version. The `MapPartition` implementation already achieves competitive performance. + +There are several built-in `DataPartition` implementations can be used and you can configure it by +changing the config value of `remote-shuffle.job.data-partition-factory-name` at Flink job side. All +supported `DataPartition` implementations are as follows: + +| Factory Name | Version | Description | +| ------------ | ------- | ----------- | +| `com.alibaba.flink.shuffle.storage.partition.LocalFileMapPartitionFactory` | 1.0.0 | A type of data partition which writes map tasks' output data to the local file. It can use both HDD and SSD, by default only the SSD will be used if there is any because the default value of `remote-shuffle.storage.preferred-disk-type` is `SSD`. However, if there is no SSD, the configured HDD will be used. | +| `com.alibaba.flink.shuffle.storage.partition.HDDOnlyLocalFileMapPartitionFactory` | 1.0.0 | Similar to `com.alibaba.flink.shuffle.storage.partition.LocalFileMapPartitionFactory` but it will always use HDD. | +| `com.alibaba.flink.shuffle.storage.partition.SSDOnlyLocalFileMapPartitionFactory` | 1.0.0 | Similar to `com.alibaba.flink.shuffle.storage.partition.LocalFileMapPartitionFactory` but it will always use SSD. | + +### Data Transmission + +Flink tasks will write data to and read data from the remote `ShuffleWorker`s. A task's output data +will be split into multiple `DataRegion`s. `DataRegion` is the basic data unit to be transferred and +written between Flink tasks and `ShuffleWorker`s which can contain multiple data buffers. Each data +region is a piece of data that can be consumed independently, which means a data region can not +contain any partial records and data compression should never span multiple data regions. As a +result, the `ShuffleWorker` can rearrange the data regions consumed by the same data consumer +freely (for `ReducePartition`). + +Before written to the remote `ShuffleWorker`, the shuffle data will be compressed at the Flink task +side in per-buffer granularity and the shuffle data will be then decompressed at the data consumer +task side after read from the remote `ShuffleWorker`. + +In addition, the remote shuffle system adopts Flink's credit-based backpressure mechanism, all +memory for data transmission & storage is managed and the TCP connection from the +same `TaskExecutor` to the same remote `ShuffleWorker` is reused. All of these can improve the +system stability. + +## How to Use + +Basically, there are three steps to use the remote shuffle service: + +1. Deploy a remote shuffle cluster. See the [deployment](#deployment) section for more information. + +2. Config Flink to use the remote shuffle service and run your batch jobs. See + the [configuration](#configuration) section for more information. + +3. Monitor and operate the remote shuffle cluster. See the [operations](#operations) section for + more information. + +Furthermore, there are some [best practices](#best-practices) and a [checklist](#checklist) that may +help if you are using the remote shuffle system in production. To quick start, please refer to +the [quick start guide](./quick_start.md) +. + +### Deployment + +The remote shuffle system supports three different deployment modes, including standalone, Yarn and +Kubernetes. It relies on Zookeeper for high availability. If you want to enable high availability, +you must have an available Zookeeper service first (deploy one or reuse the existing one). + +For standalone deployment mode, you can either enable or disable high availability (default is +disabled). If high availability is not enabled, you must config the `ShuffleManager` RPC address +explicitly. See +the [standalone deployment guide](./deploy_standalone_mode.md) +for more information. One weakness of standalone mode is that it can not tolerant the offline of +the `ShuffleManager` node, so it is not suggested using the standalone deployment mode in production +currently. In the future, a standby `ShuffleManager` may be introduced to solve the problem. + +For Yarn deployment mode, you must enable high availability and a Yarn environment is required. +The `ShuffleWorker` will run in Yarn `NodeManager` as auxiliary service and the `ShuffleManager` +will run as an independent Yarn application. The Yarn deployment mode can tolerant `ShuffleManager` +and `ShuffleWorker` crash. See +the [Yarn deployment guide](./deploy_on_yarn.md) +for more information. + +For Kubernetes deployment mode, like the Yarn deployment mode, you must enable high availability and +a Kubernetes environment is required. The `ShuffleManager` and `ShuffleWorker` will run as +Kubernetes application. The Kubernetes deployment mode can also tolerant `ShuffleManager` +and `ShuffleWorker` crash. See +the [Kubernetes deployment guide](./deploy_on_kubernetes.md) +for more information. + +### Configuration + +Before deploying the remote shuffle cluster, you need to config it properly. For Kubernetes and Yarn +deployment, you must enable high availability service: + +```yaml +remote-shuffle.high-availability.mode: ZOOKEEPER +remote-shuffle.ha.zookeeper.quorum: XXX +``` + +For standalone deployment, if high availability is disabled, you must config the `ShuffleManager` +RPC address (`ShuffleManager` ip address): + +```yaml +remote-shuffle.manager.rpc-address: XXX +``` + +For shuffle data storage, you must config the storage directories and the configured directories +must be existed on all the `ShuffleWorker` nodes: + +```yaml +remote-shuffle.storage.local-data-dirs: [SSD]/dir1,[HDD]/dir2 +``` + +If you are using Kubernetes deployment, aside from `remote-shuffle.storage.local-data-dirs` you +should also config the directory on host machine to be mounted to the `ShuffleWorker` pod: + +```yaml +remote-shuffle.kubernetes.worker.volume.host-paths: name:disk1,path:/data/1,mountPath:/opt/disk1 +``` + +At Flink side, you need to first add the shuffle plugin jar (usually with name +shuffle-plugin-XXX.jar) in shuffle-plugin/target directory of the remote shuffle project to Flink's +classpath (copy to Flink's lib directory). Then you need to config Flink to use the remote shuffle +service by adding the following configuration to Flink's configuration file: + +```yaml +shuffle-service-factory.class: com.alibaba.flink.shuffle.plugin.RemoteShuffleServiceFactory +``` + +At the same time, if high availability is enabled you need to add the following high availability +options to Flink configuration file: + +```yaml +remote-shuffle.high-availability.mode: ZOOKEEPER +remote-shuffle.ha.zookeeper.quorum: XXX +``` + +If high availability is not enabled, you need to add the `ShuffleManager` RPC +address (`ShuffleManager` ip address) to Flink configuration file: + +```yaml +remote-shuffle.manager.rpc-address: XXX +``` + +For Kubernetes deployment mode, you need to config the docker image to be used: + +```yaml +remote-shuffle.kubernetes.container.image: XXX +``` + +For more information about the configuration, please refer to +the [configuration document](./configuration.md) +. + +### Operations + +**Metrics:** There are some important metrics that may help you to monitor the cluster and more +metrics will be added in the future. You can get the `ShuffleManager` and `ShuffleWorker` metrics by +requesting the metrics server, for example, http://IP:PORT/metrics/. The IP should be +the `ShuffleManager` or `ShuffleWorker` external ip address and the PORT should the corresponding +metric server port (`23101` for `ShuffleManager` and `23103` for `ShuffleWorker` by default). All +supported metrics are as follows: + +| Metric Key | Type | Version | Description | +| ---------- | ---- | ------- | ----------- | +| `remote-shuffle.cluster.num_shuffle_workers` | Integer | 1.0.0 | Number of available shuffle workers in the remote shuffle cluster. | +| `remote-shuffle.cluster.num_jobs_serving` | Integer | 1.0.0 | Number of jobs under serving in the remote shuffle cluster. | +| `remote-shuffle.storage.num_data_partitions` | Integer | 1.0.0 | Number of `DataPartition`s stored in each `ShuffleWorker`. | +| `remote-shuffle.storage.num_available_writing_buffers` | Integer | 1.0.0 | Number of available memory buffers for data writing in each `ShuffleWorker`. | +| `remote-shuffle.storage.num_available_reading_buffers` | Integer | 1.0.0 | Number of available memory buffers for data reading in each `ShuffleWorker`. | +| `remote-shuffle.network.num_writing_connections` | Integer | 1.0.0 | Number of tcp connections used for data writing in each `ShuffleWorker` currently. | +| `remote-shuffle.network.num_reading_connections` | Integer | 1.0.0 | Number of tcp connections used for data reading in each `ShuffleWorker` currently. | +| `remote-shuffle.network.num_writing_flows` | Integer | 1.0.0 | Number of data writing channels used for data writing in each `ShuffleWorker` currently. Multiple writing channels may multiplex the same tcp connection. | +| `remote-shuffle.network.num_reading_flows` | Integer | 1.0.0 | Number of data reading channels used for data reading in each `ShuffleWorker` currently. Multiple reading channels may multiplex the same tcp connection. | +| `remote-shuffle.network.writing_throughput_bytes` | Double | 1.0.0 | Shuffle data writing throughput in bytes of each `ShuffleWorker` (including 1min, 5min and 15min). | +| `remote-shuffle.network.reading_throughput_bytes` | Integer | 1.0.0 | Shuffle data reading throughput in bytes of each `ShuffleWorker` (including 1min, 5min and 15min). | + +**Upgrading:** Currently, to upgrade the remote shuffle service, you need to first stop the previous +cluster and start a new one. Note: This can lead to the failover of running jobs which are using +this remote shuffle service. Hot upgrade without influencing the running jobs will be implemented in +a future version soon. + +### Fault Tolerance + +The remote shuffle system can tolerate exceptions at runtime and recover itself for most cases. The +following table is a list of important exceptions you may encounter together with the expected +influence and how remote shuffle system handles these exceptions. + +| Exception | Influence | Handling | +| --------- | --------- | -------- | +| Flink task failure (including failure caused by network issue) | Task failover | For writing, all relevant resources including data will be cleaned up. For reading, data will be kept and other resources will be cleaned up | +| Shuffle resource allocation failure | Task is not deployed yet and will be re-scheduled | No resource to clean up for the task is not running yet | +| Encounters exceptions when the remote shuffle system is handling data (including send, receive, write and read, etc) | Lead to failover of relevant Flink tasks | Similar to single Flink task failure, clean up all resources of relevant tasks | +| Lost or corruption of shuffle data | Lead to failure of the data consumer tasks and re-run of the data producer tasks | Throw a special exception to Flink which triggers failure of the data consumer tasks and re-run of the data producer tasks. At the same time, relevant resources including the corrupted data will be cleaned up | +| Crash of Flink `TaskManager` | Relevant Flink task failover | Same as failure of multiple Flink tasks and relevant resources will be cleaned up | +| Crash of Flink `JobMaster` | Failover and re-run of the entire job | Clean up all relevant resources of the job, including data | +| Crash of `ShuffleWorker` | Lead to failure of Flink tasks which are writing or reading the `ShuffleWorker` | The remote shuffle system relies on the external resource management system (YARN and Kubernetes, etc) to start a new `ShuffleWorker` instance. Data already produced will be taken over and relevant resources of the failed tasks will be cleaned up. If the new instance can be started in time (by default 3min), data stored in the failed `ShuffleWorker` will be reproduced | +| Crash of `ShuffleManager` | Flink jobs can not allocate any new shuffle resources for a while which lead to the re-schedule of the corresponding tasks | The remote shuffle system relies on the external resource management system (YARN and Kubernetes, etc) to start a new `ShuffleManager` instance. All `ShuffleWorker`s will register and report information to the new `ShuffleManager` instance. If the new `ShuffleManager` instance can not be started, remote shuffle will become unavailable (Note that this is nearly virtually impossible, because the `ShuffleManager` instance is a plain Java process and can be started on any node). | +| `ShuffleClient` registration failure | Lead to restart of Flink `JobMaster` and failover of the whole job | Same as crash of `JobMaster` and relevant resources will be cleaned up | +| Startup issue caused by illegal configuration, etc | Remote shuffle will be unavailable | It can not recover itself and you need to check the log, find out the root cause and correct it | +| Unavailability of `Zookeeper` | Remote shuffle will be unavailable | It can not recover itself and you need to recover `Zookeeper` manually as soon as possible | + + +### Best Practices + +There are some best practices that may help you to improve performance and stability. + +1. By default, the remote shuffle service will use host network for data communication in Kubernetes + deployment which should have better performance. It is not suggested disabling host network. + +2. The default configuration should be good enough for medium-scale batch jobs. If you are running + large-scale batch jobs, increasing `remote-shuffle.job.memory-per-partition` + and `remote-shuffle.job.memory-per-gate` may help to increase performance. + +3. For Kubernetes deployment, we suggest you to increase the CPU of `ShuffleManager` + and `ShuffleWorker`. At the same time, you can increase the CPU limit, for + example: `remote-shuffle.kubernetes.manager.cpu: 8.0` + , `remote-shuffle.kubernetes.worker.cpu: 4.0` + , `remote-shuffle.kubernetes.manager.limit-factor.cpu: 4` + , `remote-shuffle.kubernetes.worker.limit-factor.cpu: 4`. + +4. For large-scale jobs, if there is no enough resources to run all tasks of the same stage + concurrently. It is suggested to decrease the parallelism which can lead to better performance. + +5. You should always specify the data storage disk type (SSD or HDD) explicitly for better + performance. + +6. If you are using SSD for data storage, you can + decrease `remote-shuffle.job.concurrent-readings-per-gate` to reduce the network stress + of `ShuffleWorker`s. If you are using HDD, it is suggested to keep the default config value for + better performance. + +7. If your network is slow, increasing `remote-shuffle.transfer.send-receive-buffer-size` may help + to solve the network write timeout issue of shuffle data transmission. + +8. You can add all disks (including SSD and HDD) at the shuffle cluster side. Then you can choose to + use SSD or HDD only by configuring `remote-shuffle.job.data-partition-factory-name` to + either `com.alibaba.flink.shuffle.storage.partition.HDDOnlyLocalFileMapPartitionFactory` + or `com.alibaba.flink.shuffle.storage.partition.SSDOnlyLocalFileMapPartitionFactory` at Flink job + side freely. + +9. The default memory configuration for `ShuffleWorker` should be good enough for most use cases. + But if there are a large number of jobs and the `ShuffleWorker`s are under pressure, + increasing `remote-shuffle.memory.data-writing-size` + and `remote-shuffle.memory.data-reading-size` can help increase `ShuffleWorker`s' service + capacity. If `ShuffleWorker`s' service capacity is enough and you would like to reduce the memory + consumption, you can also decrease this value. + +10. The default memory configuration for `ShuffleManager` should be good enough for clusters of + several thousands of nodes. If your cluster is larger than that, it is better to increase the + heap size of `ShuffleManager` by increasing `remote-shuffle.manager.memory.heap-size` for + Kubernetes deployment mode or XXX for Yarn deployment mode. + +### Checklist + +There is a list of items that you may need to check when you are deploying and using the remote +shuffle service. + +1. To use the remote shuffle service, you must + add `shuffle-service-factory.class: com.alibaba.flink.shuffle.plugin.RemoteShuffleServiceFactory` + to Flink configuration. + +2. If high availability is not enabled, you must configure the shuffle manager RPC address by + setting `remote-shuffle.manager.rpc-address` for both the remote shuffle cluster and Flink + cluster. + +3. For high availability deployment mode, you must configure `remote-shuffle.ha.zookeeper.quorum` + for both the remote shuffle cluster and Flink cluster. + +4. For Kubernetes deployment mode, you must configure the docker image by + setting `remote-shuffle.kubernetes.container.image` and at the same time you must configure + either `remote-shuffle.kubernetes.worker.volume.empty-dirs` + or `remote-shuffle.kubernetes.worker.volume.host-paths` to set data storage. + +5. For all deployment modes, you must configure the data storage directory by + setting `remote-shuffle.storage.local-data-dirs`. + +6. There are some default ports which may lead to the `address already in use` exception, you may + need you change the default value: `remote-shuffle.manager.rpc-port: 23123` + , `remote-shuffle.metrics.manager.bind-port: 23101` + , `remote-shuffle.metrics.worker.bind-port: 23103` + , `remote-shuffle.transfer.server.data-port: 10086`. + +7. For Kubernetes and Yarn deployment, you must enable high-availability mode. + +8. For Kubernetes deployment mode, you need to set configuration in the Kubernetes deployment YAML + file. For Yarn deployment mode, you need to set the `ShuffleWorker` configuration in the + yarn-site.xml file and set the `ShuffleManager` configuration in the remote shuffle configuration + file (conf/remote-shuffle-conf.yaml). For standalone deployment mode, you need to set + configuration in the remote shuffle configuration file (conf/remote-shuffle-conf.yaml). diff --git a/pom.xml b/pom.xml new file mode 100644 index 00000000..87babb39 --- /dev/null +++ b/pom.xml @@ -0,0 +1,534 @@ + + + + + 4.0.0 + + com.alibaba.flink.shuffle + flink-shuffle-parent + pom + 1.0-SNAPSHOT + + + + The Apache Software License, Version 2.0 + https://www.apache.org/licenses/LICENSE-2.0.txt + repo + + + + + 2.11 + 13.0 + 2.12.1 + 4.13 + 2.21.0 + 1.3 + 1.3.2 + 11.0.2 + 1.7.15 + 2.12.1 + 2.4.2 + 1.8 + ${target.java.version} + ${target.java.version} + 2.12.1 + 1.27 + 2.12.0 + 2.4.1 + 2.0.6 + 1.3.1 + + 3.4.14-14.0 + 1.14.0 + + + + shuffle-common + shuffle-coordinator + shuffle-core + shuffle-kubernetes + shuffle-plugin + shuffle-dist + shuffle-e2e-tests + shuffle-storage + shuffle-transfer + shuffle-kubernetes-operator + shuffle-metrics + shuffle-yarn + shuffle-rpc + shuffle-examples + + + + + + + org.slf4j + slf4j-api + + + + + com.google.code.findbugs + jsr305 + + + + + + junit + junit + jar + test + + + + org.hamcrest + hamcrest-all + ${hamcrest.version} + jar + test + + + + + + org.apache.logging.log4j + log4j-slf4j-impl + test + + + + org.apache.logging.log4j + log4j-api + test + + + + org.apache.logging.log4j + log4j-core + test + + + + + + + + javax.annotation + javax.annotation-api + ${javax.annotation-api.version} + + + + + com.google.code.findbugs + jsr305 + 1.3.9 + + + + junit + junit + ${junit.version} + + + + org.slf4j + slf4j-api + ${slf4j.version} + + + + org.apache.logging.log4j + log4j-slf4j-impl + ${log4j.version} + + + + org.apache.logging.log4j + log4j-api + ${log4j.version} + + + + org.apache.logging.log4j + log4j-core + ${log4j.version} + + + + org.apache.logging.log4j + log4j-1.2-api + ${log4j.version} + + + + org.mockito + mockito-core + ${mockito.version} + jar + test + + + + + org.yaml + snakeyaml + ${snakeyaml.version} + + + + com.fasterxml.jackson + jackson-bom + pom + import + ${jackson.version} + + + + + + + default + + true + + + **/*E2ETest.java + + + + includeE2E + + + + + + + + + + + org.apache.rat + apache-rat-plugin + 0.12 + false + + + verify + + check + + + + + false + 0 + + + + AL2 + Apache License 2.0 + + + Licensed to the Apache Software Foundation (ASF) under + one + + + + + + + Apache License 2.0 + + + + **/.*/** + **/*.log + + **/README.md + + **/*.iml + + build-target/** + **/target/** + + **/shuffle-bin/conf/workers + **/shuffle-bin/conf/managers + **/resources/data_for_storage_compatibility_test/** + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 2.22.1 + + + org.apache.maven.surefire + surefire-logger-api + 2.21.0 + true + + + + 1 + true + false + + 0${surefire.forkNumber} + ${project.basedir} + ${project.build.directory} + + -Xms256m -Xmx2048m -Dmvn.forkNumber=${surefire.forkNumber} + -XX:+UseG1GC + + + + + default-test + test + + test + + + + ${exclude-test-classes} + + + + + + + integration-tests + integration-test + + test + + + + **/*ITCase.* + + false + + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.17 + + + com.puppycrawl.tools + checkstyle + + 8.14 + + + + + validate + validate + + check + + + + + /tools/maven/suppressions.xml + true + /tools/maven/checkstyle.xml + true + true + + + + + org.apache.maven.plugins + maven-jar-plugin + 2.2 + + + + test-jar + + + + + + + org.apache.maven.plugins + maven-source-plugin + 3.2.1 + + + attach-sources + + jar + + + + + + + com.diffplug.spotless + spotless-maven-plugin + ${spotless.version} + + + + 1.7 + + + + + + + com.alibaba.flink.shuffle,org.apache.flink,org.apache.flink.shaded,,javax,java,scala,\# + + + + + + + + + spotless-check + validate + + check + + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + 3.0.0-M1 + + + enforce-maven + + enforce + + + + + + [3.1.1,) + + + ${target.java.version} + + + + + + ban-unsafe-snakeyaml + + enforce + + + + + + org.yaml:snakeyaml:(,1.26] + + + + org.yaml:snakeyaml:(,1.26]:*:test + + + + + + + ban-unsafe-jackson + + enforce + + + + + + com.fasterxml.jackson*:*:(,2.12.0] + + + + + + + forbid-log4j-1 + + enforce + + + + + + log4j:log4j + org.slf4j:slf4j-log4j12 + + + + + + + dependency-convergence + + none + + enforce + + + + + + + + + + + + diff --git a/shuffle-common/pom.xml b/shuffle-common/pom.xml new file mode 100644 index 00000000..4162e84c --- /dev/null +++ b/shuffle-common/pom.xml @@ -0,0 +1,34 @@ + + + + + + com.alibaba.flink.shuffle + flink-shuffle-parent + 1.0-SNAPSHOT + + 4.0.0 + + shuffle-common + + + + + diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/config/ConfigOption.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/config/ConfigOption.java new file mode 100644 index 00000000..67ef27f6 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/config/ConfigOption.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.config; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import javax.annotation.Nullable; + +/** Utility class to define config options. */ +public class ConfigOption { + + private final String key; + + @Nullable private T defaultValue; + + private String description; + + public ConfigOption(String key) { + CommonUtils.checkArgument(key != null, "Must be not null."); + this.key = key; + } + + public String key() { + return key; + } + + public T defaultValue() { + return defaultValue; + } + + public ConfigOption defaultValue(@Nullable T defaultValue) { + this.defaultValue = defaultValue; + return this; + } + + public String description() { + return description; + } + + public ConfigOption description(String description) { + CommonUtils.checkArgument(description != null, "Must be not null."); + + this.description = description; + return this; + } + + @Override + public String toString() { + return key; + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/config/Configuration.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/config/Configuration.java new file mode 100644 index 00000000..8f2c28bb --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/config/Configuration.java @@ -0,0 +1,717 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.config; + +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.TimeUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.time.Duration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.stream.Collectors; + +import static com.alibaba.flink.shuffle.common.config.StructuredOptionsSplitter.escapeWithSingleQuote; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; + +/** A simple read-only configuration implementation based on {@link Properties}. */ +public class Configuration { + + private static final Logger LOG = LoggerFactory.getLogger(Configuration.class); + + public static final String REMOTE_SHUFFLE_CONF_FILENAME = "remote-shuffle-conf.yaml"; + + private final Properties configuration; + + public Configuration() { + this.configuration = new Properties(); + } + + public Configuration(String confDir) throws IOException { + CommonUtils.checkArgument(confDir != null, "Must be not null."); + + this.configuration = loadConfiguration(confDir); + } + + public Configuration(Properties configuration) { + CommonUtils.checkArgument(configuration != null, "Must be not null."); + + this.configuration = new Properties(); + this.configuration.putAll(configuration); + } + + /** Dynamic configuration has higher priority than that loaded from configuration file. */ + public Configuration(String confDir, Properties dynamicConfiguration) throws IOException { + CommonUtils.checkArgument(confDir != null, "Must be not null."); + CommonUtils.checkArgument(dynamicConfiguration != null, "Must be not null."); + + this.configuration = new Properties(); + this.configuration.putAll(loadConfiguration(confDir)); + this.configuration.putAll(dynamicConfiguration); + } + + public Configuration(Configuration other) { + CommonUtils.checkArgument(other != null, "Must be not null."); + CommonUtils.checkArgument(other.toProperties() != null, "Must be not null."); + + this.configuration = new Properties(); + this.configuration.putAll(other.toProperties()); + } + + public void addAll(Configuration other) { + CommonUtils.checkArgument(other != null, "Must be not null."); + CommonUtils.checkArgument(other.toProperties() != null, "Must be not null."); + + this.configuration.putAll(other.toProperties()); + } + + private void setValueInternal(String key, T value) { + CommonUtils.checkArgument( + key != null && !key.trim().isEmpty(), "key must not be null or empty."); + configuration.put(key, convertToString(value)); + } + + private void setValueInternal(ConfigOption option, T value) { + CommonUtils.checkArgument(option != null); + + setValueInternal(option.key(), value); + } + + public void setByte(ConfigOption option, byte value) { + setValueInternal(option, value); + } + + public void setShort(ConfigOption option, short value) { + setValueInternal(option, value); + } + + public void setInteger(ConfigOption option, int value) { + setValueInternal(option, value); + } + + public void setLong(ConfigOption option, long value) { + setValueInternal(option, value); + } + + public void setFloat(ConfigOption option, float value) { + setValueInternal(option, value); + } + + public void setDouble(ConfigOption option, double value) { + setValueInternal(option, value); + } + + public void setBoolean(ConfigOption option, boolean value) { + setValueInternal(option, value); + } + + public void setString(ConfigOption option, String value) { + setValueInternal(option, value); + } + + public void setString(String key, String value) { + setValueInternal(key, value); + } + + public void setDuration(ConfigOption option, Duration value) { + setValueInternal(option, value); + } + + public void setMemorySize(ConfigOption option, MemorySize value) { + setValueInternal(option, value); + } + + public void setMap(ConfigOption> option, Map value) { + setValueInternal(option, value); + } + + public void setList(ConfigOption option, T value) { + setValueInternal(option, value); + } + + public Byte getByte(String key) { + return getByte(key, null); + } + + public Byte getByte(ConfigOption configOption) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getByte(configOption.key(), configOption.defaultValue()); + } + + public Byte getByte(ConfigOption configOption, Byte defaultValue) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getByte(configOption.key(), defaultValue); + } + + public Byte getByte(String key, Byte defaultValue) { + CommonUtils.checkArgument(key != null, "Must be not null."); + + String value = configuration.getProperty(key); + if (value == null) { + return defaultValue; + } + + try { + return convertToByte(value); + } catch (Exception exception) { + throw new ConfigurationException("Illegal config value for " + key + "."); + } + } + + public Short getShort(String key) { + return getShort(key, null); + } + + public Short getShort(ConfigOption configOption) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getShort(configOption.key(), configOption.defaultValue()); + } + + public Short getShort(ConfigOption configOption, Short defaultValue) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getShort(configOption.key(), defaultValue); + } + + public Short getShort(String key, Short defaultValue) { + CommonUtils.checkArgument(key != null, "Must be not null."); + + String value = configuration.getProperty(key); + if (value == null) { + return defaultValue; + } + + try { + return convertToShort(value); + } catch (Exception exception) { + throw new ConfigurationException("Illegal config value for " + key + "."); + } + } + + public Integer getInteger(String key) { + return getInteger(key, null); + } + + public Integer getInteger(ConfigOption configOption) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getInteger(configOption.key(), configOption.defaultValue()); + } + + public Integer getInteger(ConfigOption configOption, Integer defaultValue) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getInteger(configOption.key(), defaultValue); + } + + public Integer getInteger(String key, Integer defaultValue) { + CommonUtils.checkArgument(key != null, "Must be not null."); + + String value = configuration.getProperty(key); + if (value == null) { + return defaultValue; + } + + try { + return convertToInteger(value); + } catch (Exception exception) { + throw new ConfigurationException("Illegal config value for " + key + "."); + } + } + + public Long getLong(String key) { + return getLong(key, null); + } + + public Long getLong(ConfigOption configOption) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getLong(configOption.key(), configOption.defaultValue()); + } + + public Long getLong(ConfigOption configOption, Long defaultValue) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getLong(configOption.key(), defaultValue); + } + + public Long getLong(String key, Long defaultValue) { + CommonUtils.checkArgument(key != null, "Must be not null."); + + String value = configuration.getProperty(key); + if (value == null) { + return defaultValue; + } + + try { + return convertToLong(value); + } catch (Exception exception) { + throw new ConfigurationException("Illegal config value for " + key + "."); + } + } + + public Double getDouble(String key) { + return getDouble(key, null); + } + + public Double getDouble(ConfigOption configOption) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getDouble(configOption.key(), configOption.defaultValue()); + } + + public Double getDouble(ConfigOption configOption, Double defaultValue) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getDouble(configOption.key(), defaultValue); + } + + public Double getDouble(String key, Double defaultValue) { + CommonUtils.checkArgument(key != null, "Must be not null."); + + String value = configuration.getProperty(key); + if (value == null) { + return defaultValue; + } + + try { + return convertToDouble(value); + } catch (Exception exception) { + throw new ConfigurationException("Illegal config value for " + key + "."); + } + } + + public Float getFloat(String key) { + return getFloat(key, null); + } + + public Float getFloat(ConfigOption configOption) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getFloat(configOption.key(), configOption.defaultValue()); + } + + public Float getFloat(ConfigOption configOption, Float defaultValue) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getFloat(configOption.key(), defaultValue); + } + + public Float getFloat(String key, Float defaultValue) { + CommonUtils.checkArgument(key != null, "Must be not null."); + + String value = configuration.getProperty(key); + if (value == null) { + return defaultValue; + } + + try { + return convertToFloat(value); + } catch (Exception exception) { + throw new ConfigurationException("Illegal config value for " + key + "."); + } + } + + public Boolean getBoolean(String key) { + return getBoolean(key, null); + } + + public Boolean getBoolean(ConfigOption configOption) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getBoolean(configOption.key(), configOption.defaultValue()); + } + + public Boolean getBoolean(ConfigOption configOption, Boolean defaultValue) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getBoolean(configOption.key(), defaultValue); + } + + public Boolean getBoolean(String key, Boolean defaultValue) { + CommonUtils.checkArgument(key != null, "Must be not null."); + + String value = configuration.getProperty(key); + if (value == null) { + return defaultValue; + } + + try { + return convertToBoolean(value); + } catch (Exception exception) { + throw new ConfigurationException( + "Illegal boolean config value for " + key + ", must be 'true' or 'false'."); + } + } + + public String getString(String key) { + return getString(key, null); + } + + public String getString(ConfigOption configOption) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getString(configOption.key(), configOption.defaultValue()); + } + + public String getString(ConfigOption configOption, String defaultValue) { + CommonUtils.checkArgument(configOption != null, "Must be not null."); + + return getString(configOption.key(), defaultValue); + } + + public String getString(String key, String defaultValue) { + CommonUtils.checkArgument(key != null, "Must be not null."); + + String value = configuration.getProperty(key); + if (value == null) { + return defaultValue; + } + return value; + } + + public MemorySize getMemorySize(ConfigOption configOption) { + checkArgument(configOption != null, "Must be not null."); + return getMemorySize(configOption, configOption.defaultValue()); + } + + public MemorySize getMemorySize( + ConfigOption configOption, MemorySize defaultValue) { + checkArgument(configOption != null, "Must be not null."); + return getMemorySize(configOption.key(), defaultValue); + } + + public MemorySize getMemorySize(String key, MemorySize defaultValue) { + checkArgument(key != null, "Must be not null."); + + String value = configuration.getProperty(key); + if (value == null) { + return defaultValue; + } + + return convertToMemorySize(value); + } + + public Duration getDuration(ConfigOption configOption) { + checkArgument(configOption != null, "Must be not null."); + return getDuration(configOption, configOption.defaultValue()); + } + + public Duration getDuration(ConfigOption configOption, Duration defaultValue) { + checkArgument(configOption != null, "Must be not null."); + return getDuration(configOption.key(), defaultValue); + } + + public Duration getDuration(String key, Duration defaultValue) { + checkArgument(key != null, "Must be not null."); + + String value = configuration.getProperty(key); + if (value == null) { + return defaultValue; + } + + return convertToDuration(value); + } + + public Map getMap(ConfigOption> configOption) { + checkArgument(configOption != null, "Must be not null."); + return getMap(configOption, configOption.defaultValue()); + } + + public Map getMap( + ConfigOption> configOption, Map defaultValue) { + checkArgument(configOption != null, "Must be not null."); + return getMap(configOption.key(), defaultValue); + } + + public Map getMap(String key, Map defaultValue) { + checkArgument(key != null, "Must be not null."); + + String value = configuration.getProperty(key); + if (value == null) { + return defaultValue; + } + + return convertToMap(value); + } + + @SuppressWarnings("unchecked") + public T getList(ConfigOption configOption, Class clazz) { + checkArgument(configOption != null && configOption.key() != null); + + String value = configuration.getProperty(configOption.key()); + + if (value == null) { + return configOption.defaultValue(); + } + + return (T) + StructuredOptionsSplitter.splitEscaped(value, ';').stream() + .map(s -> convertValue(s, clazz)) + .collect(Collectors.toList()); + } + + private static Byte convertToByte(String value) { + return Byte.parseByte(value); + } + + private static Integer convertToInteger(String value) { + return Integer.parseInt(value); + } + + private static Long convertToLong(String value) { + return Long.parseLong(value); + } + + private static Short convertToShort(String value) { + return Short.parseShort(value); + } + + private static Double convertToDouble(String value) { + return Double.parseDouble(value); + } + + private static Float convertToFloat(String value) { + return Float.parseFloat(value); + } + + private static Boolean convertToBoolean(String value) { + if (value.equalsIgnoreCase("true")) { + return Boolean.TRUE; + } else if (value.equalsIgnoreCase("false")) { + return Boolean.FALSE; + } else { + throw new IllegalArgumentException( + "Illegal boolean config value, must be 'true' or 'false'."); + } + } + + private static Duration convertToDuration(String value) { + return TimeUtils.parseDuration(value); + } + + private static MemorySize convertToMemorySize(String value) { + return MemorySize.parse(value); + } + + private static Map convertToMap(String value) { + List listOfRawProperties = StructuredOptionsSplitter.splitEscaped(value, ','); + return listOfRawProperties.stream() + .map(s -> StructuredOptionsSplitter.splitEscaped(s, ':')) + .peek( + pair -> { + if (pair.size() != 2) { + throw new IllegalArgumentException( + "Could not parse pair in the map " + pair); + } + }) + .collect(Collectors.toMap(a -> a.get(0), a -> a.get(1))); + } + + /** Get the value of option. Return null if value key not present in the configuration. */ + @SuppressWarnings("unchecked") + private static T convertValue(String value, Class clazz) { + if (Byte.class.equals(clazz)) { + return (T) convertToByte(value); + } else if (Integer.class.equals(clazz)) { + return (T) convertToInteger(value); + } else if (Long.class.equals(clazz)) { + return (T) convertToLong(value); + } else if (Short.class.equals(clazz)) { + return (T) convertToShort(value); + } else if (Double.class.equals(clazz)) { + return (T) convertToDouble(value); + } else if (Float.class.equals(clazz)) { + return (T) convertToFloat(value); + } else if (Boolean.class.equals(clazz)) { + return (T) convertToBoolean(value); + } else if (String.class.equals(clazz)) { + return (T) value; + } else if (clazz == Duration.class) { + return (T) convertToDuration(value); + } else if (clazz == MemorySize.class) { + return (T) convertToMemorySize(value); + } else if (clazz == Map.class) { + return (T) convertToMap(value); + } + + throw new IllegalArgumentException("Unsupported type: " + clazz); + } + + public Map toMap() { + Map result = new HashMap<>(); + for (String propertyName : configuration.stringPropertyNames()) { + result.put(propertyName, configuration.getProperty(propertyName)); + } + return result; + } + + public static Configuration fromMap(Map map) { + Properties properties = new Properties(); + for (Map.Entry entry : map.entrySet()) { + properties.setProperty(entry.getKey(), entry.getValue()); + } + return new Configuration(properties); + } + + public Properties toProperties() { + Properties clonedConfiguration = new Properties(); + clonedConfiguration.putAll(configuration); + return clonedConfiguration; + } + + /** Loads configuration from the configuration file. */ + /** + * Loads a YAML-file of key-value pairs. + * + *

Colon and whitespace ": " separate key and value (one per line). The hash tag "#" starts a + * single-line comment. + * + *

Example: + * + *

+     * remote-shuffle.manager.rpc-address: localhost # network address for communication with the shuffle manager
+     * remote-shuffle.manager.rpc-port   : 23123     # network port to connect to for communication with the shuffle manager
+     * 
+ * + *

This does not span the whole YAML specification, but only the *syntax* of simple YAML + * key-value pairs (see issue #113 on GitHub). If at any point in time, there is a need to go + * beyond simple key-value pairs syntax compatibility will allow to introduce a YAML parser + * library. + * + * @param confDir the conf dir. + * @see YAML 1.2 specification + */ + /** Loads configuration from the configuration file. */ + private static Properties loadConfiguration(String confDir) throws IOException { + File confFile = new File(confDir, REMOTE_SHUFFLE_CONF_FILENAME); + Properties configuration = new Properties(); + + if (!confFile.exists()) { + LOG.warn( + "Configuration file {} does not exist, only dynamic parameters will be used.", + confFile.getAbsolutePath()); + return configuration; + } + + if (!confFile.isFile()) { + throw new ConfigurationException( + String.format( + "Configuration file %s is not a normal file.", + confFile.getAbsoluteFile())); + } + + LOG.info("Loading configurations from config file: {}", confFile); + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(new FileInputStream(confFile)))) { + + String line; + int lineNo = 0; + while ((line = reader.readLine()) != null) { + lineNo++; + // 1. check for comments + String[] comments = line.split("#", 2); + String conf = comments[0].trim(); + + // 2. get key and value + if (conf.length() > 0) { + String[] kv = conf.split(": ", 2); + + // skip line with no valid key-value pair + if (kv.length == 1) { + LOG.warn( + "Error while trying to split key and value in configuration file " + + confFile + + ":" + + lineNo + + ": \"" + + line + + "\""); + continue; + } + + String key = kv[0].trim(); + String value = kv[1].trim(); + + // sanity check + if (key.length() == 0 || value.length() == 0) { + LOG.warn( + "Error after splitting key and value in configuration file " + + confFile + + ":" + + lineNo + + ": \"" + + line + + "\""); + continue; + } + + LOG.info("Loading configuration property: {}, {}", key, value); + configuration.setProperty(key, value); + } + } + } catch (IOException e) { + throw new RuntimeException("Error parsing YAML configuration.", e); + } + return configuration; + } + + private static String convertToString(Object o) { + if (o.getClass() == String.class) { + return (String) o; + } else if (o.getClass() == Duration.class) { + Duration duration = (Duration) o; + return String.format("%d ns", duration.toNanos()); + } else if (o instanceof List) { + return ((List) o) + .stream() + .map(e -> escapeWithSingleQuote(convertToString(e), ";")) + .collect(Collectors.joining(";")); + } else if (o instanceof Map) { + return ((Map) o) + .entrySet().stream() + .map( + e -> { + String escapedKey = + escapeWithSingleQuote(e.getKey().toString(), ":"); + String escapedValue = + escapeWithSingleQuote(e.getValue().toString(), ":"); + + return escapeWithSingleQuote( + escapedKey + ":" + escapedValue, ","); + }) + .collect(Collectors.joining(",")); + } + + return o.toString(); + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/config/MemorySize.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/config/MemorySize.java new file mode 100644 index 00000000..da92101f --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/config/MemorySize.java @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.config; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.stream.IntStream; + +import static com.alibaba.flink.shuffle.common.config.MemorySize.MemoryUnit.BYTES; +import static com.alibaba.flink.shuffle.common.config.MemorySize.MemoryUnit.GIGA_BYTES; +import static com.alibaba.flink.shuffle.common.config.MemorySize.MemoryUnit.KILO_BYTES; +import static com.alibaba.flink.shuffle.common.config.MemorySize.MemoryUnit.MEGA_BYTES; +import static com.alibaba.flink.shuffle.common.config.MemorySize.MemoryUnit.TERA_BYTES; +import static com.alibaba.flink.shuffle.common.config.MemorySize.MemoryUnit.hasUnit; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** + * MemorySize is a representation of a number of bytes, viewable in different units. + * + *

Parsing

+ * + *

The size can be parsed from a text expression. If the expression is a pure number, the value + * will be interpreted as bytes. + * + *

This class is copied from Apache Flink (org.apache.flink.configuration.MemorySize) + */ +public class MemorySize implements java.io.Serializable, Comparable { + + private static final long serialVersionUID = 450443291938254568L; + + public static final MemorySize ZERO = new MemorySize(0L); + + public static final MemorySize MAX_VALUE = new MemorySize(Long.MAX_VALUE); + + private static final List ORDERED_UNITS = + Arrays.asList(BYTES, KILO_BYTES, MEGA_BYTES, GIGA_BYTES, TERA_BYTES); + + // ------------------------------------------------------------------------ + + /** The memory size, in bytes. */ + private final long bytes; + + /** The memorized value returned by toString(). */ + private transient String stringified; + + /** The memorized value returned by toHumanReadableString(). */ + private transient String humanReadableStr; + + /** + * Constructs a new MemorySize. + * + * @param bytes The size, in bytes. Must be zero or larger. + */ + public MemorySize(long bytes) { + checkArgument(bytes >= 0, "bytes must be >= 0"); + this.bytes = bytes; + } + + public static MemorySize ofMebiBytes(long mebiBytes) { + return new MemorySize(mebiBytes << 20); + } + + // ------------------------------------------------------------------------ + + /** Gets the memory size in bytes. */ + public long getBytes() { + return bytes; + } + + /** Gets the memory size in Kibibytes (= 1024 bytes). */ + public long getKibiBytes() { + return bytes >> 10; + } + + /** Gets the memory size in Mebibytes (= 1024 Kibibytes). */ + public int getMebiBytes() { + return (int) (bytes >> 20); + } + + /** Gets the memory size in Gibibytes (= 1024 Mebibytes). */ + public long getGibiBytes() { + return bytes >> 30; + } + + /** Gets the memory size in Tebibytes (= 1024 Gibibytes). */ + public long getTebiBytes() { + return bytes >> 40; + } + + // ------------------------------------------------------------------------ + + @Override + public int hashCode() { + return (int) (bytes ^ (bytes >>> 32)); + } + + @Override + public boolean equals(Object obj) { + return obj == this + || (obj != null + && obj.getClass() == this.getClass() + && ((MemorySize) obj).bytes == this.bytes); + } + + @Override + public String toString() { + if (stringified == null) { + stringified = formatToString(); + } + + return stringified; + } + + private String formatToString() { + MemoryUnit highestIntegerUnit = + IntStream.range(0, ORDERED_UNITS.size()) + .sequential() + .filter(idx -> bytes % ORDERED_UNITS.get(idx).getMultiplier() != 0) + .boxed() + .findFirst() + .map( + idx -> { + if (idx == 0) { + return ORDERED_UNITS.get(0); + } else { + return ORDERED_UNITS.get(idx - 1); + } + }) + .orElse(BYTES); + + return String.format( + "%d %s", + bytes / highestIntegerUnit.getMultiplier(), highestIntegerUnit.getUnits()[1]); + } + + public String toHumanReadableString() { + if (humanReadableStr == null) { + humanReadableStr = formatToHumanReadableString(); + } + + return humanReadableStr; + } + + private String formatToHumanReadableString() { + MemoryUnit highestUnit = + IntStream.range(0, ORDERED_UNITS.size()) + .sequential() + .filter(idx -> bytes > ORDERED_UNITS.get(idx).getMultiplier()) + .boxed() + .max(Comparator.naturalOrder()) + .map(ORDERED_UNITS::get) + .orElse(BYTES); + + if (highestUnit == BYTES) { + return String.format("%d %s", bytes, BYTES.getUnits()[1]); + } else { + double approximate = 1.0 * bytes / highestUnit.getMultiplier(); + return String.format( + Locale.ROOT, + "%.3f%s (%d bytes)", + approximate, + highestUnit.getUnits()[1], + bytes); + } + } + + @Override + public int compareTo(MemorySize that) { + return Long.compare(this.bytes, that.bytes); + } + + // ------------------------------------------------------------------------ + // Calculations + // ------------------------------------------------------------------------ + + public MemorySize add(MemorySize that) { + return new MemorySize(Math.addExact(this.bytes, that.bytes)); + } + + public MemorySize subtract(MemorySize that) { + return new MemorySize(Math.subtractExact(this.bytes, that.bytes)); + } + + public MemorySize multiply(double multiplier) { + checkArgument(multiplier >= 0, "multiplier must be >= 0"); + + BigDecimal product = + BigDecimal.valueOf(this.bytes).multiply(BigDecimal.valueOf(multiplier)); + if (product.compareTo(BigDecimal.valueOf(Long.MAX_VALUE)) > 0) { + throw new ArithmeticException("long overflow"); + } + return new MemorySize(product.longValue()); + } + + public MemorySize divide(long by) { + checkArgument(by >= 0, "divisor must be >= 0"); + return new MemorySize(bytes / by); + } + + // ------------------------------------------------------------------------ + // Parsing + // ------------------------------------------------------------------------ + + /** + * Parses the given string as as MemorySize. + * + * @param text The string to parse + * @return The parsed MemorySize + * @throws IllegalArgumentException Thrown, if the expression cannot be parsed. + */ + public static MemorySize parse(String text) throws IllegalArgumentException { + return new MemorySize(parseBytes(text)); + } + + /** + * Parses the given string with a default unit. + * + * @param text The string to parse. + * @param defaultUnit specify the default unit. + * @return The parsed MemorySize. + * @throws IllegalArgumentException Thrown, if the expression cannot be parsed. + */ + public static MemorySize parse(String text, MemoryUnit defaultUnit) + throws IllegalArgumentException { + if (!hasUnit(text)) { + return parse(text + defaultUnit.getUnits()[0]); + } + + return parse(text); + } + + /** + * Parses the given string as bytes. The supported expressions are listed under {@link + * MemorySize}. + * + * @param text The string to parse + * @return The parsed size, in bytes. + * @throws IllegalArgumentException Thrown, if the expression cannot be parsed. + */ + public static long parseBytes(String text) throws IllegalArgumentException { + checkNotNull(text); + + final String trimmed = text.trim(); + checkArgument(!trimmed.isEmpty(), "argument is an empty- or whitespace-only string"); + + final int len = trimmed.length(); + int pos = 0; + + char current; + while (pos < len && (current = trimmed.charAt(pos)) >= '0' && current <= '9') { + pos++; + } + + final String number = trimmed.substring(0, pos); + final String unit = trimmed.substring(pos).trim().toLowerCase(Locale.US); + + if (number.isEmpty()) { + throw new NumberFormatException("text does not start with a number"); + } + + final long value; + try { + value = Long.parseLong(number); // this throws a NumberFormatException on overflow + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + "The value '" + + number + + "' cannot be re represented as 64bit number (numeric overflow)."); + } + + final long multiplier = parseUnit(unit).map(MemoryUnit::getMultiplier).orElse(1L); + final long result = value * multiplier; + + // check for overflow + if (result / multiplier != value) { + throw new IllegalArgumentException( + "The value '" + + text + + "' cannot be re represented as 64bit number of bytes (numeric overflow)."); + } + + return result; + } + + private static Optional parseUnit(String unit) { + if (matchesAny(unit, BYTES)) { + return Optional.of(BYTES); + } else if (matchesAny(unit, KILO_BYTES)) { + return Optional.of(KILO_BYTES); + } else if (matchesAny(unit, MEGA_BYTES)) { + return Optional.of(MEGA_BYTES); + } else if (matchesAny(unit, GIGA_BYTES)) { + return Optional.of(GIGA_BYTES); + } else if (matchesAny(unit, TERA_BYTES)) { + return Optional.of(TERA_BYTES); + } else if (!unit.isEmpty()) { + throw new IllegalArgumentException( + "Memory size unit '" + + unit + + "' does not match any of the recognized units: " + + MemoryUnit.getAllUnits()); + } + + return Optional.empty(); + } + + private static boolean matchesAny(String str, MemoryUnit unit) { + for (String s : unit.getUnits()) { + if (s.equals(str)) { + return true; + } + } + return false; + } + + /** + * Enum which defines memory unit, mostly used to parse value from configuration file. + * + *

To make larger values more compact, the common size suffixes are supported: + * + *

    + *
  • 1b or 1bytes (bytes) + *
  • 1k or 1kb or 1kibibytes (interpreted as kibibytes = 1024 bytes) + *
  • 1m or 1mb or 1mebibytes (interpreted as mebibytes = 1024 kibibytes) + *
  • 1g or 1gb or 1gibibytes (interpreted as gibibytes = 1024 mebibytes) + *
  • 1t or 1tb or 1tebibytes (interpreted as tebibytes = 1024 gibibytes) + *
+ */ + public enum MemoryUnit { + BYTES(new String[] {"b", "bytes"}, 1L), + KILO_BYTES(new String[] {"k", "kb", "kibibytes"}, 1024L), + MEGA_BYTES(new String[] {"m", "mb", "mebibytes"}, 1024L * 1024L), + GIGA_BYTES(new String[] {"g", "gb", "gibibytes"}, 1024L * 1024L * 1024L), + TERA_BYTES(new String[] {"t", "tb", "tebibytes"}, 1024L * 1024L * 1024L * 1024L); + + private final String[] units; + + private final long multiplier; + + MemoryUnit(String[] units, long multiplier) { + this.units = units; + this.multiplier = multiplier; + } + + public String[] getUnits() { + return units; + } + + public long getMultiplier() { + return multiplier; + } + + public static String getAllUnits() { + return concatenateUnits( + BYTES.getUnits(), + KILO_BYTES.getUnits(), + MEGA_BYTES.getUnits(), + GIGA_BYTES.getUnits(), + TERA_BYTES.getUnits()); + } + + public static boolean hasUnit(String text) { + checkNotNull(text); + + final String trimmed = text.trim(); + checkArgument(!trimmed.isEmpty(), "argument is an empty- or whitespace-only string"); + + final int len = trimmed.length(); + int pos = 0; + + char current; + while (pos < len && (current = trimmed.charAt(pos)) >= '0' && current <= '9') { + pos++; + } + + final String unit = trimmed.substring(pos).trim().toLowerCase(Locale.US); + + return unit.length() > 0; + } + + private static String concatenateUnits(final String[]... allUnits) { + final StringBuilder builder = new StringBuilder(128); + + for (String[] units : allUnits) { + builder.append('('); + + for (String unit : units) { + builder.append(unit); + builder.append(" | "); + } + + builder.setLength(builder.length() - 3); + builder.append(") / "); + } + + builder.setLength(builder.length() - 3); + return builder.toString(); + } + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/config/StructuredOptionsSplitter.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/config/StructuredOptionsSplitter.java new file mode 100644 index 00000000..91d19397 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/config/StructuredOptionsSplitter.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.config; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** + * Helper class for splitting a string on a given delimiter with quoting logic. + * + *

This class is copied from Apache Flink + * (org.apache.flink.configuration.StructuredOptionsSplitter). + */ +class StructuredOptionsSplitter { + + /** + * Splits the given string on the given delimiter. It supports quoting parts of the string with + * either single (') or double quotes ("). Quotes can be escaped by doubling the quotes. + * + *

Examples: + * + *

    + *
  • 'A;B';C => [A;B], [C] + *
  • "AB'D";B;C => [AB'D], [B], [C] + *
  • "AB'""D;B";C => [AB'\"D;B], [C] + *
+ * + *

For more examples check the tests. + * + * @param string a string to split + * @param delimiter delimiter to split on + * @return a list of splits + */ + static List splitEscaped(String string, char delimiter) { + List tokens = tokenize(checkNotNull(string), delimiter); + return processTokens(tokens); + } + + /** + * Escapes the given string with single quotes, if the input string contains a double quote or + * any of the given {@code charsToEscape}. Any single quotes in the input string will be escaped + * by doubling. + * + *

Given that the escapeChar is (;) + * + *

Examples: + * + *

    + *
  • A,B,C,D => A,B,C,D + *
  • A'B'C'D => 'A''B''C''D' + *
  • A;BCD => 'A;BCD' + *
  • AB"C"D => 'AB"C"D' + *
  • AB'"D:B => 'AB''"D:B' + *
+ * + * @param string a string which needs to be escaped + * @param charsToEscape escape chars for the escape conditions + * @return escaped string by single quote + */ + static String escapeWithSingleQuote(String string, String... charsToEscape) { + boolean escape = + Arrays.stream(charsToEscape).anyMatch(string::contains) + || string.contains("\"") + || string.contains("'"); + + if (escape) { + return "'" + string.replaceAll("'", "''") + "'"; + } + + return string; + } + + private static List processTokens(List tokens) { + final List splits = new ArrayList<>(); + for (int i = 0; i < tokens.size(); i++) { + Token token = tokens.get(i); + switch (token.getTokenType()) { + case DOUBLE_QUOTED: + case SINGLE_QUOTED: + if (i + 1 < tokens.size() + && tokens.get(i + 1).getTokenType() != TokenType.DELIMITER) { + int illegalPosition = tokens.get(i + 1).getPosition() - 1; + throw new IllegalArgumentException( + "Could not split string. Illegal quoting at position: " + + illegalPosition); + } + splits.add(token.getString()); + break; + case UNQUOTED: + splits.add(token.getString()); + break; + case DELIMITER: + if (i + 1 < tokens.size() + && tokens.get(i + 1).getTokenType() == TokenType.DELIMITER) { + splits.add(""); + } + break; + } + } + + return splits; + } + + private static List tokenize(String string, char delimiter) { + final List tokens = new ArrayList<>(); + final StringBuilder builder = new StringBuilder(); + for (int cursor = 0; cursor < string.length(); ) { + final char c = string.charAt(cursor); + + int nextChar = cursor + 1; + if (c == '\'') { + nextChar = consumeInQuotes(string, '\'', cursor, builder); + tokens.add(new Token(TokenType.SINGLE_QUOTED, builder.toString(), cursor)); + } else if (c == '"') { + nextChar = consumeInQuotes(string, '"', cursor, builder); + tokens.add(new Token(TokenType.DOUBLE_QUOTED, builder.toString(), cursor)); + } else if (c == delimiter) { + tokens.add(new Token(TokenType.DELIMITER, String.valueOf(c), cursor)); + } else if (!Character.isWhitespace(c)) { + nextChar = consumeUnquoted(string, delimiter, cursor, builder); + tokens.add(new Token(TokenType.UNQUOTED, builder.toString().trim(), cursor)); + } + builder.setLength(0); + cursor = nextChar; + } + + return tokens; + } + + private static int consumeInQuotes( + String string, char quote, int cursor, StringBuilder builder) { + for (int i = cursor + 1; i < string.length(); i++) { + char c = string.charAt(i); + if (c == quote) { + if (i + 1 < string.length() && string.charAt(i + 1) == quote) { + builder.append(c); + i += 1; + } else { + return i + 1; + } + } else { + builder.append(c); + } + } + + throw new IllegalArgumentException( + "Could not split string. Quoting was not closed properly."); + } + + private static int consumeUnquoted( + String string, char delimiter, int cursor, StringBuilder builder) { + int i; + for (i = cursor; i < string.length(); i++) { + char c = string.charAt(i); + if (c == delimiter) { + return i; + } + + builder.append(c); + } + + return i; + } + + private enum TokenType { + DOUBLE_QUOTED, + SINGLE_QUOTED, + UNQUOTED, + DELIMITER + } + + private static class Token { + private final TokenType tokenType; + private final String string; + private final int position; + + private Token(TokenType tokenType, String string, int position) { + this.tokenType = tokenType; + this.string = string; + this.position = position; + } + + public TokenType getTokenType() { + return tokenType; + } + + public String getString() { + return string; + } + + public int getPosition() { + return position; + } + } + + private StructuredOptionsSplitter() {} +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/exception/ConfigurationException.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/exception/ConfigurationException.java new file mode 100644 index 00000000..880c395e --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/exception/ConfigurationException.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.exception; + +/** Exception to be thrown when any configuration error occurs. */ +public class ConfigurationException extends ShuffleException { + + private static final long serialVersionUID = 5012483304677960591L; + + public ConfigurationException(String message) { + super(message); + } + + public ConfigurationException(String message, Throwable throwable) { + super(message, throwable); + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/exception/ShuffleException.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/exception/ShuffleException.java new file mode 100644 index 00000000..39fa4b0a --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/exception/ShuffleException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.exception; + +/** Basic checked exception type of the flink remote shuffle service. */ +public class ShuffleException extends RuntimeException { + + private static final long serialVersionUID = 4354119805642345969L; + + public ShuffleException(Throwable cause) { + super(cause); + } + + public ShuffleException(String message) { + super(message); + } + + public ShuffleException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/functions/AutoCloseableAsync.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/functions/AutoCloseableAsync.java new file mode 100644 index 00000000..557b1dd4 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/functions/AutoCloseableAsync.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.functions; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +/** Closeable interface which allows to close a resource in a non blocking fashion. */ +public interface AutoCloseableAsync extends AutoCloseable { + + /** + * Trigger the closing of the resource and return the corresponding close future. + * + * @return Future which is completed once the resource has been closed + */ + CompletableFuture closeAsync(); + + @Override + default void close() throws Exception { + try { + closeAsync().get(); + } catch (ExecutionException exception) { + throw new ShuffleException( + "Could not close resource.", + ExceptionUtils.stripException(exception, ExecutionException.class)); + } + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/functions/BiConsumerWithException.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/functions/BiConsumerWithException.java new file mode 100644 index 00000000..37e4c0ca --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/functions/BiConsumerWithException.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.functions; + +/** Enhancement version of {@link java.util.function.BiConsumer} which can throw exceptions. */ +public interface BiConsumerWithException { + void accept(T var1, U var2) throws E; +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/functions/ConsumerWithException.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/functions/ConsumerWithException.java new file mode 100644 index 00000000..3f636832 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/functions/ConsumerWithException.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.functions; + +/** Enhancement version of {@link java.util.function.Consumer} which can throw exceptions. */ +public interface ConsumerWithException { + void accept(T var) throws E; +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/functions/RunnableWithException.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/functions/RunnableWithException.java new file mode 100644 index 00000000..740d2e7d --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/functions/RunnableWithException.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.functions; + +/** Enhancement version of {@link Runnable} which can throw exceptions. */ +public interface RunnableWithException { + + void run() throws Exception; +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/functions/SupplierWithException.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/functions/SupplierWithException.java new file mode 100644 index 00000000..98fc62a8 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/functions/SupplierWithException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.functions; + +/** + * A functional interface for a {@link java.util.function.Supplier} that may throw exceptions. + * + * @param The type of the result of the supplier. + * @param The type of Exceptions thrown by this function. + */ +@FunctionalInterface +public interface SupplierWithException { + + /** + * Gets the result of this supplier. + * + * @return The result of thus supplier. + * @throws E This function may throw an exception. + */ + R get() throws E; +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/handler/FatalErrorHandler.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/handler/FatalErrorHandler.java new file mode 100644 index 00000000..2dcf41b8 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/handler/FatalErrorHandler.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.handler; + +/** Handler interface for fatal error. */ +public interface FatalErrorHandler { + + void onFatalError(Throwable throwable); +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/CommonUtils.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/CommonUtils.java new file mode 100644 index 00000000..7ea01c21 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/CommonUtils.java @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import com.alibaba.flink.shuffle.common.functions.RunnableWithException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.function.Supplier; + +/** Utility methods can be used by all modules. */ +public class CommonUtils { + + private static final Logger LOG = LoggerFactory.getLogger(CommonUtils.class); + + private static final char[] HEX_CHARS = { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' + }; + + private static final int DEFAULT_RETRY_TIMES = 3; + + private static final ByteOrder DEFAULT_BYTE_ORDER = ByteOrder.BIG_ENDIAN; + + /** + * Ensures that the target object is not null and returns it. It will throw {@link + * NullPointerException} if the target object is null. + */ + public static T checkNotNull(T object) { + if (object == null) { + throw new NullPointerException("Must be not null."); + } + return object; + } + + /** + * Check the legality of method arguments. It will throw {@link IllegalArgumentException} if the + * given condition is not true. + */ + public static void checkArgument(boolean condition, @Nullable String message) { + if (!condition) { + throw new IllegalArgumentException(message); + } + } + + /** + * Check the legality of method arguments. It will throw {@link IllegalArgumentException} if the + * given condition is not true. + */ + public static void checkArgument(boolean condition) { + if (!condition) { + throw new IllegalArgumentException("Illegal argument."); + } + } + + /** + * Checks the legality of program state. It will throw {@link IllegalStateException} if the + * given condition is not true. + */ + public static void checkState(boolean condition, @Nullable String message) { + if (!condition) { + throw new IllegalStateException(message); + } + } + + /** + * Checks the legality of program state. It will throw {@link IllegalStateException} if the + * given condition is not true. + */ + public static void checkState(boolean condition, Supplier message) { + if (!condition) { + throw new IllegalStateException(message.get()); + } + } + + /** + * Checks the legality of program state. It will throw {@link IllegalStateException} if the + * given condition is not true. + */ + public static void checkState(boolean condition) { + if (!condition) { + throw new IllegalStateException("Illegal state."); + } + } + + /** Exists the current process and logs the error when any unrecoverable exception occurs. */ + public static void exitOnFatalError(Throwable throwable) { + LOG.error("Exiting on fatal error.", throwable); + FatalErrorExitUtils.exitProcessIfNeeded(-101); + } + + /** Generates a random byte array of the given length. */ + public static byte[] randomBytes(int length) { + checkArgument(length > 0, "Must be positive."); + + Random random = new Random(); + byte[] bytes = new byte[length]; + random.nextBytes(bytes); + return bytes; + } + + /** Converts the given byte array to a printable hex string. */ + public static String bytesToHexString(byte[] bytes) { + checkArgument(bytes != null, "Must be not null."); + + char[] chars = new char[bytes.length * 2]; + + for (int i = 0; i < chars.length; i += 2) { + int index = i >>> 1; + chars[i] = HEX_CHARS[(0xF0 & bytes[index]) >>> 4]; + chars[i + 1] = HEX_CHARS[0x0F & bytes[index]]; + } + + return new String(chars); + } + + public static byte[] hexStringToBytes(String hexString) { + byte[] bytes = new byte[hexString.length() / 2]; + for (int i = 0; i < hexString.length(); i += 2) { + byte high = Byte.parseByte(hexString.charAt(i) + "", 16); + byte low = Byte.parseByte(hexString.charAt(i + 1) + "", 16); + bytes[i / 2] = (byte) ((high << 4) | low); + } + + return bytes; + } + + /** Generates a random hex string of the given length. */ + public static String randomHexString(int length) { + checkArgument(length > 0, "Must be positive."); + + char[] chars = new char[length]; + Random random = new Random(); + + for (int i = 0; i < length; ++i) { + chars[i] = HEX_CHARS[random.nextInt(HEX_CHARS.length)]; + } + + return new String(chars); + } + + public static byte[] concatByteArrays(byte[]... byteArrays) { + int totalLength = + Arrays.stream(byteArrays).map(array -> 4 + array.length).reduce(0, Integer::sum); + ByteBuffer buffer = allocateHeapByteBuffer(totalLength); + for (byte[] array : byteArrays) { + buffer.putInt(array.length); + buffer.put(array); + } + + return buffer.array(); + } + + public static List splitByteArrays(byte[] concatArray) { + List arrays = new ArrayList<>(); + ByteBuffer buffer = ByteBuffer.wrap(concatArray); + while (buffer.hasRemaining()) { + int length = buffer.getInt(); + byte[] array = new byte[length]; + buffer.get(array); + arrays.add(array); + } + + return arrays; + } + + public static byte[] intToBytes(int i) { + ByteBuffer bb = allocateHeapByteBuffer(4); + bb.putInt(i); + return bb.array(); + } + + public static byte[] longToBytes(long i) { + ByteBuffer bb = allocateHeapByteBuffer(8); + bb.putLong(i); + return bb.array(); + } + + public static boolean isValidHostPort(int port) { + return 0 <= port && port <= 65535; + } + + /** + * Runs the given {@link RunnableWithException} in current thread silently and do nothing even + * when any exception occurs. + */ + public static void runQuietly(@Nullable RunnableWithException runnable) { + runQuietly(runnable, false); + } + + /** + * Runs the given {@link RunnableWithException} in current thread and may log the encountered + * exception if any. + */ + public static void runQuietly(@Nullable RunnableWithException runnable, boolean logFailure) { + if (runnable == null) { + return; + } + + try { + runnable.run(); + } catch (Throwable throwable) { + if (logFailure) { + LOG.warn("Failed to run task.", throwable); + } + } + } + + /** + * Closes the target {@link AutoCloseable} and retries a maximum of {@link #DEFAULT_RETRY_TIMES} + * times. It will throw exception if still fails after that. + */ + public static void closeWithRetry(@Nullable AutoCloseable closeable) throws Exception { + closeWithRetry(closeable, DEFAULT_RETRY_TIMES); + } + + /** + * Closes the target {@link AutoCloseable} and retries a maximum of the given times. It will + * throw exception if still fails after that. + */ + public static void closeWithRetry(@Nullable AutoCloseable closeable, int retryTimes) + throws Exception { + Throwable exception = null; + for (int i = 0; i < retryTimes; ++i) { + try { + if (closeable != null) { + closeable.close(); + } + return; + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + } + } + + if (exception != null) { + ExceptionUtils.rethrowException(exception); + } + } + + /** + * Deletes the target file and retries a maximum of {@link #DEFAULT_RETRY_TIMES} times. It will + * throw exception if still fails after that. + */ + public static void deleteFileWithRetry(@Nullable Path path) throws Exception { + deleteFileWithRetry(path, DEFAULT_RETRY_TIMES); + } + + /** + * Deletes the target file and retries a maximum of the given times. It will throw exception if + * still fails after that. + */ + public static void deleteFileWithRetry(@Nullable Path path, int retryTimes) throws Exception { + Throwable exception = null; + for (int i = 0; i < retryTimes; ++i) { + try { + if (path != null) { + Files.deleteIfExists(path); + } + return; + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + } + } + + if (exception != null) { + ExceptionUtils.rethrowException(exception); + } + } + + /** Allocates a piece of unmanaged heap {@link ByteBuffer} of the given size. */ + public static ByteBuffer allocateHeapByteBuffer(int size) { + checkArgument(size >= 0, "Must be non-negative."); + + ByteBuffer buffer = ByteBuffer.allocate(size); + buffer.order(DEFAULT_BYTE_ORDER); + return buffer; + } + + /** Allocates a piece of unmanaged direct {@link ByteBuffer} of the given size. */ + public static ByteBuffer allocateDirectByteBuffer(int size) { + checkArgument(size >= 0, "Must be non-negative."); + + ByteBuffer buffer = ByteBuffer.allocateDirect(size); + buffer.order(DEFAULT_BYTE_ORDER); + return buffer; + } + + /** Casts the given long value to int and ensures there is no loss. */ + public static int checkedDownCast(long value) { + int downCast = (int) value; + if ((long) downCast != value) { + throw new IllegalArgumentException( + "Cannot downcast long value " + value + " to integer."); + } + + return downCast; + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/ExceptionUtils.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/ExceptionUtils.java new file mode 100644 index 00000000..6871299f --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/ExceptionUtils.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import javax.annotation.Nullable; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.Optional; + +/** Utility for manipulating exceptions. */ +public class ExceptionUtils { + + /** + * Checks whether the given exception indicates a situation that may leave the JVM in a + * corrupted state, meaning a state where continued normal operation can only be guaranteed via + * clean process restart. + */ + public static boolean isJvmFatalError(Throwable t) { + return (t instanceof InternalError) + || (t instanceof UnknownError) + || (t instanceof ThreadDeath); + } + + /** + * Checks whether the given exception indicates a situation that may leave the JVM in a + * corrupted state, or an out-of-memory error. + */ + public static boolean isJvmFatalOrOutOfMemoryError(Throwable t) { + return isJvmFatalError(t) || t instanceof OutOfMemoryError; + } + + /** + * Checks whether the given exception indicates a JVM metaspace out-of-memory error. + * + * @param t The exception to check. + * @return True, if the exception is the metaspace {@link OutOfMemoryError}, false otherwise. + */ + public static boolean isMetaspaceOutOfMemoryError(@Nullable Throwable t) { + return isOutOfMemoryErrorWithMessageStartingWith(t, "Metaspace"); + } + + /** Checks whether the given exception indicates a JVM direct out-of-memory error. */ + public static boolean isDirectOutOfMemoryError(@Nullable Throwable t) { + return isOutOfMemoryErrorWithMessageStartingWith(t, "Direct buffer memory"); + } + + public static boolean isHeapSpaceOutOfMemoryError(@Nullable Throwable t) { + return isOutOfMemoryErrorWithMessageStartingWith(t, "Java heap space"); + } + + private static boolean isOutOfMemoryErrorWithMessageStartingWith( + @Nullable Throwable t, String prefix) { + // the exact matching of the class is checked to avoid matching any custom subclasses of + // OutOfMemoryError as we are interested in the original exceptions, generated by JVM. + return isOutOfMemoryError(t) && t.getMessage() != null && t.getMessage().startsWith(prefix); + } + + private static boolean isOutOfMemoryError(@Nullable Throwable t) { + return t != null && t.getClass() == OutOfMemoryError.class; + } + + /** Rethrows the target {@link Throwable} as {@link Error} or {@link RuntimeException}. */ + public static void rethrowAsRuntimeException(Throwable t) { + if (t instanceof Error) { + throw (Error) t; + } else if (t instanceof RuntimeException) { + throw (RuntimeException) t; + } else { + throw new RuntimeException(t); + } + } + + /** Rethrows the target {@link Throwable} as {@link Error} or {@link Exception}. */ + public static void rethrowException(Throwable throwable) throws Exception { + if (throwable instanceof Error) { + throw (Error) throwable; + } else { + throw (Exception) throwable; + } + } + + /** + * Unpacks an specified exception and returns its cause. Otherwise the given {@link Throwable} + * is returned. + */ + public static Throwable stripException( + Throwable throwableToStrip, Class typeToStrip) { + while (typeToStrip.isAssignableFrom(throwableToStrip.getClass()) + && throwableToStrip.getCause() != null) { + throwableToStrip = throwableToStrip.getCause(); + } + return throwableToStrip; + } + + /** Checks whether a throwable chain contains a specific type of exception and returns it. */ + public static Optional findThrowable( + Throwable throwable, Class searchType) { + if (throwable == null || searchType == null) { + return Optional.empty(); + } + + Throwable cause = throwable; + while (cause != null) { + if (searchType.isAssignableFrom(cause.getClass())) { + return Optional.of(searchType.cast(cause)); + } else { + cause = cause.getCause(); + } + } + return Optional.empty(); + } + + public static String summaryErrorMessageStack(Throwable t) { + StringBuilder sb = new StringBuilder(); + do { + if (sb.length() != 0) { + sb.append(" -> "); + } + sb.append("[") + .append(t.getClass().getName()) + .append(": ") + .append(t.getMessage()) + .append("]"); + } while ((t = t.getCause()) != null); + return sb.toString(); + } + + /** + * Makes a string representation of the exception's stack trace, or "(null)", if the exception + * is null. + */ + public static String stringifyException(Throwable exception) { + if (exception == null) { + return "(null)"; + } + + try { + StringWriter stm = new StringWriter(); + PrintWriter wrt = new PrintWriter(stm); + exception.printStackTrace(wrt); + wrt.close(); + return stm.toString(); + } catch (Throwable throwable) { + return exception.getClass().getName() + " (error while printing stack trace)"; + } + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/FatalErrorExitUtils.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/FatalErrorExitUtils.java new file mode 100644 index 00000000..aa7ea934 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/FatalErrorExitUtils.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * System.exit() need to be called when a fatal error is encountered. However, in some cases, such + * as deployment on yarn, System.exit() can't be called because it may affect other processes, e.g., + * Yarn Node Manager process. + * + *

In order to avoid affecting other processes, this class will use different strategies to deal + * with these fatal errors according to different deployment environments. + */ +public class FatalErrorExitUtils { + private static final Logger LOG = LoggerFactory.getLogger(FatalErrorExitUtils.class); + + private static volatile boolean needStopProcess = true; + + public static void setNeedStopProcess(boolean needStopProcess) { + FatalErrorExitUtils.needStopProcess = needStopProcess; + } + + /** Only for tests. */ + public static boolean isNeedStopProcess() { + return needStopProcess; + } + + public static void exitProcessIfNeeded(int exitCode) { + exitProcessIfNeeded(exitCode, null); + } + + public static void exitProcessIfNeeded(int exitCode, Throwable t) { + StringBuilder sb = + new StringBuilder("Stopping the process with code ").append(exitCode).append(". "); + sb.append("Whether the process should be exit? ").append(needStopProcess).append(". "); + if (!needStopProcess) { + sb.append("Ignore the stop operation and return directly."); + } + if (t != null) { + LOG.error(sb.toString(), t); + } else { + LOG.error(sb.toString()); + } + + if (needStopProcess) { + System.exit(exitCode); + } + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/FutureUtils.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/FutureUtils.java new file mode 100644 index 00000000..0e2b5176 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/FutureUtils.java @@ -0,0 +1,352 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +/** + * A collection of utilities that expand the usage of {@link CompletableFuture}. + * + *

This class is partly copied from Apache Flink + * (org.apache.flink.runtime.concurrent.FutureUtils). + */ +public class FutureUtils { + + /** + * Run the given asynchronous action after the completion of the given future. The given future + * can be completed normally or exceptionally. In case of an exceptional completion, the + * asynchronous action's exception will be added to the initial exception. + * + * @param future to wait for its completion + * @param composedAction asynchronous action which is triggered after the future's completion + * @return Future which is completed after the asynchronous action has completed. This future + * can contain an exception if an error occurred in the given future or asynchronous action. + */ + public static CompletableFuture composeAfterwards( + CompletableFuture future, Supplier> composedAction) { + final CompletableFuture resultFuture = new CompletableFuture<>(); + + future.whenComplete( + (Object outerIgnored, Throwable outerThrowable) -> { + final CompletableFuture composedActionFuture = composedAction.get(); + + composedActionFuture.whenComplete( + (Object innerIgnored, Throwable innerThrowable) -> { + if (innerThrowable != null) { + resultFuture.completeExceptionally( + outerThrowable == null + ? innerThrowable + : outerThrowable); + } else if (outerThrowable != null) { + resultFuture.completeExceptionally(outerThrowable); + } else { + resultFuture.complete(null); + } + }); + }); + + return resultFuture; + } + + /** + * Creates a {@link ConjunctFuture} which is only completed after all given futures have + * completed. Unlike {@link FutureUtils#waitForAll(Collection)}, the resulting future won't be + * completed directly if one of the given futures is completed exceptionally. Instead, all + * occurring exception will be collected and combined to a single exception. If at least on + * exception occurs, then the resulting future will be completed exceptionally. + * + * @param futuresToComplete futures to complete + * @return Future which is completed after all given futures have been completed. + */ + public static ConjunctFuture completeAll( + Collection> futuresToComplete) { + return new CompletionConjunctFuture(futuresToComplete); + } + + /** + * Returns an exceptionally completed {@link CompletableFuture}. + * + * @param cause to complete the future with + * @param type of the future + * @return An exceptionally completed CompletableFuture + */ + public static CompletableFuture completedExceptionally(Throwable cause) { + CompletableFuture result = new CompletableFuture<>(); + result.completeExceptionally(cause); + return result; + } + + /** + * Creates a future that is complete once multiple other futures completed. The future fails + * (completes exceptionally) once one of the futures in the conjunction fails. Upon successful + * completion, the future returns the collection of the futures' results. + * + *

The ConjunctFuture gives access to how many Futures in the conjunction have already + * completed successfully, via {@link ConjunctFuture#getNumFuturesCompleted()}. + * + * @param futures The futures that make up the conjunction. No null entries are allowed. + * @return The ConjunctFuture that completes once all given futures are complete (or one fails). + */ + public static ConjunctFuture> combineAll( + Collection> futures) { + CommonUtils.checkArgument(futures != null, "Must be not null."); + return new ResultConjunctFuture<>(futures); + } + + /** + * Creates a future that is complete once all of the given futures have completed. The future + * fails (completes exceptionally) once one of the given futures fails. + * + *

The ConjunctFuture gives access to how many Futures have already completed successfully, + * via {@link ConjunctFuture#getNumFuturesCompleted()}. + * + * @param futures The futures to wait on. No null entries are allowed. + * @return The WaitingFuture that completes once all given futures are complete (or one fails). + */ + public static ConjunctFuture waitForAll( + Collection> futures) { + CommonUtils.checkArgument(futures != null, "Must be not null."); + return new WaitingConjunctFuture(futures); + } + + /** + * A future that is complete once multiple other futures completed. The futures are not + * necessarily of the same type. The ConjunctFuture fails (completes exceptionally) once one of + * the Futures in the conjunction fails. + * + *

The advantage of using the ConjunctFuture over chaining all the futures (such as via + * {@link CompletableFuture#thenCombine(CompletionStage, BiFunction)} )}) is that ConjunctFuture + * also tracks how many of the Futures are already complete. + */ + public abstract static class ConjunctFuture extends CompletableFuture { + + /** + * Gets the total number of Futures in the conjunction. + * + * @return The total number of Futures in the conjunction. + */ + public abstract int getNumFuturesTotal(); + + /** + * Gets the number of Futures in the conjunction that are already complete. + * + * @return The number of Futures in the conjunction that are already complete + */ + public abstract int getNumFuturesCompleted(); + } + + /** + * Implementation of the {@link ConjunctFuture} interface which waits only for the completion of + * its futures and does not return their values. + */ + private static final class WaitingConjunctFuture extends ConjunctFuture { + + /** Number of completed futures. */ + private final AtomicInteger numCompleted = new AtomicInteger(0); + + /** Total number of futures to wait on. */ + private final int numTotal; + + /** + * Method which increments the atomic completion counter and completes or fails the + * WaitingFutureImpl. + */ + private void handleCompletedFuture(Object ignored, Throwable throwable) { + if (throwable == null) { + if (numTotal == numCompleted.incrementAndGet()) { + complete(null); + } + } else { + completeExceptionally(throwable); + } + } + + private WaitingConjunctFuture(Collection> futures) { + this.numTotal = futures.size(); + + if (futures.isEmpty()) { + complete(null); + } else { + for (java.util.concurrent.CompletableFuture future : futures) { + future.whenComplete(this::handleCompletedFuture); + } + } + } + + @Override + public int getNumFuturesTotal() { + return numTotal; + } + + @Override + public int getNumFuturesCompleted() { + return numCompleted.get(); + } + } + + /** + * The implementation of the {@link ConjunctFuture} which returns its Futures' result as a + * collection. + */ + private static class ResultConjunctFuture extends ConjunctFuture> { + + /** The total number of futures in the conjunction. */ + private final int numTotal; + + /** The number of futures in the conjunction that are already complete. */ + private final AtomicInteger numCompleted = new AtomicInteger(0); + + /** The set of collected results so far. */ + private final T[] results; + + /** + * The function that is attached to all futures in the conjunction. Once a future is + * complete, this function tracks the completion or fails the conjunct. + */ + private void handleCompletedFuture(int index, T value, Throwable throwable) { + if (throwable != null) { + completeExceptionally(throwable); + } else { + /** + * This {@link #results} update itself is not synchronised in any way and it's fine + * because: + * + *

    + *
  • There is a happens-before relationship for each thread (that is completing + * the future) between setting {@link #results} and incrementing {@link + * #numCompleted}. + *
  • Each thread is updating uniquely different field of the {@link #results} + * array. + *
  • There is a happens-before relationship between all of the writing threads + * and the last one thread (thanks to the {@code + * numCompleted.incrementAndGet() == numTotal} check. + *
  • The last thread will be completing the future, so it has transitively + * happens-before relationship with all of preceding updated/writes to {@link + * #results}. + *
  • {@link AtomicInteger#incrementAndGet} is an equivalent of both volatile + * read & write + *
+ */ + results[index] = value; + + if (numCompleted.incrementAndGet() == numTotal) { + complete(Arrays.asList(results)); + } + } + } + + @SuppressWarnings("unchecked") + ResultConjunctFuture(Collection> resultFutures) { + this.numTotal = resultFutures.size(); + results = (T[]) new Object[numTotal]; + + if (resultFutures.isEmpty()) { + complete(Collections.emptyList()); + } else { + int counter = 0; + for (CompletableFuture future : resultFutures) { + final int index = counter; + counter++; + future.whenComplete( + (value, throwable) -> handleCompletedFuture(index, value, throwable)); + } + } + } + + @Override + public int getNumFuturesTotal() { + return numTotal; + } + + @Override + public int getNumFuturesCompleted() { + return numCompleted.get(); + } + } + + /** + * {@link ConjunctFuture} implementation which is completed after all the given futures have + * been completed. Exceptional completions of the input futures will be recorded but it won't + * trigger the early completion of this future. + */ + private static final class CompletionConjunctFuture extends ConjunctFuture { + + private final Object lock = new Object(); + + private final int numFuturesTotal; + + private int futuresCompleted; + + private Throwable globalThrowable; + + private CompletionConjunctFuture( + Collection> futuresToComplete) { + numFuturesTotal = futuresToComplete.size(); + + futuresCompleted = 0; + + globalThrowable = null; + + if (futuresToComplete.isEmpty()) { + complete(null); + } else { + for (CompletableFuture completableFuture : futuresToComplete) { + completableFuture.whenComplete(this::completeFuture); + } + } + } + + private void completeFuture(Object ignored, Throwable throwable) { + synchronized (lock) { + futuresCompleted++; + + if (throwable != null) { + globalThrowable = globalThrowable == null ? throwable : globalThrowable; + } + + if (futuresCompleted == numFuturesTotal) { + if (globalThrowable != null) { + completeExceptionally(globalThrowable); + } else { + complete(null); + } + } + } + } + + @Override + public int getNumFuturesTotal() { + return numFuturesTotal; + } + + @Override + public int getNumFuturesCompleted() { + synchronized (lock) { + return futuresCompleted; + } + } + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/Hardware.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/Hardware.java new file mode 100644 index 00000000..c01aa3d8 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/Hardware.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.lang.management.ManagementFactory; +import java.lang.management.OperatingSystemMXBean; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.nio.charset.StandardCharsets; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Convenience class to extract hardware specifics of the computer executing the running JVM. + * + *

This class is copied from Apache Flink (org.apache.flink.runtime.util.Hardware). + */ +public class Hardware { + + private static final Logger LOG = LoggerFactory.getLogger(Hardware.class); + + private static final String LINUX_MEMORY_INFO_PATH = "/proc/meminfo"; + + private static final Pattern LINUX_MEMORY_REGEX = + Pattern.compile("^MemTotal:\\s*(\\d+)\\s+kB$"); + + // ------------------------------------------------------------------------ + + /** + * Returns the size of the physical memory in bytes. + * + * @return the size of the physical memory in bytes or {@code -1}, if the size could not be + * determined. + */ + public static long getSizeOfPhysicalMemory() { + // first try if the JVM can directly tell us what the system memory is + // this works only on Oracle JVMs + try { + Class clazz = Class.forName("com.sun.management.OperatingSystemMXBean"); + Method method = clazz.getMethod("getTotalPhysicalMemorySize"); + OperatingSystemMXBean operatingSystemMXBean = + ManagementFactory.getOperatingSystemMXBean(); + + // someone may install different beans, so we need to check whether the bean + // is in fact the sun management bean + if (clazz.isInstance(operatingSystemMXBean)) { + return (Long) method.invoke(operatingSystemMXBean); + } + } catch (ClassNotFoundException e) { + // this happens on non-Oracle JVMs, do nothing and use the alternative code paths + } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + LOG.warn( + "Access to physical memory size: " + + "com.sun.management.OperatingSystemMXBean incompatibly changed.", + e); + } + + // we now try the OS specific access paths + switch (OperatingSystem.getCurrentOperatingSystem()) { + case LINUX: + return getSizeOfPhysicalMemoryForLinux(); + + case WINDOWS: + return getSizeOfPhysicalMemoryForWindows(); + + case MAC_OS: + return getSizeOfPhysicalMemoryForMac(); + + case FREE_BSD: + return getSizeOfPhysicalMemoryForFreeBSD(); + + case UNKNOWN: + LOG.error("Cannot determine size of physical memory for unknown operating system"); + return -1; + + default: + LOG.error("Unrecognized OS: " + OperatingSystem.getCurrentOperatingSystem()); + return -1; + } + } + + /** + * Returns the size of the physical memory in bytes on a Linux-based operating system. + * + * @return the size of the physical memory in bytes or {@code -1}, if the size could not be + * determined + */ + private static long getSizeOfPhysicalMemoryForLinux() { + try (BufferedReader lineReader = + new BufferedReader(new FileReader(LINUX_MEMORY_INFO_PATH))) { + String line; + while ((line = lineReader.readLine()) != null) { + Matcher matcher = LINUX_MEMORY_REGEX.matcher(line); + if (matcher.matches()) { + String totalMemory = matcher.group(1); + return Long.parseLong(totalMemory) * 1024L; // Convert from kilobyte to byte + } + } + // expected line did not come + LOG.error( + "Cannot determine the size of the physical memory for Linux host (using '/proc/meminfo'). " + + "Unexpected format."); + return -1; + } catch (NumberFormatException e) { + LOG.error( + "Cannot determine the size of the physical memory for Linux host (using '/proc/meminfo'). " + + "Unexpected format."); + return -1; + } catch (Throwable t) { + LOG.error( + "Cannot determine the size of the physical memory for Linux host (using '/proc/meminfo') ", + t); + return -1; + } + } + + /** + * Returns the size of the physical memory in bytes on a Mac OS-based operating system. + * + * @return the size of the physical memory in bytes or {@code -1}, if the size could not be + * determined + */ + private static long getSizeOfPhysicalMemoryForMac() { + BufferedReader bi = null; + try { + Process proc = Runtime.getRuntime().exec("sysctl hw.memsize"); + + bi = + new BufferedReader( + new InputStreamReader(proc.getInputStream(), StandardCharsets.UTF_8)); + + String line; + while ((line = bi.readLine()) != null) { + if (line.startsWith("hw.memsize")) { + long memsize = Long.parseLong(line.split(":")[1].trim()); + bi.close(); + proc.destroy(); + return memsize; + } + } + + } catch (Throwable t) { + LOG.error("Cannot determine physical memory of machine for MacOS host", t); + return -1; + } finally { + if (bi != null) { + try { + bi.close(); + } catch (IOException ignored) { + } + } + } + return -1; + } + + /** + * Returns the size of the physical memory in bytes on FreeBSD. + * + * @return the size of the physical memory in bytes or {@code -1}, if the size could not be + * determined + */ + private static long getSizeOfPhysicalMemoryForFreeBSD() { + BufferedReader bi = null; + try { + Process proc = Runtime.getRuntime().exec("sysctl hw.physmem"); + + bi = + new BufferedReader( + new InputStreamReader(proc.getInputStream(), StandardCharsets.UTF_8)); + + String line; + while ((line = bi.readLine()) != null) { + if (line.startsWith("hw.physmem")) { + long memsize = Long.parseLong(line.split(":")[1].trim()); + bi.close(); + proc.destroy(); + return memsize; + } + } + + LOG.error( + "Cannot determine the size of the physical memory for FreeBSD host " + + "(using 'sysctl hw.physmem')."); + return -1; + } catch (Throwable t) { + LOG.error( + "Cannot determine the size of the physical memory for FreeBSD host " + + "(using 'sysctl hw.physmem')", + t); + return -1; + } finally { + if (bi != null) { + try { + bi.close(); + } catch (IOException ignored) { + } + } + } + } + + /** + * Returns the size of the physical memory in bytes on Windows. + * + * @return the size of the physical memory in bytes or {@code -1}, if the size could not be + * determined + */ + private static long getSizeOfPhysicalMemoryForWindows() { + BufferedReader bi = null; + try { + Process proc = Runtime.getRuntime().exec("wmic memorychip get capacity"); + + bi = + new BufferedReader( + new InputStreamReader(proc.getInputStream(), StandardCharsets.UTF_8)); + + String line = bi.readLine(); + if (line == null) { + return -1L; + } + + if (!line.startsWith("Capacity")) { + return -1L; + } + + long sizeOfPhyiscalMemory = 0L; + while ((line = bi.readLine()) != null) { + if (line.isEmpty()) { + continue; + } + + line = line.replaceAll(" ", ""); + sizeOfPhyiscalMemory += Long.parseLong(line); + } + return sizeOfPhyiscalMemory; + } catch (Throwable t) { + LOG.error( + "Cannot determine the size of the physical memory for Windows host " + + "(using 'wmic memorychip')", + t); + return -1L; + } finally { + if (bi != null) { + try { + bi.close(); + } catch (Throwable ignored) { + } + } + } + } + + // -------------------------------------------------------------------------------------------- + + private Hardware() {} +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/JvmShutdownSafeguard.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/JvmShutdownSafeguard.java new file mode 100644 index 00000000..cc0f57a8 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/JvmShutdownSafeguard.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import org.slf4j.Logger; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; + +/** + * A utility that guards against blocking shutdown hooks that block JVM shutdown. + * + *

When the JVM shuts down cleanly (SIGTERM or {@link System#exit(int)}) it runs all + * installed shutdown hooks. It is possible that any of the shutdown hooks blocks, which causes the + * JVM to get stuck and not exit at all. + * + *

This utility installs a shutdown hook that forcibly terminates the JVM if it is still alive a + * certain time after clean shutdown was initiated. Even if some shutdown hooks block, the JVM will + * terminate within a certain time. + * + *

This class is copied from Apache Flink (org.apache.flink.runtime.util.JvmShutdownSafeguard). + */ +public class JvmShutdownSafeguard extends Thread { + + /** + * Default delay to wait after clean shutdown was stared, before forcibly terminating the JVM. + */ + private static final long DEFAULT_DELAY = 5000L; + + /** The exit code returned by the JVM process if it is killed by the safeguard. */ + private static final int EXIT_CODE = -17; + + /** The thread that actually does the termination. */ + private final Thread terminator; + + private JvmShutdownSafeguard(long delayMillis) { + setName("JVM Terminator Launcher"); + + this.terminator = new Thread(new DelayedTerminator(delayMillis), "Jvm Terminator"); + this.terminator.setDaemon(true); + } + + @Override + public void run() { + // Because this thread is registered as a shutdown hook, we cannot + // wait here and then call for termination. That would always delay the JVM shutdown. + // Instead, we spawn a non shutdown hook thread from here. + // That thread is a daemon, so it does not keep the JVM alive. + terminator.start(); + } + + // ------------------------------------------------------------------------ + // The actual Shutdown thread + // ------------------------------------------------------------------------ + + private static class DelayedTerminator implements Runnable { + + private final long delayMillis; + + private DelayedTerminator(long delayMillis) { + this.delayMillis = delayMillis; + } + + @Override + public void run() { + try { + Thread.sleep(delayMillis); + } catch (Throwable t) { + // catch all, including thread death, etc + } + + Runtime.getRuntime().halt(EXIT_CODE); + } + } + + // ------------------------------------------------------------------------ + // Installing as a shutdown hook + // ------------------------------------------------------------------------ + + /** + * Installs the safeguard shutdown hook. The maximum time that the JVM is allowed to spend on + * shutdown before being killed is five seconds. + * + * @param logger The logger to log errors to. + */ + public static void installAsShutdownHook(Logger logger) { + installAsShutdownHook(logger, DEFAULT_DELAY); + } + + /** + * Installs the safeguard shutdown hook. The maximum time that the JVM is allowed to spend on + * shutdown before being killed is the given number of milliseconds. + * + * @param logger The logger to log errors to. + * @param delayMillis The delay (in milliseconds) to wait after clean shutdown was stared, + * before forcibly terminating the JVM. + */ + public static void installAsShutdownHook(Logger logger, long delayMillis) { + checkArgument(delayMillis >= 0, "delay must be >= 0"); + + // install the blocking shutdown hook + Thread shutdownHook = new JvmShutdownSafeguard(delayMillis); + ShutdownHookUtil.addShutdownHookThread( + shutdownHook, JvmShutdownSafeguard.class.getSimpleName(), logger); + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/OperatingSystem.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/OperatingSystem.java new file mode 100644 index 00000000..8028dec1 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/OperatingSystem.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +/** + * An enumeration indicating the operating system that the JVM runs on. + * + *

This class is copied from Apache Flink (org.apache.flink.util.OperatingSystem). + */ +public enum OperatingSystem { + LINUX, + WINDOWS, + MAC_OS, + FREE_BSD, + SOLARIS, + UNKNOWN; + + // ------------------------------------------------------------------------ + + /** + * Gets the operating system that the JVM runs on from the java system properties. this method + * returns UNKNOWN, if the operating system was not successfully determined. + * + * @return The enum constant for the operating system, or UNKNOWN, if it was not + * possible to determine. + */ + public static OperatingSystem getCurrentOperatingSystem() { + return os; + } + + /** + * Checks whether the operating system this JVM runs on is Windows. + * + * @return true if the operating system this JVM runs on is Windows, false + * otherwise + */ + public static boolean isWindows() { + return getCurrentOperatingSystem() == WINDOWS; + } + + /** + * Checks whether the operating system this JVM runs on is Linux. + * + * @return true if the operating system this JVM runs on is Linux, false + * otherwise + */ + public static boolean isLinux() { + return getCurrentOperatingSystem() == LINUX; + } + + /** + * Checks whether the operating system this JVM runs on is Windows. + * + * @return true if the operating system this JVM runs on is Windows, false + * otherwise + */ + public static boolean isMac() { + return getCurrentOperatingSystem() == MAC_OS; + } + + /** + * Checks whether the operating system this JVM runs on is FreeBSD. + * + * @return true if the operating system this JVM runs on is FreeBSD, false + * otherwise + */ + public static boolean isFreeBSD() { + return getCurrentOperatingSystem() == FREE_BSD; + } + + /** + * Checks whether the operating system this JVM runs on is Solaris. + * + * @return true if the operating system this JVM runs on is Solaris, false + * otherwise + */ + public static boolean isSolaris() { + return getCurrentOperatingSystem() == SOLARIS; + } + + /** The enum constant for the operating system. */ + private static final OperatingSystem os = readOSFromSystemProperties(); + + /** + * Parses the operating system that the JVM runs on from the java system properties. If the + * operating system was not successfully determined, this method returns {@code UNKNOWN}. + * + * @return The enum constant for the operating system, or {@code UNKNOWN}, if it was not + * possible to determine. + */ + private static OperatingSystem readOSFromSystemProperties() { + String osName = System.getProperty(OS_KEY); + + if (osName.startsWith(LINUX_OS_PREFIX)) { + return LINUX; + } + if (osName.startsWith(WINDOWS_OS_PREFIX)) { + return WINDOWS; + } + if (osName.startsWith(MAC_OS_PREFIX)) { + return MAC_OS; + } + if (osName.startsWith(FREEBSD_OS_PREFIX)) { + return FREE_BSD; + } + String osNameLowerCase = osName.toLowerCase(); + if (osNameLowerCase.contains(SOLARIS_OS_INFIX_1) + || osNameLowerCase.contains(SOLARIS_OS_INFIX_2)) { + return SOLARIS; + } + + return UNKNOWN; + } + + // -------------------------------------------------------------------------------------------- + // Constants to extract the OS type from the java environment + // -------------------------------------------------------------------------------------------- + + /** The key to extract the operating system name from the system properties. */ + private static final String OS_KEY = "os.name"; + + /** The expected prefix for Linux operating systems. */ + private static final String LINUX_OS_PREFIX = "Linux"; + + /** The expected prefix for Windows operating systems. */ + private static final String WINDOWS_OS_PREFIX = "Windows"; + + /** The expected prefix for Mac OS operating systems. */ + private static final String MAC_OS_PREFIX = "Mac"; + + /** The expected prefix for FreeBSD. */ + private static final String FREEBSD_OS_PREFIX = "FreeBSD"; + + /** One expected infix for Solaris. */ + private static final String SOLARIS_OS_INFIX_1 = "sunos"; + + /** One expected infix for Solaris. */ + private static final String SOLARIS_OS_INFIX_2 = "solaris"; +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/ProcessUtils.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/ProcessUtils.java new file mode 100644 index 00000000..a9279233 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/ProcessUtils.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import java.lang.management.ManagementFactory; + +/** Utilities related to the properties of the processes. */ +public class ProcessUtils { + + public static int getProcessID() { + return Integer.parseInt(ManagementFactory.getRuntimeMXBean().getName().split("@")[0]); + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/ProtocolUtils.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/ProtocolUtils.java new file mode 100644 index 00000000..144831b5 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/ProtocolUtils.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import static com.alibaba.flink.shuffle.common.utils.StringUtils.stringToBytes; + +/** Utils to handle versions and protocol compatibility. */ +public class ProtocolUtils { + + private static final int CURRENT_VERSION = 0; + + private static final int EMPTY_BUFFER_SIZE = Integer.MAX_VALUE; + + private static final long EMPTY_OFFSET = -2; + + private static final String EMPTY_EXTRA_MESSAGE = "{}"; + + private static final String EMPTY_DATA_PARTITION_TYPE_FACTORY = "emptyDataPartitionTypeFactory"; + + /** Returns the current protocol version used. */ + public static int currentProtocolVersion() { + return CURRENT_VERSION; + } + + /** Returns the number of empty byte buffers. */ + public static int emptyBufferSize() { + return EMPTY_BUFFER_SIZE; + } + + /** Returns the empty offset value. */ + public static long emptyOffset() { + return EMPTY_OFFSET; + } + + /** Returns the empty string of data partition type factory. */ + public static String emptyDataPartitionType() { + return EMPTY_DATA_PARTITION_TYPE_FACTORY; + } + + /** Returns the empty extra message. */ + public static String emptyExtraMessage() { + return EMPTY_EXTRA_MESSAGE; + } + + /** Returns the empty extra message bytes. */ + public static byte[] emptyExtraMessageBytes() { + return stringToBytes(EMPTY_EXTRA_MESSAGE); + } + + /** + * Returns true if the client protocol version is compatible with the server protocol version, + * including both control flow and data flow. This method is used at the server side. + */ + public static boolean isClientProtocolCompatible(int clientProtocolVersion) { + return compatibleVersion() <= clientProtocolVersion; + } + + /** + * Returns true if the client protocol version is compatible with the server protocol version, + * including both control flow and data flow. This method is used at the client side. + */ + public static boolean isServerProtocolCompatible( + int serverProtocolVersion, int serverCompatibleVersion) { + return currentProtocolVersion() <= serverProtocolVersion + && serverCompatibleVersion <= currentProtocolVersion(); + } + + /** Returns the minimum supported version compatible with the current version. */ + public static int compatibleVersion() { + return 0; + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/ShutdownHookUtil.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/ShutdownHookUtil.java new file mode 100644 index 00000000..5af7466c --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/ShutdownHookUtil.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import org.slf4j.Logger; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** + * Utils class for dealing with JVM shutdown hooks. + * + *

This class is copied from Apache Flink (org.apache.flink.util.ShutdownHookUtil). + */ +public class ShutdownHookUtil { + + /** Adds a shutdown hook to the JVM and returns the Thread, which has been registered. */ + public static Thread addShutdownHook( + final AutoCloseable service, final String serviceName, final Logger logger) { + + checkNotNull(service); + checkNotNull(logger); + + final Thread shutdownHook = + new Thread( + () -> { + try { + service.close(); + } catch (Throwable t) { + logger.error( + "Error during shutdown of {} via JVM shutdown hook.", + serviceName, + t); + } + }, + serviceName + " shutdown hook"); + + return addShutdownHookThread(shutdownHook, serviceName, logger) ? shutdownHook : null; + } + + /** + * Adds a shutdown hook to the JVM. + * + * @param shutdownHook Shutdown hook to be registered. + * @param serviceName The name of service. + * @param logger The logger to log. + * @return Whether the hook has been successfully registered. + */ + public static boolean addShutdownHookThread( + final Thread shutdownHook, final String serviceName, final Logger logger) { + + checkNotNull(shutdownHook); + checkNotNull(logger); + + try { + // Add JVM shutdown hook to call shutdown of service + Runtime.getRuntime().addShutdownHook(shutdownHook); + return true; + } catch (IllegalStateException e) { + // JVM is already shutting down. no need to do our work + } catch (Throwable t) { + logger.error( + "Cannot register shutdown hook that cleanly terminates {}.", serviceName, t); + } + return false; + } + + /** Removes a shutdown hook from the JVM. */ + public static void removeShutdownHook( + final Thread shutdownHook, final String serviceName, final Logger logger) { + + // Do not run if this is invoked by the shutdown hook itself + if (shutdownHook == null || shutdownHook == Thread.currentThread()) { + return; + } + + checkNotNull(logger); + + try { + Runtime.getRuntime().removeShutdownHook(shutdownHook); + } catch (IllegalStateException e) { + // race, JVM is in shutdown already, we can safely ignore this + logger.debug( + "Unable to remove shutdown hook for {}, shutdown already in progress", + serviceName, + e); + } catch (Throwable t) { + logger.warn("Exception while un-registering {}'s shutdown hook.", serviceName, t); + } + } + + private ShutdownHookUtil() { + throw new AssertionError(); + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/SignalHandler.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/SignalHandler.java new file mode 100644 index 00000000..938a2773 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/SignalHandler.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import org.slf4j.Logger; +import sun.misc.Signal; + +/** + * This signal handler / signal logger is based on Apache Hadoop's + * org.apache.hadoop.util.SignalLogger. + * + *

This class is copied from Apache Flink (org.apache.flink.runtime.util.SignalHandler). + */ +public class SignalHandler { + + private static boolean registered = false; + + /** Our signal handler. */ + private static class Handler implements sun.misc.SignalHandler { + + private final Logger log; + private final sun.misc.SignalHandler prevHandler; + + Handler(String name, Logger log) { + this.log = log; + prevHandler = Signal.handle(new Signal(name), this); + } + + /** + * Handle an incoming signal. + * + * @param signal The incoming signal + */ + @Override + public void handle(Signal signal) { + log.info( + "RECEIVED SIGNAL {}: SIG{}. Shutting down as requested.", + signal.getNumber(), + signal.getName()); + prevHandler.handle(signal); + } + } + + /** + * Register some signal handlers. + * + * @param log The slf4j logger + */ + public static void register(final Logger log) { + synchronized (SignalHandler.class) { + if (registered) { + return; + } + registered = true; + + final String[] signals = + OperatingSystem.isWindows() + ? new String[] {"TERM", "INT"} + : new String[] {"TERM", "HUP", "INT"}; + + StringBuilder bld = new StringBuilder(); + bld.append("Registered UNIX signal handlers for ["); + + String separator = ""; + for (String signalName : signals) { + try { + new Handler(signalName, log); + bld.append(separator); + bld.append(signalName); + separator = ", "; + } catch (Exception e) { + log.info("Error while registering signal handler", e); + } + } + bld.append("]"); + log.info(bld.toString()); + } + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/SingleThreadExecutorValidator.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/SingleThreadExecutorValidator.java new file mode 100644 index 00000000..21867061 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/SingleThreadExecutorValidator.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** + * A tool class to ensure the current method is running in the thread of a specified single thread + * executor. + */ +public class SingleThreadExecutorValidator { + + private static final Logger LOG = LoggerFactory.getLogger(SingleThreadExecutorValidator.class); + + private final Thread targetThread; + + public SingleThreadExecutorValidator(Executor executor) { + CompletableFuture targetThreadFuture = new CompletableFuture<>(); + executor.execute(() -> targetThreadFuture.complete(Thread.currentThread())); + targetThread = checkNotNull(targetThreadFuture.join()); + } + + public SingleThreadExecutorValidator(Thread targetThread) { + this.targetThread = targetThread; + } + + public void assertRunningInTargetThread() { + if (Thread.currentThread() != targetThread) { + RuntimeException exception = + new RuntimeException( + "Expected running in " + + targetThread + + ", but running in " + + Thread.currentThread()); + LOG.warn("Validate Failed", exception); + throw exception; + } + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/StringUtils.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/StringUtils.java new file mode 100644 index 00000000..594219c4 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/StringUtils.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; + +/** Utilities to handle {@link String} operations. */ +public class StringUtils { + + public static final Charset UTF_8 = StandardCharsets.UTF_8; + + public static String bytesToString(byte[] inputBytes) { + return new String(inputBytes, UTF_8); + } + + public static byte[] stringToBytes(String inputString) { + return inputString.getBytes(UTF_8); + } + + /** + * Checks if the string is null, empty, or contains only whitespace characters. A whitespace + * character is defined via {@link Character#isWhitespace(char)}. + */ + public static boolean isNullOrWhitespaceOnly(String string) { + if (string == null || string.length() == 0) { + return true; + } + + final int len = string.length(); + for (int i = 0; i < len; i++) { + if (!Character.isWhitespace(string.charAt(i))) { + return false; + } + } + return true; + } +} diff --git a/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/TimeUtils.java b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/TimeUtils.java new file mode 100644 index 00000000..cf85b3e7 --- /dev/null +++ b/shuffle-common/src/main/java/com/alibaba/flink/shuffle/common/utils/TimeUtils.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import java.time.Duration; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Collection of utilities about time intervals. + * + *

This class is copied from Apache Flink (org.apache.flink.util.TimeUtils). + */ +public class TimeUtils { + + private static final Map LABEL_TO_UNIT_MAP = + Collections.unmodifiableMap(initMap()); + + /** + * Parse the given string to a java {@link Duration}. The string is in format "{length + * value}{time unit label}", e.g. "123ms", "321 s". If no time unit label is specified, it will + * be considered as milliseconds. + * + *

Supported time unit labels are: + * + *

    + *
  • DAYS: "d", "day" + *
  • HOURS: "h", "hour" + *
  • MINUTES: "min", "minute" + *
  • SECONDS: "s", "sec", "second" + *
  • MILLISECONDS: "ms", "milli", "millisecond" + *
  • MICROSECONDS: "µs", "micro", "microsecond" + *
  • NANOSECONDS: "ns", "nano", "nanosecond" + *
+ * + * @param text string to parse. + */ + public static Duration parseDuration(String text) { + CommonUtils.checkNotNull(text); + + final String trimmed = text.trim(); + CommonUtils.checkArgument( + !trimmed.isEmpty(), "argument is an empty- or whitespace-only string"); + + final int len = trimmed.length(); + int pos = 0; + + char current; + while (pos < len && (current = trimmed.charAt(pos)) >= '0' && current <= '9') { + pos++; + } + + final String number = trimmed.substring(0, pos); + final String unitLabel = trimmed.substring(pos).trim().toLowerCase(Locale.US); + + if (number.isEmpty()) { + throw new NumberFormatException("text does not start with a number"); + } + + final long value; + try { + value = Long.parseLong(number); // this throws a NumberFormatException on overflow + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + "The value '" + + number + + "' cannot be re represented as 64bit number (numeric overflow)."); + } + + if (unitLabel.isEmpty()) { + return Duration.of(value, ChronoUnit.MILLIS); + } + + ChronoUnit unit = LABEL_TO_UNIT_MAP.get(unitLabel); + if (unit != null) { + return Duration.of(value, unit); + } else { + throw new IllegalArgumentException( + "Time interval unit label '" + + unitLabel + + "' does not match any of the recognized units: " + + TimeUnit.getAllUnits()); + } + } + + private static Map initMap() { + Map labelToUnit = new HashMap<>(); + for (TimeUnit timeUnit : TimeUnit.values()) { + for (String label : timeUnit.getLabels()) { + labelToUnit.put(label, timeUnit.getUnit()); + } + } + return labelToUnit; + } + + /** Enum which defines time unit, mostly used to parse value from configuration file. */ + private enum TimeUnit { + DAYS(ChronoUnit.DAYS, singular("d"), plural("day")), + HOURS(ChronoUnit.HOURS, singular("h"), plural("hour")), + MINUTES(ChronoUnit.MINUTES, singular("min"), plural("minute")), + SECONDS(ChronoUnit.SECONDS, singular("s"), plural("sec"), plural("second")), + MILLISECONDS(ChronoUnit.MILLIS, singular("ms"), plural("milli"), plural("millisecond")), + MICROSECONDS(ChronoUnit.MICROS, singular("µs"), plural("micro"), plural("microsecond")), + NANOSECONDS(ChronoUnit.NANOS, singular("ns"), plural("nano"), plural("nanosecond")); + + private static final String PLURAL_SUFFIX = "s"; + + private final List labels; + + private final ChronoUnit unit; + + TimeUnit(ChronoUnit unit, String[]... labels) { + this.unit = unit; + this.labels = + Arrays.stream(labels) + .flatMap(ls -> Arrays.stream(ls)) + .collect(Collectors.toList()); + } + + /** + * @param label the original label + * @return the singular format of the original label + */ + private static String[] singular(String label) { + return new String[] {label}; + } + + /** + * @param label the original label + * @return both the singular format and plural format of the original label + */ + private static String[] plural(String label) { + return new String[] {label, label + PLURAL_SUFFIX}; + } + + public List getLabels() { + return labels; + } + + public ChronoUnit getUnit() { + return unit; + } + + public static String getAllUnits() { + return Arrays.stream(TimeUnit.values()) + .map(TimeUnit::createTimeUnitString) + .collect(Collectors.joining(", ")); + } + + private static String createTimeUnitString(TimeUnit timeUnit) { + return timeUnit.name() + ": (" + String.join(" | ", timeUnit.getLabels()) + ")"; + } + } +} diff --git a/shuffle-common/src/test/java/com/alibaba/flink/shuffle/common/config/ConfigurationTest.java b/shuffle-common/src/test/java/com/alibaba/flink/shuffle/common/config/ConfigurationTest.java new file mode 100644 index 00000000..084939ed --- /dev/null +++ b/shuffle-common/src/test/java/com/alibaba/flink/shuffle/common/config/ConfigurationTest.java @@ -0,0 +1,329 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.config; + +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.FileWriter; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; + +/** Tests for {@link Configuration}. */ +public class ConfigurationTest { + + private static final String KEY_1 = "remote-shuffle.test.key1"; + + private static final String KEY_2 = "remote-shuffle.test.key2"; + + private static final String KEY_3 = "remote-shuffle.test.key3"; + + @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); + + @Test + public void testGetBoolean() { + ConfigOption option1 = new ConfigOption(KEY_1).defaultValue(false); + ConfigOption option2 = new ConfigOption(KEY_2).defaultValue(false); + ConfigOption option3 = new ConfigOption(KEY_3).defaultValue(false); + + Properties properties = new Properties(); + properties.put(KEY_1, "true"); + properties.put(KEY_2, "illegal"); + + Configuration configuration = new Configuration(properties); + + assertEquals(true, configuration.getBoolean(KEY_1)); + assertEquals(true, configuration.getBoolean(KEY_1, false)); + assertEquals(true, configuration.getBoolean(option1)); + assertEquals(true, configuration.getBoolean(option1, false)); + + checkGetIllegalValue(() -> configuration.getBoolean(KEY_2)); + checkGetIllegalValue(() -> configuration.getBoolean(KEY_2, false)); + checkGetIllegalValue(() -> configuration.getBoolean(option2)); + checkGetIllegalValue(() -> configuration.getBoolean(option2, false)); + + assertNull(configuration.getBoolean(KEY_3)); + assertEquals(false, configuration.getBoolean(KEY_3, false)); + assertEquals(true, configuration.getBoolean(KEY_3, true)); + assertEquals(option3.defaultValue(), configuration.getBoolean(option3)); + assertEquals(false, configuration.getBoolean(option3, false)); + assertEquals(true, configuration.getBoolean(option3, true)); + } + + @Test + public void testGetByte() { + Byte defaultValue = 'X'; + ConfigOption option1 = new ConfigOption(KEY_1).defaultValue(defaultValue); + ConfigOption option2 = new ConfigOption(KEY_2).defaultValue(defaultValue); + ConfigOption option3 = new ConfigOption(KEY_3).defaultValue(defaultValue); + + Byte value1 = 'O'; + Properties properties = new Properties(); + properties.put(KEY_1, value1.toString()); + properties.put(KEY_2, "illegal"); + + Configuration configuration = new Configuration(properties); + + assertEquals(value1, configuration.getByte(KEY_1)); + assertEquals(value1, configuration.getByte(KEY_1, defaultValue)); + assertEquals(value1, configuration.getByte(option1)); + assertEquals(value1, configuration.getByte(option1, defaultValue)); + + checkGetIllegalValue(() -> configuration.getByte(KEY_2)); + checkGetIllegalValue(() -> configuration.getByte(KEY_2, defaultValue)); + checkGetIllegalValue(() -> configuration.getByte(option2)); + checkGetIllegalValue(() -> configuration.getByte(option2, defaultValue)); + + assertNull(configuration.getByte(KEY_3)); + assertEquals(defaultValue, configuration.getByte(KEY_3, defaultValue)); + assertEquals(option3.defaultValue(), configuration.getByte(option3)); + assertEquals(value1, configuration.getByte(option3, value1)); + } + + @Test + public void testGetShort() { + Short defaultValue = 1; + ConfigOption option1 = new ConfigOption(KEY_1).defaultValue(defaultValue); + ConfigOption option2 = new ConfigOption(KEY_2).defaultValue(defaultValue); + ConfigOption option3 = new ConfigOption(KEY_3).defaultValue(defaultValue); + + Short value1 = 1024; + Properties properties = new Properties(); + properties.put(KEY_1, value1.toString()); + properties.put(KEY_2, "illegal"); + + Configuration configuration = new Configuration(properties); + + assertEquals(value1, configuration.getShort(KEY_1)); + assertEquals(value1, configuration.getShort(KEY_1, defaultValue)); + assertEquals(value1, configuration.getShort(option1)); + assertEquals(value1, configuration.getShort(option1, defaultValue)); + + checkGetIllegalValue(() -> configuration.getShort(KEY_2)); + checkGetIllegalValue(() -> configuration.getShort(KEY_2, defaultValue)); + checkGetIllegalValue(() -> configuration.getShort(option2)); + checkGetIllegalValue(() -> configuration.getShort(option2, defaultValue)); + + assertNull(configuration.getShort(KEY_3)); + assertEquals(defaultValue, configuration.getShort(KEY_3, defaultValue)); + assertEquals(option3.defaultValue(), configuration.getShort(option3)); + assertEquals(value1, configuration.getShort(option3, value1)); + } + + @Test + public void testGetInteger() { + Integer defaultValue = 1; + ConfigOption option1 = new ConfigOption(KEY_1).defaultValue(defaultValue); + ConfigOption option2 = new ConfigOption(KEY_2).defaultValue(defaultValue); + ConfigOption option3 = new ConfigOption(KEY_3).defaultValue(defaultValue); + + Integer value1 = 1024; + Properties properties = new Properties(); + properties.put(KEY_1, value1.toString()); + properties.put(KEY_2, "illegal"); + + Configuration configuration = new Configuration(properties); + + assertEquals(value1, configuration.getInteger(KEY_1)); + assertEquals(value1, configuration.getInteger(KEY_1, defaultValue)); + assertEquals(value1, configuration.getInteger(option1)); + assertEquals(value1, configuration.getInteger(option1, defaultValue)); + + checkGetIllegalValue(() -> configuration.getInteger(KEY_2)); + checkGetIllegalValue(() -> configuration.getInteger(KEY_2, defaultValue)); + checkGetIllegalValue(() -> configuration.getInteger(option2)); + checkGetIllegalValue(() -> configuration.getInteger(option2, defaultValue)); + + assertNull(configuration.getInteger(KEY_3)); + assertEquals(defaultValue, configuration.getInteger(KEY_3, defaultValue)); + assertEquals(option3.defaultValue(), configuration.getInteger(option3)); + assertEquals(value1, configuration.getInteger(option3, value1)); + } + + @Test + public void testGetLong() { + Long defaultValue = 1L; + ConfigOption option1 = new ConfigOption(KEY_1).defaultValue(defaultValue); + ConfigOption option2 = new ConfigOption(KEY_2).defaultValue(defaultValue); + ConfigOption option3 = new ConfigOption(KEY_3).defaultValue(defaultValue); + + Long value1 = 1024L; + Properties properties = new Properties(); + properties.put(KEY_1, value1.toString()); + properties.put(KEY_2, "illegal"); + + Configuration configuration = new Configuration(properties); + + assertEquals(value1, configuration.getLong(KEY_1)); + assertEquals(value1, configuration.getLong(KEY_1, defaultValue)); + assertEquals(value1, configuration.getLong(option1)); + assertEquals(value1, configuration.getLong(option1, defaultValue)); + + checkGetIllegalValue(() -> configuration.getLong(KEY_2)); + checkGetIllegalValue(() -> configuration.getLong(KEY_2, defaultValue)); + checkGetIllegalValue(() -> configuration.getLong(option2)); + checkGetIllegalValue(() -> configuration.getLong(option2, defaultValue)); + + assertNull(configuration.getLong(KEY_3)); + assertEquals(defaultValue, configuration.getLong(KEY_3, defaultValue)); + assertEquals(option3.defaultValue(), configuration.getLong(option3)); + assertEquals(value1, configuration.getLong(option3, value1)); + } + + @Test + public void testGetFloat() { + Float defaultValue = 1.0F; + ConfigOption option1 = new ConfigOption(KEY_1).defaultValue(defaultValue); + ConfigOption option2 = new ConfigOption(KEY_2).defaultValue(defaultValue); + ConfigOption option3 = new ConfigOption(KEY_3).defaultValue(defaultValue); + + Float value1 = 1024.0F; + Properties properties = new Properties(); + properties.put(KEY_1, value1.toString()); + properties.put(KEY_2, "illegal"); + + Configuration configuration = new Configuration(properties); + + assertEquals(value1, configuration.getFloat(KEY_1)); + assertEquals(value1, configuration.getFloat(KEY_1, defaultValue)); + assertEquals(value1, configuration.getFloat(option1)); + assertEquals(value1, configuration.getFloat(option1, defaultValue)); + + checkGetIllegalValue(() -> configuration.getFloat(KEY_2)); + checkGetIllegalValue(() -> configuration.getFloat(KEY_2, defaultValue)); + checkGetIllegalValue(() -> configuration.getFloat(option2)); + checkGetIllegalValue(() -> configuration.getFloat(option2, defaultValue)); + + assertNull(configuration.getFloat(KEY_3)); + assertEquals(defaultValue, configuration.getFloat(KEY_3, defaultValue)); + assertEquals(option3.defaultValue(), configuration.getFloat(option3)); + assertEquals(value1, configuration.getFloat(option3, value1)); + } + + @Test + public void testGetDouble() { + Double defaultValue = 1.0; + ConfigOption option1 = new ConfigOption(KEY_1).defaultValue(defaultValue); + ConfigOption option2 = new ConfigOption(KEY_2).defaultValue(defaultValue); + ConfigOption option3 = new ConfigOption(KEY_3).defaultValue(defaultValue); + + Double value1 = 1024.0; + Properties properties = new Properties(); + properties.put(KEY_1, value1.toString()); + properties.put(KEY_2, "illegal"); + + Configuration configuration = new Configuration(properties); + + assertEquals(value1, configuration.getDouble(KEY_1)); + assertEquals(value1, configuration.getDouble(KEY_1, defaultValue)); + assertEquals(value1, configuration.getDouble(option1)); + assertEquals(value1, configuration.getDouble(option1, defaultValue)); + + checkGetIllegalValue(() -> configuration.getDouble(KEY_2)); + checkGetIllegalValue(() -> configuration.getDouble(KEY_2, defaultValue)); + checkGetIllegalValue(() -> configuration.getDouble(option2)); + checkGetIllegalValue(() -> configuration.getDouble(option2, defaultValue)); + + assertNull(configuration.getDouble(KEY_3)); + assertEquals(defaultValue, configuration.getDouble(KEY_3, defaultValue)); + assertEquals(option3.defaultValue(), configuration.getDouble(option3)); + assertEquals(value1, configuration.getDouble(option3, value1)); + } + + @Test + public void testGetString() { + String defaultValue = "hello"; + ConfigOption option1 = new ConfigOption(KEY_1).defaultValue(defaultValue); + ConfigOption option2 = new ConfigOption(KEY_2).defaultValue(defaultValue); + + String value1 = "world"; + Properties properties = new Properties(); + properties.put(KEY_1, value1); + + Configuration configuration = new Configuration(properties); + + assertEquals(value1, configuration.getString(KEY_1)); + assertEquals(value1, configuration.getString(KEY_1, defaultValue)); + assertEquals(value1, configuration.getString(option1)); + assertEquals(value1, configuration.getString(option1, defaultValue)); + + assertNull(configuration.getString(KEY_3)); + assertEquals(defaultValue, configuration.getString(KEY_3, defaultValue)); + assertEquals(option2.defaultValue(), configuration.getString(option2)); + assertEquals(value1, configuration.getString(option2, value1)); + } + + @Test + public void testLoadConfigurationFromFile() throws Exception { + File confFile = temporaryFolder.newFile(Configuration.REMOTE_SHUFFLE_CONF_FILENAME); + + String value1 = "value1"; + String value2 = "value2"; + String value3 = "value3"; + + try (FileWriter fileWriter = new FileWriter(confFile)) { + fileWriter.write(KEY_1 + ": " + value1 + "\n"); + fileWriter.write(KEY_2 + ": " + value2 + "\n"); + fileWriter.write("#" + KEY_3 + ": " + value3 + "\n"); + } + + Configuration configuration = new Configuration(confFile.getParent()); + assertEquals(value1, configuration.getString(KEY_1)); + assertEquals(value2, configuration.getString(KEY_2)); + assertNull(configuration.getString(KEY_3)); + } + + @Test + public void testDynamicConfiguration() throws Exception { + File confFile = temporaryFolder.newFile(Configuration.REMOTE_SHUFFLE_CONF_FILENAME); + + String value1 = "value1"; + String value2 = "value2"; + String value3 = "value3"; + + try (FileWriter fileWriter = new FileWriter(confFile)) { + fileWriter.write(KEY_1 + ": " + value1 + "\n"); + fileWriter.write(KEY_2 + ": " + value2 + "\n"); + } + + Properties dynamicConfiguration = new Properties(); + dynamicConfiguration.setProperty(KEY_1, value3); + + Configuration configuration = new Configuration(confFile.getParent(), dynamicConfiguration); + assertEquals(value3, configuration.getString(KEY_1)); + assertEquals(value2, configuration.getString(KEY_2)); + } + + private void checkGetIllegalValue(Runnable runnable) { + try { + runnable.run(); + } catch (ConfigurationException ignored) { + // expected + return; + } + fail("Should throw IllegalArgumentException."); + } +} diff --git a/shuffle-common/src/test/java/com/alibaba/flink/shuffle/common/utils/CommonUtilsTest.java b/shuffle-common/src/test/java/com/alibaba/flink/shuffle/common/utils/CommonUtilsTest.java new file mode 100644 index 00000000..52111543 --- /dev/null +++ b/shuffle-common/src/test/java/com/alibaba/flink/shuffle/common/utils/CommonUtilsTest.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests for utilities in {@link CommonUtils}. */ +public class CommonUtilsTest { + + @Test + public void testConcatAndSplitByteArrays() { + byte[] first = new byte[] {1, 2, 3, 4}; + byte[] second = new byte[] {1, 2, 3}; + + byte[] concat = CommonUtils.concatByteArrays(first, second); + assertEquals(15, concat.length); + + List split = CommonUtils.splitByteArrays(concat); + assertEquals(2, split.size()); + assertArrayEquals(first, split.get(0)); + assertArrayEquals(second, split.get(1)); + } + + @Test + public void testConvertBetweenHexAndByteArray() { + byte[] bytes = CommonUtils.randomBytes(16); + String hex = CommonUtils.bytesToHexString(bytes); + byte[] decoded = CommonUtils.hexStringToBytes(hex); + assertArrayEquals(bytes, decoded); + } +} diff --git a/shuffle-common/src/test/java/com/alibaba/flink/shuffle/common/utils/ExceptionUtilsTest.java b/shuffle-common/src/test/java/com/alibaba/flink/shuffle/common/utils/ExceptionUtilsTest.java new file mode 100644 index 00000000..3caa57e5 --- /dev/null +++ b/shuffle-common/src/test/java/com/alibaba/flink/shuffle/common/utils/ExceptionUtilsTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** Tests for utilities in {@link CommonUtils}. */ +public class ExceptionUtilsTest { + + @Test + public void testSummaryErrorMessageStack() { + Throwable t = new Throwable("msg"); + assertEquals("[java.lang.Throwable: msg]", ExceptionUtils.summaryErrorMessageStack(t)); + + t = new Throwable(); + assertEquals("[java.lang.Throwable: null]", ExceptionUtils.summaryErrorMessageStack(t)); + + t = new Throwable("msga", new Throwable0("msgb")); + assertEquals( + "[java.lang.Throwable: msga] -> [com.alibaba.flink.shuffle.common.utils.ExceptionUtilsTest$Throwable0: msgb]", + ExceptionUtils.summaryErrorMessageStack(t)); + + t = new Throwable0("msga", new Throwable()); + assertEquals( + "[com.alibaba.flink.shuffle.common.utils.ExceptionUtilsTest$Throwable0: msga] -> [java.lang.Throwable: null]", + ExceptionUtils.summaryErrorMessageStack(t)); + + t = new Throwable0("msga", new Throwable(null, new Throwable0("msgb"))); + assertEquals( + "[com.alibaba.flink.shuffle.common.utils.ExceptionUtilsTest$Throwable0: msga] -> [java.lang.Throwable: null] -> [com.alibaba.flink.shuffle.common.utils.ExceptionUtilsTest$Throwable0: msgb]", + ExceptionUtils.summaryErrorMessageStack(t)); + } + + private class Throwable0 extends Throwable { + Throwable0(String msg) { + super(msg); + } + + Throwable0(String msg, Throwable t) { + super(msg, t); + } + } +} diff --git a/shuffle-common/src/test/java/com/alibaba/flink/shuffle/common/utils/FatalErrorsExitUtilsTest.java b/shuffle-common/src/test/java/com/alibaba/flink/shuffle/common/utils/FatalErrorsExitUtilsTest.java new file mode 100644 index 00000000..5d99d866 --- /dev/null +++ b/shuffle-common/src/test/java/com/alibaba/flink/shuffle/common/utils/FatalErrorsExitUtilsTest.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import org.junit.Test; + +import java.io.IOException; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** Tests for {@link FatalErrorExitUtils}. */ +public class FatalErrorsExitUtilsTest { + + @Test + public void testNeedStopProcess() { + FatalErrorExitUtils.setNeedStopProcess(false); + FatalErrorExitUtils.exitProcessIfNeeded(-1, new IOException("ignore")); + assertFalse(FatalErrorExitUtils.isNeedStopProcess()); + + FatalErrorExitUtils.setNeedStopProcess(true); + assertTrue(FatalErrorExitUtils.isNeedStopProcess()); + } +} diff --git a/shuffle-common/src/test/java/com/alibaba/flink/shuffle/common/utils/SingleThreadExecutorValidatorTest.java b/shuffle-common/src/test/java/com/alibaba/flink/shuffle/common/utils/SingleThreadExecutorValidatorTest.java new file mode 100644 index 00000000..0de1e100 --- /dev/null +++ b/shuffle-common/src/test/java/com/alibaba/flink/shuffle/common/utils/SingleThreadExecutorValidatorTest.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.common.utils; + +import org.junit.Test; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import static org.junit.Assert.fail; + +/** Tests the {@link SingleThreadExecutorValidator}. */ +public class SingleThreadExecutorValidatorTest { + + @Test + public void testRunningInSameThread() throws ExecutionException, InterruptedException { + ExecutorService executor = Executors.newSingleThreadExecutor(); + try { + SingleThreadExecutorValidator validator = new SingleThreadExecutorValidator(executor); + + // The following call should succeed. + executor.submit(validator::assertRunningInTargetThread).get(); + + } finally { + executor.shutdownNow(); + } + } + + @Test + public void testRunningInDifferentThread() { + ExecutorService executor = Executors.newSingleThreadExecutor(); + try { + SingleThreadExecutorValidator validator = new SingleThreadExecutorValidator(executor); + + try { + validator.assertRunningInTargetThread(); + fail("The check should failed"); + } catch (RuntimeException e) { + // Expected exception + } + } finally { + executor.shutdownNow(); + } + } +} diff --git a/shuffle-common/src/test/resources/log4j2-test.properties b/shuffle-common/src/test/resources/log4j2-test.properties new file mode 100644 index 00000000..d7fcb327 --- /dev/null +++ b/shuffle-common/src/test/resources/log4j2-test.properties @@ -0,0 +1,26 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +rootLogger.level=OFF +rootLogger.appenderRef.test.ref=TestLogger +appender.testlogger.name=TestLogger +appender.testlogger.type=CONSOLE +appender.testlogger.target=SYSTEM_ERR +appender.testlogger.layout.type=PatternLayout +appender.testlogger.layout.pattern=%-4r [%t] %-5p %c %x - %m%n diff --git a/shuffle-coordinator/pom.xml b/shuffle-coordinator/pom.xml new file mode 100644 index 00000000..7404c9aa --- /dev/null +++ b/shuffle-coordinator/pom.xml @@ -0,0 +1,172 @@ + + + + + com.alibaba.flink.shuffle + flink-shuffle-parent + 1.0-SNAPSHOT + + 4.0.0 + + shuffle-coordinator + + + 2.21.0 + + + + + com.alibaba.flink.shuffle + shuffle-rpc + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-core + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-common + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-storage + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-transfer + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-metrics + ${project.version} + + + + org.apache.flink + flink-shaded-zookeeper-3 + ${zookeeper.version} + + + + org.apache.commons + commons-lang3 + 3.3.2 + + + + org.apache.flink + flink-rpc-core + ${flink.version} + provided + + + + org.apache.flink + flink-core + ${flink.version} + provided + + + + org.apache.flink + flink-shaded-netty + 4.1.49.Final-${flink.shaded.version} + provided + + + + com.alibaba.flink.shuffle + shuffle-core + ${project.version} + test-jar + test + + + + org.apache.curator + curator-test + ${curator.version} + test + + + + log4j + log4j + + + + + + org.mockito + mockito-core + ${mockito.version} + jar + test + + + + + + + src/main/resources-filtered + true + + + + + + pl.project13.maven + git-commit-id-plugin + + + get-the-git-infos + validate + + revision + + + + + ${project.basedir}/../.git + false + false + false + + + + true + + + + + + diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/client/ShuffleManagerClient.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/client/ShuffleManagerClient.java new file mode 100644 index 00000000..c96cc95f --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/client/ShuffleManagerClient.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.client; + +import com.alibaba.flink.shuffle.coordinator.manager.JobDataPartitionDistribution; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleResource; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerMetrics; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; + +/** Client to interact with ShuffleManager to request and release shuffle resource. */ +public interface ShuffleManagerClient extends AutoCloseable { + + void start(); + + void synchronizeWorkerStatus(Set initialWorkers) throws Exception; + + CompletableFuture requestShuffleResource( + DataSetID dataSetId, + MapPartitionID mapPartitionId, + int numberOfSubpartitions, + String dataPartitionFactoryName); + + void releaseShuffleResource(DataSetID dataSetId, MapPartitionID mapPartitionId); + + CompletableFuture getNumberOfRegisteredWorkers(); + + CompletableFuture> getShuffleWorkerMetrics(); + + CompletableFuture> listJobs(boolean includeMyself); + + CompletableFuture getJobDataPartitionDistribution(JobID jobID); + + @Override + void close() throws Exception; +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/client/ShuffleManagerClientConfiguration.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/client/ShuffleManagerClientConfiguration.java new file mode 100644 index 00000000..ef6e58ad --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/client/ShuffleManagerClientConfiguration.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.client; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.coordinator.registration.RetryingRegistrationConfiguration; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; +import com.alibaba.flink.shuffle.core.config.RpcOptions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** The configuration of shuffle manager client. */ +public class ShuffleManagerClientConfiguration { + + private static final Logger LOG = + LoggerFactory.getLogger(ShuffleManagerClientConfiguration.class); + + private final Configuration configuration; + + private final long rpcTimeout; + + private final long maxRegistrationDuration; + + private final RetryingRegistrationConfiguration retryingRegistrationConfiguration; + + public ShuffleManagerClientConfiguration( + Configuration configuration, + long rpcTimeout, + long maxRegistrationDuration, + RetryingRegistrationConfiguration retryingRegistrationConfiguration) { + + this.configuration = configuration; + this.rpcTimeout = rpcTimeout; + this.maxRegistrationDuration = maxRegistrationDuration; + this.retryingRegistrationConfiguration = checkNotNull(retryingRegistrationConfiguration); + } + + public Configuration getConfiguration() { + return configuration; + } + + public long getRpcTimeout() { + return rpcTimeout; + } + + public long getMaxRegistrationDuration() { + return maxRegistrationDuration; + } + + public RetryingRegistrationConfiguration getRetryingRegistrationConfiguration() { + return retryingRegistrationConfiguration; + } + + // -------------------------------------------------------------------------------------------- + // Static factory methods + // -------------------------------------------------------------------------------------------- + + public static ShuffleManagerClientConfiguration fromConfiguration(Configuration configuration) { + long rpcTimeout; + try { + rpcTimeout = configuration.getDuration(RpcOptions.RPC_TIMEOUT).toMillis(); + } catch (Exception e) { + throw new IllegalArgumentException( + "Invalid format for '" + + RpcOptions.RPC_TIMEOUT.key() + + "'.Use formats like '50 s' or '1 min' to specify the timeout."); + } + LOG.debug("Messages have a max timeout of {}.", rpcTimeout); + + long maxRegistrationDuration; + try { + maxRegistrationDuration = + configuration.getDuration(ClusterOptions.REGISTRATION_TIMEOUT).toMillis(); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + String.format( + "Invalid format for parameter %s. Set the timeout to be infinite.", + ClusterOptions.REGISTRATION_TIMEOUT.key()), + e); + } + + RetryingRegistrationConfiguration retryingRegistrationConfiguration = + RetryingRegistrationConfiguration.fromConfiguration(configuration); + return new ShuffleManagerClientConfiguration( + configuration, + rpcTimeout, + maxRegistrationDuration, + retryingRegistrationConfiguration); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/client/ShuffleManagerClientImpl.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/client/ShuffleManagerClientImpl.java new file mode 100644 index 00000000..5327ba95 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/client/ShuffleManagerClientImpl.java @@ -0,0 +1,616 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.client; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; +import com.alibaba.flink.shuffle.common.utils.FutureUtils; +import com.alibaba.flink.shuffle.common.utils.SingleThreadExecutorValidator; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatListener; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatManager; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServices; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatTarget; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalListener; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalService; +import com.alibaba.flink.shuffle.coordinator.manager.JobDataPartitionDistribution; +import com.alibaba.flink.shuffle.coordinator.manager.ManagerToJobHeartbeatPayload; +import com.alibaba.flink.shuffle.coordinator.manager.RegistrationSuccess; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleManagerJobGateway; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleResource; +import com.alibaba.flink.shuffle.coordinator.registration.ConnectingConnection; +import com.alibaba.flink.shuffle.coordinator.registration.EstablishedConnection; +import com.alibaba.flink.shuffle.coordinator.registration.RegistrationConnectionListener; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerMetrics; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; +import com.alibaba.flink.shuffle.rpc.RpcTargetAddress; +import com.alibaba.flink.shuffle.rpc.executor.ScheduledExecutorServiceAdapter; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * The client that is responsible for interacting with ShuffleManager to request and release + * resources. + */ +public class ShuffleManagerClientImpl implements ShuffleManagerClient, LeaderRetrievalListener { + + private static final Logger LOG = LoggerFactory.getLogger(ShuffleManagerClientImpl.class); + + private final JobID jobID; + + private final InstanceID clientID; + + private final ShuffleWorkerStatusListener shuffleWorkerStatusListener; + + private final RemoteShuffleRpcService rpcService; + + private final FatalErrorHandler fatalErrorHandler; + + private final ShuffleManagerClientConfiguration shuffleManagerClientConfiguration; + + // ------------------------------------------------------------------------ + + private final ScheduledExecutorService mainThreadExecutor; + + private final SingleThreadExecutorValidator mainThreadExecutorValidator; + + private final HeartbeatManager shuffleManagerHeartbeatManager; + + private final LeaderRetrievalService shuffleManagerLeaderRetrieveService; + + @Nullable private RpcTargetAddress shuffleManagerAddress; + + @Nullable + private ConnectingConnection + shuffleManagerConnection; + + @Nullable + private EstablishedConnection + establishedConnection; + + @Nullable private UUID currentRegistrationTimeoutId; + + @Nullable private CompletableFuture connectionFuture; + + /** + * The list of cached shuffle workers. Here we could not directly use the list cached in the + * partition tracker since the list must be consistent with the result of the last heartbeat. + * However, after the last heartbeat, the PartitionTracker might have tracked more shuffle + * workers. If there are restarted shuffle workers, the next heartbeat would ignore these + * shuffle workers. + */ + private final Set relatedShuffleWorkers = new HashSet<>(); + + public ShuffleManagerClientImpl( + JobID jobID, + ShuffleWorkerStatusListener shuffleWorkerStatusListener, + RemoteShuffleRpcService rpcService, + FatalErrorHandler fatalErrorHandler, + ShuffleManagerClientConfiguration shuffleManagerClientConfiguration, + HaServices haServices, + HeartbeatServices heartbeatService, + InstanceID clientID) { + + this.jobID = checkNotNull(jobID); + this.clientID = checkNotNull(clientID); + this.shuffleWorkerStatusListener = checkNotNull(shuffleWorkerStatusListener); + this.rpcService = checkNotNull(rpcService); + this.fatalErrorHandler = checkNotNull(fatalErrorHandler); + this.shuffleManagerClientConfiguration = checkNotNull(shuffleManagerClientConfiguration); + + mainThreadExecutor = + Executors.newSingleThreadScheduledExecutor( + r -> { + Thread thread = new Thread(r); + thread.setName("shuffle-client-" + jobID); + return thread; + }); + mainThreadExecutorValidator = new SingleThreadExecutorValidator(mainThreadExecutor); + + this.shuffleManagerLeaderRetrieveService = + checkNotNull(haServices) + .createLeaderRetrievalService(HaServices.LeaderReceptor.SHUFFLE_CLIENT); + + shuffleManagerHeartbeatManager = + heartbeatService.createHeartbeatManagerSender( + new InstanceID(jobID.getId()), + new ManagerHeartbeatListener(), + new ScheduledExecutorServiceAdapter(mainThreadExecutor), + LOG); + } + + public ShuffleManagerClientImpl( + JobID jobID, + ShuffleWorkerStatusListener shuffleWorkerStatusListener, + RemoteShuffleRpcService rpcService, + FatalErrorHandler fatalErrorHandler, + ShuffleManagerClientConfiguration shuffleManagerClientConfiguration, + HaServices haServices, + HeartbeatServices heartbeatService) { + this( + jobID, + shuffleWorkerStatusListener, + rpcService, + fatalErrorHandler, + shuffleManagerClientConfiguration, + haServices, + heartbeatService, + new InstanceID()); + } + + // ------------------------------------------------------------------------ + // Life Cycle + // ------------------------------------------------------------------------ + + @Override + public void start() { + try { + connectionFuture = new CompletableFuture<>(); + shuffleManagerLeaderRetrieveService.start(this); + + // Wait till the connection is established + connectionFuture.get(); + } catch (Exception e) { + fatalErrorHandler.onFatalError(e); + } + } + + @Override + public void synchronizeWorkerStatus(Set initialWorkers) throws Exception { + // Force a synchronous heartbeat to the shuffle manager + LOG.info("Synchronize worker status with the manager."); + + CompletableFuture payloadFuture = new CompletableFuture<>(); + mainThreadExecutor.submit( + () -> { + // Here we directly merge the two arrays. The only + // result might be unnecessary unrelated notification, + // which should not cause problems. + relatedShuffleWorkers.addAll(initialWorkers); + return emitHeartbeat() + .whenCompleteAsync( + (payload, future) -> { + updateClientStatusOnHeartbeatSuccess(payload, future); + payloadFuture.complete(payload); + }, + mainThreadExecutor); + }); + + ManagerToJobHeartbeatPayload heartbeatPayload = + payloadFuture.get( + shuffleManagerClientConfiguration.getRpcTimeout(), TimeUnit.MILLISECONDS); + updatePartitionTrackerOnHeartbeatSuccess(heartbeatPayload, null); + } + + @Override + public void close() { + try { + mainThreadExecutor + .submit( + () -> { + if (establishedConnection != null) { + shuffleManagerHeartbeatManager.unmonitorTarget( + establishedConnection.getResponse().getInstanceID()); + + if (establishedConnection != null) { + establishedConnection + .getGateway() + .unregisterClient(jobID, clientID); + } + } + }) + .get(); + } catch (Exception e) { + LOG.error("Failed to close the established connections", e); + } + + mainThreadExecutor.shutdownNow(); + } + + // ------------------------------------------------------------------------ + // Internal shuffle manager connection methods + // ------------------------------------------------------------------------ + + @Override + public void notifyLeaderAddress(LeaderInformation leaderInfo) { + mainThreadExecutor.execute(() -> notifyOfNewShuffleManagerLeader(leaderInfo)); + } + + @Override + public void handleError(Exception exception) { + fatalErrorHandler.onFatalError( + new Exception("Failed to retrieve shuffle manager address", exception)); + } + + private void notifyOfNewShuffleManagerLeader(LeaderInformation leaderInfo) { + shuffleManagerAddress = createShuffleManagerAddress(leaderInfo); + reconnectToShuffleManager( + new ShuffleException( + String.format( + "ShuffleManager leader changed to new address %s", + shuffleManagerAddress))); + } + + @Nullable + private RpcTargetAddress createShuffleManagerAddress(LeaderInformation leaderInfo) { + if (leaderInfo == LeaderInformation.empty()) { + return null; + } + return new RpcTargetAddress(leaderInfo.getLeaderAddress(), leaderInfo.getLeaderSessionID()); + } + + private void reconnectToShuffleManager(Exception cause) { + closeShuffleManagerConnection(cause); + startRegistrationTimeout(); + tryConnectToShuffleManager(); + } + + private void tryConnectToShuffleManager() { + if (shuffleManagerAddress != null) { + connectToShuffleManager(); + } + } + + private void connectToShuffleManager() { + assert (shuffleManagerAddress != null); + assert (establishedConnection == null); + assert (shuffleManagerConnection == null); + + LOG.info("Connecting to ShuffleManager {}.", shuffleManagerAddress); + + shuffleManagerConnection = + new ConnectingConnection<>( + LOG, + "ShuffleManager", + ShuffleManagerJobGateway.class, + rpcService, + shuffleManagerClientConfiguration.getRetryingRegistrationConfiguration(), + shuffleManagerAddress.getTargetAddress(), + shuffleManagerAddress.getLeaderUUID(), + mainThreadExecutor, + new ShuffleManagerRegistrationListener(), + (gateway) -> gateway.registerClient(jobID, clientID)); + + shuffleManagerConnection.start(); + } + + private void establishShuffleManagerConnection( + ShuffleManagerJobGateway shuffleManagerGateway, RegistrationSuccess response) { + + // monitor the shuffle manager as heartbeat target + shuffleManagerHeartbeatManager.monitorTarget( + response.getInstanceID(), + new HeartbeatTarget() { + @Override + public void receiveHeartbeat(InstanceID instanceID, Void heartbeatPayload) { + // Will never call this + } + + @Override + public void requestHeartbeat(InstanceID instanceID, Void heartbeatPayload) { + heartbeatToShuffleManager(); + } + }); + + establishedConnection = new EstablishedConnection<>(shuffleManagerGateway, response); + + stopRegistrationTimeout(); + + checkState(connectionFuture != null); + connectionFuture.complete(null); + } + + private void closeShuffleManagerConnection(Exception cause) { + if (establishedConnection != null) { + final InstanceID shuffleManagerInstanceID = + establishedConnection.getResponse().getInstanceID(); + + if (LOG.isDebugEnabled()) { + LOG.debug("Close ShuffleManager connection {}.", shuffleManagerInstanceID, cause); + } else { + LOG.info("Close ShuffleManager connection {}.", shuffleManagerInstanceID); + } + shuffleManagerHeartbeatManager.unmonitorTarget(shuffleManagerInstanceID); + + ShuffleManagerJobGateway shuffleManagerGateway = establishedConnection.getGateway(); + shuffleManagerGateway.unregisterClient(jobID, clientID); + + establishedConnection = null; + } + + if (shuffleManagerConnection != null) { + if (!shuffleManagerConnection.isConnected()) { + if (LOG.isDebugEnabled()) { + LOG.debug( + "Terminating registration attempts towards ShuffleManager {}.", + shuffleManagerConnection.getTargetAddress(), + cause); + } else { + LOG.info( + "Terminating registration attempts towards ShuffleManager {}.", + shuffleManagerConnection.getTargetAddress()); + } + } + + shuffleManagerConnection.close(); + shuffleManagerConnection = null; + } + } + + private void startRegistrationTimeout() { + final UUID newRegistrationTimeoutId = UUID.randomUUID(); + currentRegistrationTimeoutId = newRegistrationTimeoutId; + + mainThreadExecutor.schedule( + () -> registrationTimeout(newRegistrationTimeoutId), + shuffleManagerClientConfiguration.getMaxRegistrationDuration(), + TimeUnit.MILLISECONDS); + } + + private void stopRegistrationTimeout() { + currentRegistrationTimeoutId = null; + } + + private void registrationTimeout(@Nonnull UUID registrationTimeoutId) { + if (registrationTimeoutId.equals(currentRegistrationTimeoutId)) { + fatalErrorHandler.onFatalError( + new Exception( + String.format( + "Could not register at the ShuffleManager within the specified maximum " + + "registration duration %s. This indicates a problem with this instance. Terminating now.", + shuffleManagerClientConfiguration + .getMaxRegistrationDuration()))); + } + } + + public void heartbeatToShuffleManager() { + emitHeartbeat() + .whenCompleteAsync( + (payload, exception) -> { + updateClientStatusOnHeartbeatSuccess(payload, exception); + updatePartitionTrackerOnHeartbeatSuccess(payload, exception); + }, + mainThreadExecutor); + } + + private CompletableFuture emitHeartbeat() { + if (establishedConnection != null) { + return establishedConnection + .getGateway() + .heartbeatFromClient(jobID, clientID, relatedShuffleWorkers); + } + + return FutureUtils.completedExceptionally( + new RuntimeException("Connection not established yet")); + } + + /** + * Updates the shuffle client status on heartbeat success. This method must be called in the + * {@code mainExecutor}. + */ + private void updateClientStatusOnHeartbeatSuccess( + ManagerToJobHeartbeatPayload payload, Throwable exception) { + mainThreadExecutorValidator.assertRunningInTargetThread(); + + if (exception != null) { + LOG.warn("Heartbeat to ShuffleManager failed.", exception); + return; + } + + shuffleManagerHeartbeatManager.receiveHeartbeat(payload.getManagerID(), null); + + payload.getJobChangedWorkerStatus() + .getIrrelevantWorkers() + .forEach(relatedShuffleWorkers::remove); + + payload.getJobChangedWorkerStatus() + .getRelevantWorkers() + .forEach((workerId, ignored) -> relatedShuffleWorkers.add(workerId)); + } + + /** Updates the partition tracker status on heartbeat success. */ + private void updatePartitionTrackerOnHeartbeatSuccess( + ManagerToJobHeartbeatPayload payload, Throwable exception) { + if (exception != null) { + LOG.warn("Heartbeat to ShuffleManager failed.", exception); + return; + } + + for (InstanceID unavailable : payload.getJobChangedWorkerStatus().getIrrelevantWorkers()) { + LOG.info("Got unrelated shuffle worker: " + unavailable); + + shuffleWorkerStatusListener.notifyIrrelevantWorker(unavailable); + } + + payload.getJobChangedWorkerStatus() + .getRelevantWorkers() + .forEach( + (workerId, partitions) -> { + LOG.info("Got newly related shuffle worker: " + workerId); + + shuffleWorkerStatusListener.notifyRelevantWorker(workerId, partitions); + }); + } + + // ------------------------------------------------------------------------ + // Manage the shuffle partitions + // ------------------------------------------------------------------------ + + @Override + public CompletableFuture requestShuffleResource( + DataSetID dataSetId, + MapPartitionID mapPartitionId, + int numberOfSubpartitions, + String dataPartitionFactoryName) { + return sendRpcCall( + (shuffleManagerJobGateway) -> + shuffleManagerJobGateway.requestShuffleResource( + jobID, + clientID, + dataSetId, + mapPartitionId, + numberOfSubpartitions, + dataPartitionFactoryName)); + } + + @Override + public void releaseShuffleResource(DataSetID dataSetId, MapPartitionID mapPartitionId) { + sendRpcCall( + shuffleManagerJobGateway -> + shuffleManagerJobGateway.releaseShuffleResource( + jobID, clientID, dataSetId, mapPartitionId)); + } + + @Override + public CompletableFuture getNumberOfRegisteredWorkers() { + return sendRpcCall(ShuffleManagerJobGateway::getNumberOfRegisteredWorkers); + } + + @Override + public CompletableFuture> getShuffleWorkerMetrics() { + return sendRpcCall(ShuffleManagerJobGateway::getShuffleWorkerMetrics); + } + + @Override + public CompletableFuture> listJobs(boolean includeMyself) { + return sendRpcCall( + shuffleManagerJobGateway -> + shuffleManagerJobGateway + .listJobs() + .thenApply( + jobIds -> { + if (includeMyself) { + return jobIds; + } + + return jobIds.stream() + .filter(id -> !id.equals(jobID)) + .collect(Collectors.toList()); + })); + } + + @Override + public CompletableFuture getJobDataPartitionDistribution( + JobID jobID) { + return sendRpcCall( + shuffleManagerJobGateway -> + shuffleManagerJobGateway.getJobDataPartitionDistribution(jobID)); + } + + private CompletableFuture sendRpcCall( + Function> rpcCallFunction) { + + return CompletableFuture.supplyAsync( + () -> { + if (establishedConnection == null) { + Exception e = new Exception("No connection to the shuffle manager"); + LOG.warn( + "No connection to the shuffle manager, abort the request", + e); + throw new CompletionException(e); + } + + return establishedConnection.getGateway(); + }, + mainThreadExecutor) + .thenComposeAsync(rpcCallFunction, mainThreadExecutor); + } + + // ------------------------------------------------------------------------ + // Static utility classes + // ------------------------------------------------------------------------ + + private final class ShuffleManagerRegistrationListener + implements RegistrationConnectionListener< + ConnectingConnection, + RegistrationSuccess> { + + @Override + public void onRegistrationSuccess( + ConnectingConnection connection, + RegistrationSuccess success) { + final ShuffleManagerJobGateway shuffleManagerGateway = connection.getTargetGateway(); + + mainThreadExecutor.execute( + () -> { + // filter out outdated connections + //noinspection ObjectEquality + if (shuffleManagerConnection == connection) { + try { + establishShuffleManagerConnection(shuffleManagerGateway, success); + } catch (Throwable t) { + LOG.error( + "Establishing Shuffle Manager connection in client failed", + t); + } + } + }); + } + + @Override + public void onRegistrationFailure(Throwable failure) { + fatalErrorHandler.onFatalError(failure); + } + } + + private class ManagerHeartbeatListener implements HeartbeatListener { + + @Override + public void notifyHeartbeatTimeout(InstanceID instanceID) { + LOG.info("Timeout with remote shuffle manager {}", instanceID); + if (establishedConnection != null + && establishedConnection.getResponse().getInstanceID().equals(instanceID)) { + reconnectToShuffleManager(new Exception("Heartbeat timeout")); + } + } + + @Override + public void reportPayload(InstanceID instanceID, Void payload) {} + + @Override + public Void retrievePayload(InstanceID instanceID) { + return null; + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/client/ShuffleWorkerStatusListener.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/client/ShuffleWorkerStatusListener.java new file mode 100644 index 00000000..3be31f69 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/client/ShuffleWorkerStatusListener.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.client; + +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +import java.util.Set; + +/** The Listener about the status change of the shuffle workers. */ +public interface ShuffleWorkerStatusListener { + + /** + * Notifies that the shuffle worker get unrelated with this job, like due to worker get + * unavailable or all the data partitions are released. + * + * @param workerID the resource id of the shuffle worker. + */ + void notifyIrrelevantWorker(InstanceID workerID); + + /** + * Notifies that the shuffle worker get back. + * + * @param workerID the resource id of the shuffle worker. + * @param dataPartitions the recovered data partitions on the shuffle worker. + */ + void notifyRelevantWorker(InstanceID workerID, Set dataPartitions); +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatListener.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatListener.java new file mode 100644 index 00000000..10557af1 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatListener.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +/** + * Interface for the interaction with the {@link HeartbeatManager}. The heartbeat listener is used + * for the following things: + * + *
    + *
  • Notifications about heartbeat timeouts + *
  • Payload reports of incoming heartbeats + *
  • Retrieval of payloads for outgoing heartbeats + *
+ * + * @param Type of the incoming payload + * @param Type of the outgoing payload + */ +public interface HeartbeatListener { + + /** + * Callback which is called if a heartbeat for the machine identified by the given resource ID + * times out. + * + * @param instanceID InstanceID ID of the machine whose heartbeat has timed out + */ + void notifyHeartbeatTimeout(InstanceID instanceID); + + /** + * Callback which is called whenever a heartbeat with an associated payload is received. The + * carried payload is given to this method. + * + * @param instanceID InstanceID ID identifying the sender of the payload + * @param payload Payload of the received heartbeat + */ + void reportPayload(InstanceID instanceID, I payload); + + /** + * Retrieves the payload value for the next heartbeat message. + * + * @param instanceID InstanceID ID identifying the receiver of the payload + * @return The payload for the next heartbeat + */ + O retrievePayload(InstanceID instanceID); +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatManager.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatManager.java new file mode 100644 index 00000000..d498a1bb --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatManager.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +/** + * A heartbeat manager has to be able to start/stop monitoring a {@link HeartbeatTarget}, and report + * heartbeat timeouts for this target. + * + * @param Type of the incoming payload + * @param Type of the outgoing payload + */ +public interface HeartbeatManager extends HeartbeatTarget { + + /** + * Start monitoring a {@link HeartbeatTarget}. Heartbeat timeouts for this target are reported + * to the {@link HeartbeatListener} associated with this heartbeat manager. + * + * @param instanceID Resource ID identifying the heartbeat target + * @param heartbeatTarget Interface to send heartbeat requests and responses to the heartbeat + * target + */ + void monitorTarget(InstanceID instanceID, HeartbeatTarget heartbeatTarget); + + /** + * Stops monitoring the heartbeat target with the associated resource ID. + * + * @param instanceID Resource ID of the heartbeat target which shall no longer be monitored + */ + void unmonitorTarget(InstanceID instanceID); + + /** Stops the heartbeat manager. */ + void stop(); + + /** + * Returns the last received heartbeat from the given target. + * + * @param instanceID for which to return the last heartbeat + * @return Last heartbeat received from the given target or -1 if the target is not being + * monitored. + */ + long getLastHeartbeatFrom(InstanceID instanceID); +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatManagerImpl.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatManagerImpl.java new file mode 100644 index 00000000..c4623da9 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatManagerImpl.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.rpc.executor.ScheduledExecutor; + +import org.slf4j.Logger; + +import javax.annotation.concurrent.ThreadSafe; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** + * Heartbeat manager implementation. The heartbeat manager maintains a map of heartbeat monitors and + * resource IDs. Each monitor will be updated when a new heartbeat of the associated machine has + * been received. If the monitor detects that a heartbeat has timed out, it will notify the {@link + * HeartbeatListener} about it. A heartbeat times out iff no heartbeat signal has been received + * within a given timeout interval. + * + * @param Type of the incoming heartbeat payload + * @param Type of the outgoing heartbeat payload + */ +@ThreadSafe +public class HeartbeatManagerImpl implements HeartbeatManager { + + /** Heartbeat timeout interval in milli seconds. */ + private final long heartbeatTimeoutIntervalMs; + + /** Resource ID which is used to mark one own's heartbeat signals. */ + private final InstanceID ownInstanceID; + + /** Heartbeat listener with which the heartbeat manager has been associated. */ + private final HeartbeatListener heartbeatListener; + + /** Executor service used to run heartbeat timeout notifications. */ + private final ScheduledExecutor mainThreadExecutor; + + protected final Logger log; + + /** Map containing the heartbeat monitors associated with the respective resource ID. */ + private final ConcurrentHashMap> heartbeatTargets; + + private final HeartbeatMonitor.Factory heartbeatMonitorFactory; + + /** Running state of the heartbeat manager. */ + protected volatile boolean stopped; + + public HeartbeatManagerImpl( + long heartbeatTimeoutIntervalMs, + InstanceID ownInstanceID, + HeartbeatListener heartbeatListener, + ScheduledExecutor mainThreadExecutor, + Logger log) { + this( + heartbeatTimeoutIntervalMs, + ownInstanceID, + heartbeatListener, + mainThreadExecutor, + log, + new HeartbeatMonitorImpl.Factory<>()); + } + + public HeartbeatManagerImpl( + long heartbeatTimeoutIntervalMs, + InstanceID ownInstanceID, + HeartbeatListener heartbeatListener, + ScheduledExecutor mainThreadExecutor, + Logger log, + HeartbeatMonitor.Factory heartbeatMonitorFactory) { + + checkArgument( + heartbeatTimeoutIntervalMs > 0L, "The heartbeat timeout has to be larger than 0."); + + this.heartbeatTimeoutIntervalMs = heartbeatTimeoutIntervalMs; + this.ownInstanceID = checkNotNull(ownInstanceID); + this.heartbeatListener = checkNotNull(heartbeatListener); + this.mainThreadExecutor = checkNotNull(mainThreadExecutor); + this.log = checkNotNull(log); + this.heartbeatMonitorFactory = heartbeatMonitorFactory; + this.heartbeatTargets = new ConcurrentHashMap<>(16); + + stopped = false; + } + + // ---------------------------------------------------------------------------------------------- + // Getters + // ---------------------------------------------------------------------------------------------- + + InstanceID getOwnInstanceID() { + return ownInstanceID; + } + + HeartbeatListener getHeartbeatListener() { + return heartbeatListener; + } + + public Map> getHeartbeatTargets() { + return heartbeatTargets; + } + + // ---------------------------------------------------------------------------------------------- + // HeartbeatManager methods + // ---------------------------------------------------------------------------------------------- + + @Override + public void monitorTarget(InstanceID instanceID, HeartbeatTarget heartbeatTarget) { + if (!stopped) { + if (heartbeatTargets.containsKey(instanceID)) { + log.debug("The target with instance ID {} is already been monitored.", instanceID); + } else { + HeartbeatMonitor heartbeatMonitor = + heartbeatMonitorFactory.createHeartbeatMonitor( + instanceID, + heartbeatTarget, + mainThreadExecutor, + heartbeatListener, + heartbeatTimeoutIntervalMs); + + heartbeatTargets.put(instanceID, heartbeatMonitor); + + // check if we have stopped in the meantime (concurrent stop operation) + if (stopped) { + heartbeatMonitor.cancel(); + + heartbeatTargets.remove(instanceID); + } + } + } + } + + @Override + public void unmonitorTarget(InstanceID instanceID) { + if (!stopped) { + HeartbeatMonitor heartbeatMonitor = heartbeatTargets.remove(instanceID); + + if (heartbeatMonitor != null) { + heartbeatMonitor.cancel(); + } + } + } + + @Override + public void stop() { + stopped = true; + + for (HeartbeatMonitor heartbeatMonitor : heartbeatTargets.values()) { + heartbeatMonitor.cancel(); + } + + heartbeatTargets.clear(); + } + + @Override + public long getLastHeartbeatFrom(InstanceID instanceID) { + HeartbeatMonitor heartbeatMonitor = heartbeatTargets.get(instanceID); + + if (heartbeatMonitor != null) { + return heartbeatMonitor.getLastHeartbeat(); + } else { + return -1L; + } + } + + ScheduledExecutor getMainThreadExecutor() { + return mainThreadExecutor; + } + + // ---------------------------------------------------------------------------------------------- + // HeartbeatTarget methods + // ---------------------------------------------------------------------------------------------- + + @Override + public void receiveHeartbeat(InstanceID heartbeatOrigin, I heartbeatPayload) { + if (!stopped) { + log.debug("Received heartbeat from {}.", heartbeatOrigin); + reportHeartbeat(heartbeatOrigin); + + if (heartbeatPayload != null) { + heartbeatListener.reportPayload(heartbeatOrigin, heartbeatPayload); + } + } + } + + @Override + public void requestHeartbeat(final InstanceID requestOrigin, I heartbeatPayload) { + if (!stopped) { + log.debug("Received heartbeat request from {}.", requestOrigin); + + final HeartbeatTarget heartbeatTarget = reportHeartbeat(requestOrigin); + + if (heartbeatTarget != null) { + if (heartbeatPayload != null) { + heartbeatListener.reportPayload(requestOrigin, heartbeatPayload); + } + + heartbeatTarget.receiveHeartbeat( + getOwnInstanceID(), heartbeatListener.retrievePayload(requestOrigin)); + } + } + } + + HeartbeatTarget reportHeartbeat(InstanceID instanceID) { + if (heartbeatTargets.containsKey(instanceID)) { + HeartbeatMonitor heartbeatMonitor = heartbeatTargets.get(instanceID); + heartbeatMonitor.reportHeartbeat(); + + return heartbeatMonitor.getHeartbeatTarget(); + } else { + return null; + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatManagerSenderImpl.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatManagerSenderImpl.java new file mode 100644 index 00000000..3a0895fc --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatManagerSenderImpl.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.rpc.executor.ScheduledExecutor; + +import org.slf4j.Logger; + +import java.util.concurrent.TimeUnit; + +/** + * {@link HeartbeatManager} implementation which regularly requests a heartbeat response from its + * monitored {@link HeartbeatTarget}. The heartbeat period is configurable. + * + * @param Type of the incoming heartbeat payload + * @param Type of the outgoing heartbeat payload + */ +public class HeartbeatManagerSenderImpl extends HeartbeatManagerImpl + implements Runnable { + + private final long heartbeatPeriod; + + public HeartbeatManagerSenderImpl( + long heartbeatPeriod, + long heartbeatTimeout, + InstanceID ownInstanceID, + HeartbeatListener heartbeatListener, + ScheduledExecutor mainThreadExecutor, + Logger log) { + this( + heartbeatPeriod, + heartbeatTimeout, + ownInstanceID, + heartbeatListener, + mainThreadExecutor, + log, + new HeartbeatMonitorImpl.Factory<>()); + } + + HeartbeatManagerSenderImpl( + long heartbeatPeriod, + long heartbeatTimeout, + InstanceID ownInstanceID, + HeartbeatListener heartbeatListener, + ScheduledExecutor mainThreadExecutor, + Logger log, + HeartbeatMonitor.Factory heartbeatMonitorFactory) { + super( + heartbeatTimeout, + ownInstanceID, + heartbeatListener, + mainThreadExecutor, + log, + heartbeatMonitorFactory); + + this.heartbeatPeriod = heartbeatPeriod; + mainThreadExecutor.schedule(this, 0L, TimeUnit.MILLISECONDS); + } + + @Override + public void run() { + if (!stopped) { + log.debug("Trigger heartbeat request."); + for (HeartbeatMonitor heartbeatMonitor : getHeartbeatTargets().values()) { + requestHeartbeat(heartbeatMonitor); + } + + getMainThreadExecutor().schedule(this, heartbeatPeriod, TimeUnit.MILLISECONDS); + } + } + + private void requestHeartbeat(HeartbeatMonitor heartbeatMonitor) { + O payload = getHeartbeatListener().retrievePayload(heartbeatMonitor.getHeartbeatTargetId()); + final HeartbeatTarget heartbeatTarget = heartbeatMonitor.getHeartbeatTarget(); + + heartbeatTarget.requestHeartbeat(getOwnInstanceID(), payload); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatMonitor.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatMonitor.java new file mode 100644 index 00000000..413ee04b --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatMonitor.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.rpc.executor.ScheduledExecutor; + +/** + * Heartbeat monitor which manages the heartbeat state of the associated heartbeat target. The + * monitor notifies the {@link HeartbeatListener} whenever it has not seen a heartbeat signal in the + * specified heartbeat timeout interval. Each heartbeat signal resets this timer. + * + * @param Type of the payload being sent to the associated heartbeat target + */ +public interface HeartbeatMonitor { + + /** + * Gets heartbeat target. + * + * @return the heartbeat target + */ + HeartbeatTarget getHeartbeatTarget(); + + /** + * Gets heartbeat target id. + * + * @return the heartbeat target id + */ + InstanceID getHeartbeatTargetId(); + + /** Report heartbeat from the monitored target. */ + void reportHeartbeat(); + + /** Cancel this monitor. */ + void cancel(); + + /** + * Gets the last heartbeat. + * + * @return the last heartbeat + */ + long getLastHeartbeat(); + + /** + * This factory provides an indirection way to create {@link HeartbeatMonitor}. + * + * @param Type of the outgoing heartbeat payload + */ + interface Factory { + /** + * Create heartbeat monitor. + * + * @param instanceID the resource id + * @param heartbeatTarget the heartbeat target + * @param mainThreadExecutor the main thread executor + * @param heartbeatListener the heartbeat listener + * @param heartbeatTimeoutIntervalMs the heartbeat timeout interval ms + * @return the heartbeat monitor + */ + HeartbeatMonitor createHeartbeatMonitor( + InstanceID instanceID, + HeartbeatTarget heartbeatTarget, + ScheduledExecutor mainThreadExecutor, + HeartbeatListener heartbeatListener, + long heartbeatTimeoutIntervalMs); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatMonitorImpl.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatMonitorImpl.java new file mode 100644 index 00000000..ae2840a8 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatMonitorImpl.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.rpc.executor.ScheduledExecutor; + +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** + * The default implementation of {@link HeartbeatMonitor}. + * + * @param Type of the payload being sent to the associated heartbeat target + */ +public class HeartbeatMonitorImpl implements HeartbeatMonitor, Runnable { + + /** Resource ID of the monitored heartbeat target. */ + private final InstanceID instanceID; + + /** Associated heartbeat target. */ + private final HeartbeatTarget heartbeatTarget; + + private final ScheduledExecutor scheduledExecutor; + + /** Listener which is notified about heartbeat timeouts. */ + private final HeartbeatListener heartbeatListener; + + /** Maximum heartbeat timeout interval. */ + private final long heartbeatTimeoutIntervalMs; + + private volatile ScheduledFuture futureTimeout; + + private final AtomicReference state = new AtomicReference<>(State.RUNNING); + + private volatile long lastHeartbeat; + + HeartbeatMonitorImpl( + InstanceID instanceID, + HeartbeatTarget heartbeatTarget, + ScheduledExecutor scheduledExecutor, + HeartbeatListener heartbeatListener, + long heartbeatTimeoutIntervalMs) { + + this.instanceID = checkNotNull(instanceID); + this.heartbeatTarget = checkNotNull(heartbeatTarget); + this.scheduledExecutor = checkNotNull(scheduledExecutor); + this.heartbeatListener = checkNotNull(heartbeatListener); + + checkArgument( + heartbeatTimeoutIntervalMs > 0L, + "The heartbeat timeout interval has to be larger than 0."); + this.heartbeatTimeoutIntervalMs = heartbeatTimeoutIntervalMs; + + lastHeartbeat = 0L; + + resetHeartbeatTimeout(heartbeatTimeoutIntervalMs); + } + + @Override + public HeartbeatTarget getHeartbeatTarget() { + return heartbeatTarget; + } + + @Override + public InstanceID getHeartbeatTargetId() { + return instanceID; + } + + @Override + public long getLastHeartbeat() { + return lastHeartbeat; + } + + @Override + public void reportHeartbeat() { + lastHeartbeat = System.currentTimeMillis(); + resetHeartbeatTimeout(heartbeatTimeoutIntervalMs); + } + + @Override + public void cancel() { + // we can only cancel if we are in state running + if (state.compareAndSet(State.RUNNING, State.CANCELED)) { + cancelTimeout(); + } + } + + @Override + public void run() { + // The heartbeat has timed out if we're in state running + if (state.compareAndSet(State.RUNNING, State.TIMEOUT)) { + heartbeatListener.notifyHeartbeatTimeout(instanceID); + } + } + + public boolean isCanceled() { + return state.get() == State.CANCELED; + } + + void resetHeartbeatTimeout(long heartbeatTimeout) { + if (state.get() == State.RUNNING) { + cancelTimeout(); + + futureTimeout = + scheduledExecutor.schedule(this, heartbeatTimeout, TimeUnit.MILLISECONDS); + + // Double check for concurrent accesses (e.g. a firing of the scheduled future) + if (state.get() != State.RUNNING) { + cancelTimeout(); + } + } + } + + private void cancelTimeout() { + if (futureTimeout != null) { + futureTimeout.cancel(true); + } + } + + private enum State { + RUNNING, + TIMEOUT, + CANCELED + } + + /** + * The factory that instantiates {@link HeartbeatMonitorImpl}. + * + * @param Type of the outgoing heartbeat payload + */ + static class Factory implements HeartbeatMonitor.Factory { + + @Override + public HeartbeatMonitor createHeartbeatMonitor( + InstanceID instanceID, + HeartbeatTarget heartbeatTarget, + ScheduledExecutor mainThreadExecutor, + HeartbeatListener heartbeatListener, + long heartbeatTimeoutIntervalMs) { + + return new HeartbeatMonitorImpl<>( + instanceID, + heartbeatTarget, + mainThreadExecutor, + heartbeatListener, + heartbeatTimeoutIntervalMs); + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatServices.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatServices.java new file mode 100644 index 00000000..82043386 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatServices.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.rpc.executor.ScheduledExecutor; + +import org.slf4j.Logger; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; + +/** + * HeartbeatServices gives access to all services needed for heartbeat. This includes the creation + * of heartbeat receivers and heartbeat senders. + */ +public class HeartbeatServices { + + /** Heartbeat interval for the created services. */ + protected final long heartbeatInterval; + + /** Heartbeat timeout for the created services. */ + protected final long heartbeatTimeout; + + public HeartbeatServices(long heartbeatInterval, long heartbeatTimeout) { + checkArgument(0L < heartbeatInterval, "The heartbeat interval must be larger than 0."); + checkArgument( + heartbeatInterval <= heartbeatTimeout, + "The heartbeat timeout should be larger or equal than the heartbeat interval."); + + this.heartbeatInterval = heartbeatInterval; + this.heartbeatTimeout = heartbeatTimeout; + } + + /** + * Creates a heartbeat manager which does not actively send heartbeats. + * + * @param instanceID Resource Id which identifies the owner of the heartbeat manager + * @param heartbeatListener Listener which will be notified upon heartbeat timeouts for + * registered targets + * @param mainThreadExecutor Scheduled executor to be used for scheduling heartbeat timeouts + * @param log Logger to be used for the logging + * @param Type of the incoming payload + * @param Type of the outgoing payload + * @return A new HeartbeatManager instance + */ + public HeartbeatManager createHeartbeatManager( + InstanceID instanceID, + HeartbeatListener heartbeatListener, + ScheduledExecutor mainThreadExecutor, + Logger log) { + + return new HeartbeatManagerImpl<>( + heartbeatTimeout, instanceID, heartbeatListener, mainThreadExecutor, log); + } + + /** + * Creates a heartbeat manager which actively sends heartbeats to monitoring targets. + * + * @param instanceID Resource Id which identifies the owner of the heartbeat manager + * @param heartbeatListener Listener which will be notified upon heartbeat timeouts for + * registered targets + * @param mainThreadExecutor Scheduled executor to be used for scheduling heartbeat timeouts and + * periodically send heartbeat requests + * @param log Logger to be used for the logging + * @param Type of the incoming payload + * @param Type of the outgoing payload + * @return A new HeartbeatManager instance which actively sends heartbeats + */ + public HeartbeatManager createHeartbeatManagerSender( + InstanceID instanceID, + HeartbeatListener heartbeatListener, + ScheduledExecutor mainThreadExecutor, + Logger log) { + + return new HeartbeatManagerSenderImpl<>( + heartbeatInterval, + heartbeatTimeout, + instanceID, + heartbeatListener, + mainThreadExecutor, + log); + } + + public long getHeartbeatInterval() { + return heartbeatInterval; + } + + public long getHeartbeatTimeout() { + return heartbeatTimeout; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatServicesUtils.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatServicesUtils.java new file mode 100644 index 00000000..e5607962 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatServicesUtils.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.config.HeartbeatOptions; + +/** Utils for creating HeartServices. */ +public class HeartbeatServicesUtils { + + public static HeartbeatServices createManagerJobHeartbeatServices(Configuration configuration) { + long heartbeatTimeout = + configuration.getDuration(HeartbeatOptions.HEARTBEAT_JOB_TIMEOUT).toMillis(); + + long heartbeatInterval = + configuration.getDuration(HeartbeatOptions.HEARTBEAT_JOB_INTERVAL).toMillis(); + + return new HeartbeatServices(heartbeatInterval, heartbeatTimeout); + } + + public static HeartbeatServices createManagerWorkerHeartbeatServices( + Configuration configuration) { + long heartbeatInterval = + configuration.getDuration(HeartbeatOptions.HEARTBEAT_WORKER_INTERVAL).toMillis(); + + long heartbeatTimeout = + configuration.getDuration(HeartbeatOptions.HEARTBEAT_WORKER_TIMEOUT).toMillis(); + + return new HeartbeatServices(heartbeatInterval, heartbeatTimeout); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatTarget.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatTarget.java new file mode 100644 index 00000000..370fe279 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatTarget.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +/** + * Interface for components which can be sent heartbeats and from which one can request a heartbeat + * response. Both the heartbeat response and the heartbeat request can carry a payload. This payload + * is reported to the heartbeat target and contains additional information. The payload can be empty + * which is indicated by a null value. + * + * @param Type of the payload which is sent to the heartbeat target + */ +public interface HeartbeatTarget { + + /** + * Sends a heartbeat response to the target. Each heartbeat response can carry a payload which + * contains additional information for the heartbeat target. + * + * @param heartbeatOrigin Resource ID identifying the machine for which a heartbeat shall be + * reported. + * @param heartbeatPayload Payload of the heartbeat. Null indicates an empty payload. + */ + void receiveHeartbeat(InstanceID heartbeatOrigin, I heartbeatPayload); + + /** + * Requests a heartbeat from the target. Each heartbeat request can carry a payload which + * contains additional information for the heartbeat target. + * + * @param requestOrigin Resource ID identifying the machine issuing the heartbeat request. + * @param heartbeatPayload Payload of the heartbeat request. Null indicates an empty payload. + */ + void requestHeartbeat(InstanceID requestOrigin, I heartbeatPayload); +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/NoOpHeartbeatManager.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/NoOpHeartbeatManager.java new file mode 100644 index 00000000..5c833009 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/heartbeat/NoOpHeartbeatManager.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +/** + * A {@link HeartbeatManager} implementation which does nothing. + * + * @param ignored + * @param ignored + */ +public class NoOpHeartbeatManager implements HeartbeatManager { + private static final NoOpHeartbeatManager INSTANCE = + new NoOpHeartbeatManager<>(); + + private NoOpHeartbeatManager() {} + + @Override + public void monitorTarget(InstanceID instanceID, HeartbeatTarget heartbeatTarget) {} + + @Override + public void unmonitorTarget(InstanceID instanceID) {} + + @Override + public void stop() {} + + @Override + public long getLastHeartbeatFrom(InstanceID instanceID) { + return 0; + } + + @Override + public void receiveHeartbeat(InstanceID heartbeatOrigin, I heartbeatPayload) {} + + @Override + public void requestHeartbeat(InstanceID requestOrigin, I heartbeatPayload) {} + + @SuppressWarnings("unchecked") + public static NoOpHeartbeatManager getInstance() { + return (NoOpHeartbeatManager) INSTANCE; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/DefaultLeaderElectionService.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/DefaultLeaderElectionService.java new file mode 100644 index 00000000..3f9ab6c3 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/DefaultLeaderElectionService.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.GuardedBy; + +import java.util.UUID; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * Default implementation for leader election service. Composed with different {@link + * LeaderElectionDriver}, we could perform a leader election for the contender, and then persist the + * leader information to various storage. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderelection.DefaultLeaderElectionService). + */ +public class DefaultLeaderElectionService + implements LeaderElectionService, LeaderElectionEventHandler { + + private static final Logger LOG = LoggerFactory.getLogger(DefaultLeaderElectionService.class); + + private final Object lock = new Object(); + + private final LeaderElectionDriverFactory leaderElectionDriverFactory; + + /** The leader contender which applies for leadership. */ + @GuardedBy("lock") + private LeaderContender leaderContender; + + @GuardedBy("lock") + private UUID issuedLeaderSessionID; + + @GuardedBy("lock") + private LeaderInformation confirmedLeaderInfo; + + @GuardedBy("lock") + private boolean running; + + @GuardedBy("lock") + private LeaderElectionDriver leaderElectionDriver; + + public DefaultLeaderElectionService(LeaderElectionDriverFactory leaderElectionDriverFactory) { + checkArgument(leaderElectionDriverFactory != null, "Must be not null."); + this.leaderElectionDriverFactory = leaderElectionDriverFactory; + } + + @Override + public final void start(LeaderContender contender) throws Exception { + checkArgument(contender != null, "Contender must not be null."); + + synchronized (lock) { + checkState(leaderContender == null, "Contender was already set."); + leaderContender = contender; + leaderElectionDriver = + leaderElectionDriverFactory.createLeaderElectionDriver( + this, + new LeaderElectionFatalErrorHandler(), + leaderContender.getDescription()); + running = true; + LOG.info("Starting DefaultLeaderElectionService with {}.", leaderElectionDriver); + } + } + + @Override + public final void stop() throws Exception { + LOG.info("Stopping DefaultLeaderElectionService."); + + synchronized (lock) { + if (!running) { + return; + } + running = false; + + clearConfirmedLeaderInformation(); + leaderElectionDriver.close(); + } + } + + @Override + public void confirmLeadership(LeaderInformation leaderInfo) { + LOG.info("Confirm leader {}.", leaderInfo); + UUID leaderSessionID = checkNotNull(leaderInfo.getLeaderSessionID()); + + synchronized (lock) { + if (hasLeadership(leaderSessionID)) { + if (running) { + confirmLeaderInformation(leaderInfo); + return; + } + LOG.debug( + "Ignoring the leader session Id {} confirmation, since the " + + "LeaderElectionService has already been stopped.", + leaderSessionID); + return; + } + + // Received an old confirmation call + if (!leaderSessionID.equals(issuedLeaderSessionID)) { + LOG.warn( + "Receive an old confirmation call of leader session ID {}, current " + + "issued session ID is {}", + leaderSessionID, + issuedLeaderSessionID); + } else { + LOG.warn( + "The leader session ID {} was confirmed even though the " + + "corresponding contender was not elected as the leader.", + leaderSessionID); + } + } + } + + @Override + public boolean hasLeadership(UUID leaderSessionId) { + synchronized (lock) { + if (running) { + return leaderElectionDriver.hasLeadership() + && leaderSessionId.equals(issuedLeaderSessionID); + } + LOG.debug("hasLeadership is called after the service is stopped, returning false."); + return false; + } + } + + private void confirmLeaderInformation(LeaderInformation leaderInfo) { + assert Thread.holdsLock(lock); + confirmedLeaderInfo = leaderInfo; + leaderElectionDriver.writeLeaderInformation(leaderInfo); + } + + private void clearConfirmedLeaderInformation() { + assert Thread.holdsLock(lock); + confirmedLeaderInfo = null; + } + + @Override + public void onGrantLeadership() { + synchronized (lock) { + if (!running) { + LOG.warn( + "Ignoring the grant leadership notification for {} has already been closed.", + leaderElectionDriver); + return; + } + + issuedLeaderSessionID = UUID.randomUUID(); + clearConfirmedLeaderInformation(); + + LOG.info( + "Grant leadership to contender {} with session ID {}.", + leaderContender.getDescription(), + issuedLeaderSessionID); + leaderContender.grantLeadership(issuedLeaderSessionID); + } + } + + @Override + public void onRevokeLeadership() { + synchronized (lock) { + if (!running) { + LOG.warn( + "Ignoring the revoke leadership notification since {} has already been closed.", + leaderElectionDriver); + return; + } + + LOG.info( + "Revoke leadership of {}-{}.", + leaderContender.getDescription(), + confirmedLeaderInfo); + + issuedLeaderSessionID = null; + clearConfirmedLeaderInformation(); + leaderContender.revokeLeadership(); + + LOG.info("Clearing the leader information on {}.", leaderElectionDriver); + // Clear the old leader information on the external storage + leaderElectionDriver.writeLeaderInformation(LeaderInformation.empty()); + } + } + + @Override + public void onLeaderInformationChange(LeaderInformation leaderInfo) { + synchronized (lock) { + if (!running) { + LOG.warn( + "Ignoring change notification since the {} has already been closed.", + leaderElectionDriver); + return; + } + + LOG.info( + "Leader node changed while {} is the leader. Old leader information {}; New " + + "leader information {}.", + leaderContender.getDescription(), + confirmedLeaderInfo, + leaderInfo); + + if (confirmedLeaderInfo == null) { + return; + } + + if (leaderInfo.isEmpty() || !leaderInfo.equals(confirmedLeaderInfo)) { + LOG.info( + "Writing leader information {} of {}.", + confirmedLeaderInfo, + leaderContender.getDescription()); + leaderElectionDriver.writeLeaderInformation(confirmedLeaderInfo); + } + } + } + + private class LeaderElectionFatalErrorHandler implements FatalErrorHandler { + + @Override + public void onFatalError(Throwable throwable) { + synchronized (lock) { + if (!running) { + LOG.debug("Ignoring error notification since the service has been stopped."); + return; + } + leaderContender.handleError(throwable); + } + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/DefaultLeaderRetrievalService.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/DefaultLeaderRetrievalService.java new file mode 100644 index 00000000..61446619 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/DefaultLeaderRetrievalService.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.GuardedBy; + +import java.util.Objects; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * The counterpart to the {@link DefaultLeaderElectionService}. Composed with different {@link + * LeaderRetrievalDriver}, we could retrieve the leader information from different storage. The + * leader address as well as the current leader session ID will be retrieved from {@link + * LeaderRetrievalDriver}. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderretrieval.DefaultLeaderRetrievalService). + */ +public class DefaultLeaderRetrievalService + implements LeaderRetrievalService, LeaderRetrievalEventHandler { + + private static final Logger LOG = LoggerFactory.getLogger(DefaultLeaderRetrievalService.class); + + private final Object lock = new Object(); + + private final LeaderRetrievalDriverFactory leaderRetrievalDriverFactory; + + @GuardedBy("lock") + private LeaderInformation lastLeaderInfo; + + @GuardedBy("lock") + private boolean running; + + /** Listener which will be notified about leader changes. */ + @GuardedBy("lock") + private LeaderRetrievalListener leaderListener; + + @GuardedBy("lock") + private LeaderRetrievalDriver leaderRetrievalDriver; + + public DefaultLeaderRetrievalService( + LeaderRetrievalDriverFactory leaderRetrievalDriverFactory) { + checkArgument(leaderRetrievalDriverFactory != null, "Must be not null."); + this.leaderRetrievalDriverFactory = leaderRetrievalDriverFactory; + this.lastLeaderInfo = LeaderInformation.empty(); + } + + @Override + public void start(LeaderRetrievalListener listener) throws Exception { + checkArgument(listener != null, "Listener must not be null."); + + synchronized (lock) { + checkState( + leaderListener == null, + "DefaultLeaderRetrievalService can only be started once."); + running = true; + leaderListener = listener; + leaderRetrievalDriver = + leaderRetrievalDriverFactory.createLeaderRetrievalDriver( + this, + new DefaultLeaderRetrievalService.LeaderRetrievalFatalErrorHandler()); + LOG.info("Starting DefaultLeaderRetrievalService with {}.", leaderRetrievalDriver); + } + } + + @Override + public void stop() throws Exception { + LOG.info("Stopping DefaultLeaderRetrievalService."); + + synchronized (lock) { + if (!running) { + return; + } + running = false; + + leaderRetrievalDriver.close(); + } + } + + /** + * Called by specific {@link LeaderRetrievalDriver} to notify leader address. + * + * @param leaderInfo New notified leader information. The exception will be handled by leader + * listener. + */ + @Override + public void notifyLeaderAddress(LeaderInformation leaderInfo) { + synchronized (lock) { + if (!running) { + LOG.debug( + "Ignoring notification since the {} has already been closed.", + leaderRetrievalDriver); + return; + } + + LOG.info("New leader information: {}.", leaderInfo); + if (!Objects.equals(lastLeaderInfo, leaderInfo)) { + lastLeaderInfo = leaderInfo; + // Notify the listener only when the leader is truly changed. + leaderListener.notifyLeaderAddress(leaderInfo); + } + } + } + + private class LeaderRetrievalFatalErrorHandler implements FatalErrorHandler { + + @Override + public void onFatalError(Throwable throwable) { + synchronized (lock) { + if (!running) { + LOG.debug("Ignoring error notification since the service has been stopped."); + return; + } + leaderListener.handleError(new Exception(throwable)); + } + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/HaMode.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/HaMode.java new file mode 100644 index 00000000..ba43f3ad --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/HaMode.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; + +/** + * High availability mode for remote shuffle cluster execution. Currently, supported modes are: + * + *

- NONE: No high availability. - ZooKeeper: ShuffleManager high availability via ZooKeeper is + * used to select a leader among a group of ShuffleManager. This ShuffleManager is responsible for + * managing the ShuffleWorker. Upon failure of the leader a new leader is elected which will take + * over the responsibilities of the old leader. - FACTORY_CLASS: Use implementation * of {@link + * HaServicesFactory} specified in configuration property high-availability + */ +public enum HaMode { + NONE(false), + ZOOKEEPER(true), + FACTORY_CLASS(true); + + private final boolean haActive; + + HaMode(boolean haActive) { + this.haActive = haActive; + } + + public static HaMode fromConfig(Configuration config) { + String haMode = config.getString(HighAvailabilityOptions.HA_MODE); + + if (haMode == null) { + return HaMode.NONE; + } else if (haMode.equalsIgnoreCase("NONE")) { + // Map old default to new default + return HaMode.NONE; + } else { + try { + return HaMode.valueOf(haMode.toUpperCase()); + } catch (IllegalArgumentException e) { + return FACTORY_CLASS; + } + } + } + + /** + * Returns true if the defined recovery mode supports high availability. + * + * @param configuration Configuration which contains the recovery mode + * @return true if high availability is supported by the recovery mode, otherwise false + */ + public static boolean isHighAvailabilityModeActivated(Configuration configuration) { + HaMode mode = fromConfig(configuration); + return mode.haActive; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/HaServiceUtils.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/HaServiceUtils.java new file mode 100644 index 00000000..5eab43c7 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/HaServiceUtils.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.embeded.EmbeddedHaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.standalone.StandaloneHaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperHaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperUtils; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; +import com.alibaba.flink.shuffle.rpc.utils.AkkaRpcServiceUtils; + +import org.apache.commons.lang3.tuple.Pair; + +import java.util.concurrent.Executor; + +/** Utils class to instantiate {@link HaServices} implementations. */ +public class HaServiceUtils { + + private static final String SHUFFLE_MANAGER_NAME = "shufflemanager"; + + /** Used by tests only. */ + public static HaServices createAvailableOrEmbeddedServices( + Configuration config, Executor executor) throws Exception { + + HaMode highAvailabilityMode = HaMode.fromConfig(config); + + switch (highAvailabilityMode) { + case NONE: + return new EmbeddedHaServices(executor); + case ZOOKEEPER: + return new ZooKeeperHaServices( + config, ZooKeeperUtils.startCuratorFramework(config)); + case FACTORY_CLASS: + return createCustomHAServices(config); + default: + throw new Exception( + "High availability mode " + highAvailabilityMode + " is not supported."); + } + } + + public static HaServices createHAServices(Configuration config) throws Exception { + HaMode highAvailabilityMode = HaMode.fromConfig(config); + switch (highAvailabilityMode) { + case NONE: + final Pair hostnamePort = getShuffleManagerAddress(config); + + final String shuffleManagerRpcUrl = + AkkaRpcServiceUtils.getRpcUrl( + hostnamePort.getLeft(), + hostnamePort.getRight(), + AkkaRpcServiceUtils.createWildcardName(SHUFFLE_MANAGER_NAME), + AkkaRpcServiceUtils.AkkaProtocol.TCP); + return new StandaloneHaServices(shuffleManagerRpcUrl); + case ZOOKEEPER: + return new ZooKeeperHaServices( + config, ZooKeeperUtils.startCuratorFramework(config)); + case FACTORY_CLASS: + return createCustomHAServices(config); + default: + throw new Exception("Recovery mode " + highAvailabilityMode + " is not supported."); + } + } + + /** + * Returns the ShuffleManager's hostname and port extracted from the given {@link + * org.apache.flink.configuration.Configuration}. + * + * @param configuration Configuration to extract the ShuffleManager's address from + * @return The ShuffleManager's hostname and port + * @throws ConfigurationException if the ShuffleManager's address cannot be extracted from the + * configuration + */ + public static Pair getShuffleManagerAddress(Configuration configuration) + throws ConfigurationException { + + final String hostname = configuration.getString(ManagerOptions.RPC_ADDRESS); + final int port = configuration.getInteger(ManagerOptions.RPC_PORT); + + if (hostname == null) { + throw new ConfigurationException( + "Config parameter '" + + ManagerOptions.RPC_ADDRESS.key() + + "' is missing (hostname/address of ShuffleManager to connect to)."); + } + + if (!CommonUtils.isValidHostPort(port)) { + throw new ConfigurationException( + "Invalid value for '" + + ManagerOptions.RPC_PORT.key() + + "' (port of the ShuffleManager actor system) : " + + port + + ". it must be greater than 0 and less than 65536."); + } + + return Pair.of(hostname, port); + } + + private static HaServices createCustomHAServices(Configuration config) throws Exception { + Class clazz = Class.forName(config.getString(HighAvailabilityOptions.HA_MODE)); + HaServicesFactory haServicesFactory = (HaServicesFactory) clazz.newInstance(); + + try { + return haServicesFactory.createHAServices(config); + } catch (Exception e) { + throw new Exception( + String.format( + "Could not create the ha services from the instantiated HighAvailabilityServicesFactory %s.", + haServicesFactory.getClass().getName()), + e); + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/HaServices.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/HaServices.java new file mode 100644 index 00000000..a6a95761 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/HaServices.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +import java.util.UUID; + +/** + * The HighAvailabilityServices gives access to all services needed for a highly-available setup. + */ +public interface HaServices extends AutoCloseable { + + // ------------------------------------------------------------------------ + // Constants + // ------------------------------------------------------------------------ + + /** + * This UUID should be used when no proper leader election happens, but a simple pre-configured + * leader is used. That is for example the case in non-highly-available standalone setups. + */ + UUID DEFAULT_LEADER_ID = new UUID(0, 0); + + /** Creates the shuffle manager leader retriever for the shuffle worker and shuffle client. */ + LeaderRetrievalService createLeaderRetrievalService(LeaderReceptor receptor); + + /** + * Creates the leader election service for the shuffle manager. + * + * @return Leader election service for the shuffle manager election. + */ + LeaderElectionService createLeaderElectionService(); + + /** + * Closes the high availability services (releasing all resources) and deletes all data stored + * by these services in external stores. + * + *

If an exception occurs during cleanup, this method will attempt to continue the cleanup + * and report exceptions only after all cleanup steps have been attempted. + * + * @throws Exception Thrown, if an exception occurred while closing these services or cleaning + * up data stored by them. + */ + void closeAndCleanupAllData() throws Exception; + + /** + * Type of leader information receptor which will retrieve the {@link LeaderInformation} of + * shuffle manager. + */ + enum LeaderReceptor { + SHUFFLE_CLIENT, + SHUFFLE_WORKER + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/HaServicesFactory.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/HaServicesFactory.java new file mode 100644 index 00000000..a705f650 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/HaServicesFactory.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +import com.alibaba.flink.shuffle.common.config.Configuration; + +/** Factory interface for {@link HaServices}. */ +public interface HaServicesFactory { + + HaServices createHAServices(Configuration configuration) throws Exception; +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderContender.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderContender.java new file mode 100644 index 00000000..9578e6cb --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderContender.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +import java.util.UUID; + +/** + * Interface which has to be implemented to take part in the leader election process of the {@link + * LeaderElectionService}. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderelection.LeaderContender). + */ +public interface LeaderContender { + + /** + * Callback method which is called by the {@link LeaderElectionService} upon selecting this + * instance as the new leader. The method is called with the new leader session ID. + * + * @param leaderSessionID New leader session ID. + */ + void grantLeadership(UUID leaderSessionID); + + /** + * Callback method which is called by the {@link LeaderElectionService} upon revoking the + * leadership of a former leader. This might happen in case that multiple contenders have been + * granted leadership. + */ + void revokeLeadership(); + + /** + * Callback method which is called by {@link LeaderElectionService} in case of an error in the + * service thread. + */ + void handleError(Throwable throwable); + + /** + * Returns the description of the {@link LeaderContender} for logging purposes. + * + * @return Description of this contender. + */ + default String getDescription() { + return "LeaderContender: " + getClass().getSimpleName(); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderElectionDriver.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderElectionDriver.java new file mode 100644 index 00000000..7492a357 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderElectionDriver.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +/** + * A {@link LeaderElectionDriver} is responsible for performing the leader election and storing the + * leader information. All the leader internal state is guarded by lock in {@link + * LeaderElectionService}. Different driver implementations do not need to care about the lock. And + * it should use {@link LeaderElectionEventHandler} if it wants to respond to the leader change + * events. + * + *

Important: The {@link LeaderElectionDriver} could not guarantee that there is + * no {@link LeaderElectionEventHandler} callbacks happen after {@link #close()}. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderelection.LeaderElectionDriver). + */ +public interface LeaderElectionDriver { + + /** + * Write the current leader information to external persistent storage(e.g. Zookeeper, + * Kubernetes ConfigMap). This is a blocking IO operation. The write operation takes effect only + * when the driver still has the leadership. + * + * @param leaderInformation current leader information. It could be {@link + * LeaderInformation#empty()}, which means the caller want to clear the leader information + * on external storage. Please remember that the clear operation should only happen before a + * new leader is elected and has written his leader information on the storage. Otherwise, + * we may have a risk to wrongly update the storage with empty leader information. + */ + void writeLeaderInformation(LeaderInformation leaderInformation); + + /** + * Check whether the driver still have the leadership in the distributed coordination system. + * + * @return Return true if the driver has leadership. + */ + boolean hasLeadership(); + + /** Close the services used for leader election. */ + void close() throws Exception; +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderElectionDriverFactory.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderElectionDriverFactory.java new file mode 100644 index 00000000..191fd8f7 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderElectionDriverFactory.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; + +/** + * Factory for creating {@link LeaderElectionDriver} with different implementation. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderelection.LeaderElectionDriverFactory). + */ +public interface LeaderElectionDriverFactory { + + /** + * Create a specific {@link LeaderElectionDriver} and start the necessary services. For example, + * LeaderLatch and NodeCache in Zookeeper, KubernetesLeaderElector and ConfigMap watcher in + * Kubernetes. + * + * @param leaderEventHandler handler for the leader election driver to process leader events. + * @param leaderContenderDescription leader contender description. + * @param fatalErrorHandler fatal error handler + * @throws Exception when create a specific {@link LeaderElectionDriver} implementation and + * start the necessary services. + */ + LeaderElectionDriver createLeaderElectionDriver( + LeaderElectionEventHandler leaderEventHandler, + FatalErrorHandler fatalErrorHandler, + String leaderContenderDescription) + throws Exception; +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderElectionEventHandler.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderElectionEventHandler.java new file mode 100644 index 00000000..0924ee46 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderElectionEventHandler.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +/** + * Interface which should be implemented to respond to leader changes in {@link + * LeaderElectionDriver}. + * + *

Important: The {@link LeaderElectionDriver} could not guarantee that there is + * no {@link LeaderElectionEventHandler} callbacks happen after {@link + * LeaderElectionDriver#close()}. This means that the implementor of {@link + * LeaderElectionEventHandler} is responsible for filtering out spurious callbacks(e.g. after close + * has been called on {@link LeaderElectionDriver}). + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderelection.LeaderElectionEventHandler). + */ +public interface LeaderElectionEventHandler { + + /** Called by specific {@link LeaderElectionDriver} when the leadership is granted. */ + void onGrantLeadership(); + + /** Called by specific {@link LeaderElectionDriver} when the leadership is revoked. */ + void onRevokeLeadership(); + + /** + * Called by specific {@link LeaderElectionDriver} when the leader information is changed. Then + * the {@link LeaderElectionService} could write the leader information again if necessary. This + * method should only be called when {@link LeaderElectionDriver#hasLeadership()} is true. + * Duplicated leader change events could happen, so the implementation should check whether the + * passed leader information is really different with internal confirmed leader information. + * + * @param leaderInfo leader information which contains leader session id and leader address. + */ + void onLeaderInformationChange(LeaderInformation leaderInfo); +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderElectionService.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderElectionService.java new file mode 100644 index 00000000..fa32c0f3 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderElectionService.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +import java.util.UUID; + +/** + * Interface for a service which allows to elect a leader among a group of contenders. + * + *

Prior to using this service, it has to be started calling the start method. The start method + * takes the contender as a parameter. If there are multiple contenders, then each contender has to + * instantiate its own leader election service. + * + *

Once a contender has been granted leadership he has to confirm the received leader session ID + * by calling the method {@link #confirmLeadership(LeaderInformation)} ()}. This will notify the + * leader election service, that the contender has accepted the leadership specified and that the + * leader session id as well as the leader address can now be published for leader retrieval + * services. + * + *

This class is copied and modified from Apache Flink + * (org.apache.flink.runtime.leaderelection.LeaderElectionService). + */ +public interface LeaderElectionService { + + /** + * Starts the leader election service. This method can only be called once. + * + * @param contender LeaderContender which applies for the leadership. + */ + void start(LeaderContender contender) throws Exception; + + /** Stops the leader election service. */ + void stop() throws Exception; + + /** + * Confirms that the {@link LeaderContender} has accepted the leadership identified by the given + * leader session id. It also publishes the leader address under which the leader is reachable. + * + *

The rational behind this method is to establish an order between setting the new leader + * session ID in the {@link LeaderContender} and publishing the new leader session ID as well as + * the leader address to the leader retrieval services. + */ + void confirmLeadership(LeaderInformation leaderInfo); + + /** + * Returns true if the {@link LeaderContender} with which the service has been started owns + * currently the leadership under the given leader session id. + * + * @param leaderSessionId identifying the current leader + * @return true if the associated {@link LeaderContender} is the leader, otherwise false + */ + boolean hasLeadership(UUID leaderSessionId); +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderInformation.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderInformation.java new file mode 100644 index 00000000..c1fb918e --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderInformation.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.ProtocolUtils; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.Serializable; +import java.util.Objects; +import java.util.Properties; +import java.util.UUID; + +/** + * Information about leader including the confirmed leader session id and leader address. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderelection.LeaderInformation). + */ +public class LeaderInformation implements Serializable { + + private static final long serialVersionUID = -5345049443329244907L; + + private static final String IS_EMPTY_KEY = "isEmpty"; + + private static final String PROTOCOL_VERSION_KEY = "protocolVersion"; + + private static final String SUPPORTED_VERSION_KEY = "supportedVersion"; + + private static final String LEADER_ID_KEY = "leaderSessionID"; + + private static final String LEADER_ADDRESS_KEY = "leaderAddress"; + + private final int protocolVersion; + + private final int supportedVersion; + + private final UUID leaderSessionID; + + private final String leaderAddress; + + private static final LeaderInformation EMPTY = + new LeaderInformation(HaServices.DEFAULT_LEADER_ID, ""); + + public LeaderInformation(UUID leaderSessionID, String leaderAddress) { + this( + ProtocolUtils.currentProtocolVersion(), + ProtocolUtils.compatibleVersion(), + leaderSessionID, + leaderAddress); + } + + public LeaderInformation( + int protocolVersion, int supportedVersion, UUID leaderSessionID, String leaderAddress) { + CommonUtils.checkArgument(leaderSessionID != null, "Must be not null."); + CommonUtils.checkArgument(leaderAddress != null, "Must be not null."); + + this.protocolVersion = protocolVersion; + this.supportedVersion = supportedVersion; + this.leaderSessionID = leaderSessionID; + this.leaderAddress = leaderAddress; + } + + public int getProtocolVersion() { + return protocolVersion; + } + + public int getSupportedVersion() { + return supportedVersion; + } + + public UUID getLeaderSessionID() { + return leaderSessionID; + } + + public String getLeaderAddress() { + return leaderAddress; + } + + public boolean isEmpty() { + return this == EMPTY; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof LeaderInformation)) { + return false; + } + + LeaderInformation that = (LeaderInformation) obj; + return Objects.equals(this.leaderSessionID, that.leaderSessionID) + && Objects.equals(this.leaderAddress, that.leaderAddress); + } + + @Override + public int hashCode() { + return Objects.hash(leaderSessionID, leaderAddress); + } + + public static LeaderInformation empty() { + return EMPTY; + } + + public static LeaderInformation fromByteArray(byte[] bytes) throws Exception { + try (ByteArrayInputStream input = new ByteArrayInputStream(bytes)) { + Properties properties = new Properties(); + properties.load(input); + + boolean isEmpty = Boolean.parseBoolean(properties.getProperty(IS_EMPTY_KEY)); + if (isEmpty) { + return EMPTY; + } + + return new LeaderInformation( + Integer.parseInt(properties.getProperty(PROTOCOL_VERSION_KEY)), + Integer.parseInt(properties.getProperty(SUPPORTED_VERSION_KEY)), + UUID.fromString(properties.getProperty(LEADER_ID_KEY)), + properties.getProperty(LEADER_ADDRESS_KEY)); + } + } + + public byte[] toByteArray() throws Exception { + Properties properties = new Properties(); + properties.setProperty(IS_EMPTY_KEY, String.valueOf(isEmpty())); + properties.setProperty(PROTOCOL_VERSION_KEY, String.valueOf(protocolVersion)); + properties.setProperty(SUPPORTED_VERSION_KEY, String.valueOf(supportedVersion)); + properties.setProperty(LEADER_ID_KEY, leaderSessionID.toString()); + properties.setProperty(LEADER_ADDRESS_KEY, leaderAddress); + + try (ByteArrayOutputStream output = new ByteArrayOutputStream()) { + properties.store(output, null); + return output.toByteArray(); + } + } + + @Override + public String toString() { + return String.format( + "LeaderInformation{leaderSessionID=%s, leaderAddress=%s, isEmpty=%s, " + + "protocolVersion=%d, supportedVersion=%d}", + leaderSessionID, leaderAddress, isEmpty(), protocolVersion, supportedVersion); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderRetrievalDriver.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderRetrievalDriver.java new file mode 100644 index 00000000..eb5a90ac --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderRetrievalDriver.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +/** + * A {@link LeaderRetrievalDriver} is responsible for retrieving the current leader which has been + * elected by the {@link LeaderElectionDriver}. + * + *

Important: The {@link LeaderRetrievalDriver} could not guarantee that there + * is no {@link LeaderRetrievalEventHandler} callbacks happen after {@link #close()}. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderretrieval.LeaderRetrievalDriver). + */ +public interface LeaderRetrievalDriver extends AutoCloseable {} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderRetrievalDriverFactory.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderRetrievalDriverFactory.java new file mode 100644 index 00000000..a2687b69 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderRetrievalDriverFactory.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; + +/** + * Factory for creating {@link LeaderRetrievalDriver} with different implementations. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderretrieval.LeaderRetrievalDriverFactory). + */ +public interface LeaderRetrievalDriverFactory { + + /** + * Create a specific {@link LeaderRetrievalDriver} and start the necessary services. For + * example, NodeCache in Zookeeper, ConfigMap watcher in Kubernetes. They could get the leader + * information change events and need to notify the leader listener by {@link + * LeaderRetrievalEventHandler}. + * + * @param leaderEventHandler handler for the leader retrieval driver to notify leader change + * events. + * @param fatalErrorHandler fatal error handler + * @throws Exception when create a specific {@link LeaderRetrievalDriver} implementation and + * start the necessary services. + */ + LeaderRetrievalDriver createLeaderRetrievalDriver( + LeaderRetrievalEventHandler leaderEventHandler, FatalErrorHandler fatalErrorHandler) + throws Exception; +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderRetrievalEventHandler.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderRetrievalEventHandler.java new file mode 100644 index 00000000..7f8fb3f6 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderRetrievalEventHandler.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +/** + * Interface which should be implemented to notify to {@link LeaderInformation} changes in {@link + * LeaderRetrievalDriver}. + * + *

Important: The {@link LeaderRetrievalDriver} could not guarantee that there + * is no {@link LeaderRetrievalEventHandler} callbacks happen after {@link + * LeaderRetrievalDriver#close()}. This means that the implementor of {@link + * LeaderRetrievalEventHandler} is responsible for filtering out spurious callbacks(e.g. after close + * has been called on {@link LeaderRetrievalDriver}). + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderretrieval.LeaderRetrievalEventHandler). + */ +public interface LeaderRetrievalEventHandler { + + /** + * Called by specific {@link LeaderRetrievalDriver} to notify leader address. + * + *

Duplicated leader change events could happen, so the implementation should check whether + * the passed leader information is truly changed with last stored leader information. + * + * @param leaderInfo the new leader information to notify {@link LeaderRetrievalService}. It + * could be {@link LeaderInformation#empty()} if the leader address does not exist in the + * external storage. + */ + void notifyLeaderAddress(LeaderInformation leaderInfo); +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderRetrievalListener.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderRetrievalListener.java new file mode 100644 index 00000000..84f2b6d9 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderRetrievalListener.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +/** + * Classes which want to be notified about a changing leader by the {@link LeaderRetrievalService} + * have to implement this interface. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderretrieval.LeaderRetrievalListener). + */ +public interface LeaderRetrievalListener { + + /** + * This method is called by the {@link LeaderRetrievalService} when a new leader is elected. + * + *

If both arguments are null then it signals that leadership was revoked without a new + * leader having been elected. + */ + void notifyLeaderAddress(LeaderInformation leaderInfo); + + /** + * This method is called by the {@link LeaderRetrievalService} in case of an exception. This + * assures that the {@link LeaderRetrievalListener} is aware of any problems occurring in the + * {@link LeaderRetrievalService} thread. + */ + void handleError(Exception exception); +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderRetrievalService.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderRetrievalService.java new file mode 100644 index 00000000..6bd8948d --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/LeaderRetrievalService.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +/** + * This interface has to be implemented by a service which retrieves the current leader and notifies + * a listener about it. + * + *

Prior to using this service it has to be started by calling the start method. The start method + * also takes the {@link LeaderRetrievalListener} as an argument. The service can only be started + * once. + * + *

The service should be stopped by calling the stop method. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService). + */ +public interface LeaderRetrievalService { + + /** + * Starts the leader retrieval service with the given listener to listen for new leaders. This + * method can only be called once. + * + * @param listener The leader retrieval listener which will be notified about new leaders. + */ + void start(LeaderRetrievalListener listener) throws Exception; + + /** Stops the leader retrieval service. */ + void stop() throws Exception; +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/embeded/EmbeddedHaServices.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/embeded/EmbeddedHaServices.java new file mode 100644 index 00000000..bd4c3326 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/embeded/EmbeddedHaServices.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability.embeded; + +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalService; + +import javax.annotation.Nonnull; + +import java.util.concurrent.Executor; + +/** + * An implementation of the {@link HaServices} for the non-high-availability case where all + * participants run in the same process. + * + *

This implementation has no dependencies on any external services. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.highavailability.nonha.embedded.EmbeddedHaServices). + */ +public class EmbeddedHaServices implements HaServices { + + protected final Object lock = new Object(); + + private final EmbeddedLeaderService embeddedLeaderService; + + private boolean shutdown; + + public EmbeddedHaServices(Executor executor) { + this.embeddedLeaderService = createEmbeddedLeaderService(executor); + } + + @Override + public LeaderRetrievalService createLeaderRetrievalService(LeaderReceptor receptor) { + return embeddedLeaderService.createLeaderRetrievalService(); + } + + @Override + public LeaderElectionService createLeaderElectionService() { + return embeddedLeaderService.createLeaderElectionService(); + } + + @Nonnull + private EmbeddedLeaderService createEmbeddedLeaderService(Executor executor) { + return new EmbeddedLeaderService(executor); + } + + @Override + public void close() { + synchronized (lock) { + if (!shutdown) { + shutdown = true; + } + + embeddedLeaderService.shutdown(); + } + } + + @Override + public void closeAndCleanupAllData() { + this.close(); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/embeded/EmbeddedLeaderService.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/embeded/EmbeddedLeaderService.java new file mode 100644 index 00000000..aa8eb4a6 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/embeded/EmbeddedLeaderService.java @@ -0,0 +1,567 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability.embeded; + +import com.alibaba.flink.shuffle.common.utils.FutureUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderContender; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalListener; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalService; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.concurrent.GuardedBy; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * A simple leader election service, which selects a leader among contenders and notifies listeners. + * + *

An election service for contenders can be created via {@link #createLeaderElectionService()}, + * a listener service for leader observers can be created via {@link + * #createLeaderRetrievalService()}. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.highavailability.nonha.embedded.EmbeddedLeaderService). + */ +public class EmbeddedLeaderService { + + private static final Logger LOG = LoggerFactory.getLogger(EmbeddedLeaderService.class); + + private final Object lock = new Object(); + + private final Executor notificationExecutor; + + private final Set allLeaderContenders; + + private final Set listeners; + + /** proposed leader, which has been notified of leadership grant, but has not confirmed. */ + private EmbeddedLeaderService.EmbeddedLeaderElectionService currentLeaderProposed; + + /** actual leader that has confirmed leadership and of which listeners have been notified. */ + private EmbeddedLeaderService.EmbeddedLeaderElectionService currentLeaderConfirmed; + + private volatile LeaderInformation currentLeaderInfo; + + /** flag marking the service as terminated. */ + private boolean shutdown; + + // ------------------------------------------------------------------------ + + public EmbeddedLeaderService(Executor notificationsDispatcher) { + this.notificationExecutor = checkNotNull(notificationsDispatcher); + this.allLeaderContenders = new HashSet<>(); + this.listeners = new HashSet<>(); + } + + // ------------------------------------------------------------------------ + // shutdown and errors + // ------------------------------------------------------------------------ + + /** + * Shuts down this leader election service. + * + *

This method does not perform a clean revocation of the leader status and no notification + * to any leader listeners. It simply notifies all contenders and listeners that the service is + * no longer available. + */ + public void shutdown() { + synchronized (lock) { + shutdownInternally(new Exception("Leader election service is shutting down")); + } + } + + public boolean isShutdown() { + synchronized (lock) { + return shutdown; + } + } + + private void fatalError(Throwable error) { + LOG.error( + "Embedded leader election service encountered a fatal error. Shutting down service.", + error); + + synchronized (lock) { + shutdownInternally( + new Exception( + "Leader election service is shutting down after a fatal error", error)); + } + } + + @GuardedBy("lock") + private void shutdownInternally(Exception exceptionForHandlers) { + assert Thread.holdsLock(lock); + + if (!shutdown) { + // clear all leader status + currentLeaderProposed = null; + currentLeaderConfirmed = null; + currentLeaderInfo = LeaderInformation.empty(); + + // fail all registered listeners + for (EmbeddedLeaderService.EmbeddedLeaderElectionService service : + allLeaderContenders) { + service.shutdown(exceptionForHandlers); + } + allLeaderContenders.clear(); + + // fail all registered listeners + for (EmbeddedLeaderService.EmbeddedLeaderRetrievalService service : listeners) { + service.shutdown(exceptionForHandlers); + } + listeners.clear(); + + shutdown = true; + } + } + + // ------------------------------------------------------------------------ + // creating contenders and listeners + // ------------------------------------------------------------------------ + + public LeaderElectionService createLeaderElectionService() { + checkState(!shutdown, "leader election service is shut down"); + return new EmbeddedLeaderService.EmbeddedLeaderElectionService(); + } + + public LeaderRetrievalService createLeaderRetrievalService() { + checkState(!shutdown, "leader election service is shut down"); + return new EmbeddedLeaderService.EmbeddedLeaderRetrievalService(); + } + + // ------------------------------------------------------------------------ + // adding and removing contenders & listeners + // ------------------------------------------------------------------------ + + /** Callback from leader contenders when they start their service. */ + private void addContender( + EmbeddedLeaderService.EmbeddedLeaderElectionService service, + LeaderContender contender) { + synchronized (lock) { + checkState(!shutdown, "leader election service is shut down"); + checkState(!service.running, "leader election service is already started"); + + try { + if (!allLeaderContenders.add(service)) { + throw new IllegalStateException( + "leader election service was added to this service multiple times"); + } + + service.contender = contender; + service.running = true; + + updateLeader() + .whenComplete( + (aVoid, throwable) -> { + if (throwable != null) { + fatalError(throwable); + } + }); + } catch (Throwable t) { + fatalError(t); + } + } + } + + /** Callback from leader contenders when they stop their service. */ + private void removeContender(EmbeddedLeaderService.EmbeddedLeaderElectionService service) { + synchronized (lock) { + // if the service was not even started, simply do nothing + if (!service.running || shutdown) { + return; + } + + try { + if (!allLeaderContenders.remove(service)) { + throw new IllegalStateException( + "leader election service does not belong to this service"); + } + + // stop the service + service.contender = null; + service.running = false; + service.isLeader = false; + + // if that was the current leader, unset its status + if (currentLeaderConfirmed == service) { + currentLeaderConfirmed = null; + currentLeaderInfo = LeaderInformation.empty(); + } + if (currentLeaderProposed == service) { + currentLeaderProposed = null; + currentLeaderInfo = LeaderInformation.empty(); + } + + updateLeader() + .whenComplete( + (aVoid, throwable) -> { + if (throwable != null) { + fatalError(throwable); + } + }); + } catch (Throwable t) { + fatalError(t); + } + } + } + + /** Callback from leader contenders when they confirm a leader grant. */ + private void confirmLeader( + EmbeddedLeaderService.EmbeddedLeaderElectionService service, + LeaderInformation leaderInfo) { + synchronized (lock) { + // if the service was shut down in the meantime, ignore this confirmation + if (!service.running || shutdown) { + return; + } + + try { + // check if the confirmation is for the same grant, or whether it is a stale grant + if (service == currentLeaderProposed + && currentLeaderInfo + .getLeaderSessionID() + .equals(leaderInfo.getLeaderSessionID())) { + LOG.info("Received confirmation of leadership {}.", leaderInfo); + + // mark leadership + currentLeaderConfirmed = service; + currentLeaderInfo = leaderInfo; + currentLeaderProposed = null; + + // notify all listeners + notifyAllListeners(leaderInfo); + } else { + LOG.debug( + "Received confirmation of leadership for a stale leadership grant. Ignoring."); + } + } catch (Throwable t) { + fatalError(t); + } + } + } + + private CompletableFuture notifyAllListeners(LeaderInformation leaderInfo) { + final List> notifyListenerFutures = + new ArrayList<>(listeners.size()); + + for (EmbeddedLeaderService.EmbeddedLeaderRetrievalService listener : listeners) { + notifyListenerFutures.add(notifyListener(leaderInfo, listener.listener)); + } + + return FutureUtils.waitForAll(notifyListenerFutures); + } + + @GuardedBy("lock") + private CompletableFuture updateLeader() { + // this must be called under the lock + assert Thread.holdsLock(lock); + + if (currentLeaderConfirmed == null && currentLeaderProposed == null) { + // we need a new leader + if (allLeaderContenders.isEmpty()) { + // no new leader available, tell everyone that there is no leader currently + return notifyAllListeners(LeaderInformation.empty()); + } else { + // propose a leader and ask it + final UUID leaderSessionId = UUID.randomUUID(); + EmbeddedLeaderService.EmbeddedLeaderElectionService leaderService = + allLeaderContenders.iterator().next(); + + currentLeaderInfo = new LeaderInformation(leaderSessionId, ""); + currentLeaderProposed = leaderService; + currentLeaderProposed.isLeader = true; + + LOG.info( + "Proposing leadership to contender {}", + leaderService.contender.getDescription()); + + return execute( + new EmbeddedLeaderService.GrantLeadershipCall( + leaderService.contender, leaderSessionId, LOG)); + } + } else { + return CompletableFuture.completedFuture(null); + } + } + + private CompletableFuture notifyListener( + LeaderInformation leaderInfo, LeaderRetrievalListener listener) { + return CompletableFuture.runAsync( + new EmbeddedLeaderService.NotifyOfLeaderCall(leaderInfo, listener, LOG), + notificationExecutor); + } + + private void addListener( + EmbeddedLeaderService.EmbeddedLeaderRetrievalService service, + LeaderRetrievalListener listener) { + synchronized (lock) { + checkState(!shutdown, "leader election service is shut down"); + checkState(!service.running, "leader retrieval service is already started"); + + try { + if (!listeners.add(service)) { + throw new IllegalStateException( + "leader retrieval service was added to this service multiple times"); + } + + service.listener = listener; + service.running = true; + + // if we already have a leader, immediately notify this new listener + if (currentLeaderConfirmed != null) { + notifyListener(currentLeaderInfo, listener); + } + } catch (Throwable t) { + fatalError(t); + } + } + } + + private void removeListener(EmbeddedLeaderService.EmbeddedLeaderRetrievalService service) { + synchronized (lock) { + // if the service was not even started, simply do nothing + if (!service.running || shutdown) { + return; + } + + try { + if (!listeners.remove(service)) { + throw new IllegalStateException( + "leader retrieval service does not belong to this service"); + } + + // stop the service + service.listener = null; + service.running = false; + } catch (Throwable t) { + fatalError(t); + } + } + } + + CompletableFuture grantLeadership() { + synchronized (lock) { + if (shutdown) { + return getShutDownFuture(); + } + + return updateLeader(); + } + } + + private CompletableFuture getShutDownFuture() { + return FutureUtils.completedExceptionally( + new Exception("EmbeddedLeaderService has been shut down.")); + } + + CompletableFuture revokeLeadership() { + synchronized (lock) { + if (shutdown) { + return getShutDownFuture(); + } + + if (currentLeaderProposed != null || currentLeaderConfirmed != null) { + final EmbeddedLeaderService.EmbeddedLeaderElectionService leaderService; + + if (currentLeaderConfirmed != null) { + leaderService = currentLeaderConfirmed; + } else { + leaderService = currentLeaderProposed; + } + + LOG.info("Revoking leadership of {}.", leaderService.contender); + leaderService.isLeader = false; + CompletableFuture revokeLeadershipCallFuture = + execute( + new EmbeddedLeaderService.RevokeLeadershipCall( + leaderService.contender)); + + CompletableFuture notifyAllListenersFuture = + notifyAllListeners(LeaderInformation.empty()); + + currentLeaderProposed = null; + currentLeaderConfirmed = null; + currentLeaderInfo = LeaderInformation.empty(); + + return CompletableFuture.allOf( + revokeLeadershipCallFuture, notifyAllListenersFuture); + } else { + return CompletableFuture.completedFuture(null); + } + } + } + + private CompletableFuture execute(Runnable runnable) { + return CompletableFuture.runAsync(runnable, notificationExecutor); + } + + // ------------------------------------------------------------------------ + // election and retrieval service implementations + // ------------------------------------------------------------------------ + + private class EmbeddedLeaderElectionService implements LeaderElectionService { + + volatile LeaderContender contender; + + volatile boolean isLeader; + + volatile boolean running; + + @Override + public void start(LeaderContender contender) throws Exception { + checkNotNull(contender); + addContender(this, contender); + } + + @Override + public void stop() throws Exception { + removeContender(this); + } + + @Override + public void confirmLeadership(LeaderInformation leaderInfo) { + confirmLeader(this, leaderInfo); + } + + @Override + public boolean hasLeadership(@Nonnull UUID leaderSessionId) { + return isLeader && leaderSessionId.equals(currentLeaderInfo.getLeaderSessionID()); + } + + void shutdown(Exception cause) { + if (running) { + running = false; + isLeader = false; + contender.revokeLeadership(); + contender = null; + } + } + } + + // ------------------------------------------------------------------------ + + private class EmbeddedLeaderRetrievalService implements LeaderRetrievalService { + + volatile LeaderRetrievalListener listener; + + volatile boolean running; + + @Override + public void start(LeaderRetrievalListener listener) throws Exception { + checkNotNull(listener); + addListener(this, listener); + } + + @Override + public void stop() throws Exception { + removeListener(this); + } + + public void shutdown(Exception cause) { + if (running) { + running = false; + listener = null; + } + } + } + + // ------------------------------------------------------------------------ + // asynchronous notifications + // ------------------------------------------------------------------------ + + private static class NotifyOfLeaderCall implements Runnable { + + private final LeaderInformation leaderInfo; // empty if leader revoked without new leader + + private final LeaderRetrievalListener listener; + private final Logger logger; + + NotifyOfLeaderCall( + LeaderInformation leaderInfo, LeaderRetrievalListener listener, Logger logger) { + + this.leaderInfo = checkNotNull(leaderInfo); + this.listener = checkNotNull(listener); + this.logger = checkNotNull(logger); + } + + @Override + public void run() { + try { + listener.notifyLeaderAddress(leaderInfo); + } catch (Throwable t) { + logger.warn("Error notifying leader listener about new leader", t); + listener.handleError(t instanceof Exception ? (Exception) t : new Exception(t)); + } + } + } + + // ------------------------------------------------------------------------ + + private static class GrantLeadershipCall implements Runnable { + + private final LeaderContender contender; + private final UUID leaderSessionId; + private final Logger logger; + + GrantLeadershipCall(LeaderContender contender, UUID leaderSessionId, Logger logger) { + + this.contender = checkNotNull(contender); + this.leaderSessionId = checkNotNull(leaderSessionId); + this.logger = checkNotNull(logger); + } + + @Override + public void run() { + try { + contender.grantLeadership(leaderSessionId); + } catch (Throwable t) { + logger.warn("Error granting leadership to contender", t); + contender.handleError(t instanceof Exception ? (Exception) t : new Exception(t)); + } + } + } + + private static class RevokeLeadershipCall implements Runnable { + + @Nonnull private final LeaderContender contender; + + RevokeLeadershipCall(@Nonnull LeaderContender contender) { + this.contender = contender; + } + + @Override + public void run() { + contender.revokeLeadership(); + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/standalone/StandaloneHaServices.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/standalone/StandaloneHaServices.java new file mode 100644 index 00000000..b63eb86f --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/standalone/StandaloneHaServices.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability.standalone; + +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalService; + +import javax.annotation.concurrent.GuardedBy; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * An implementation of the {@link HaServices} for the non-high-availability case. This + * implementation can be used for testing, and for cluster setups that do not tolerate failures of + * the master processes. + * + *

This implementation has no dependencies on any external services. It returns a fix + * pre-configured ShuffleManager, and stores checkpoints and metadata simply on the heap or on a + * local file system and therefore in a storage without guarantees. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.highavailability.nonha.standalone.StandaloneHaServices). + */ +public class StandaloneHaServices implements HaServices { + + protected final Object lock = new Object(); + + /** The fix address of the ShuffleManager. */ + private final String shuffleManagerAddress; + + private boolean shutdown; + + public StandaloneHaServices(String shuffleManagerAddress) { + this.shuffleManagerAddress = checkNotNull(shuffleManagerAddress); + } + + @Override + public LeaderRetrievalService createLeaderRetrievalService(LeaderReceptor receptor) { + synchronized (lock) { + checkNotShutdown(); + + return new StandaloneLeaderRetrievalService( + new LeaderInformation(DEFAULT_LEADER_ID, shuffleManagerAddress)); + } + } + + @Override + public LeaderElectionService createLeaderElectionService() { + synchronized (lock) { + checkNotShutdown(); + + return new StandaloneLeaderElectionService(); + } + } + + @GuardedBy("lock") + protected void checkNotShutdown() { + checkState(!shutdown, "high availability services are shut down"); + } + + @Override + public void close() { + synchronized (lock) { + if (!shutdown) { + shutdown = true; + } + } + } + + @Override + public void closeAndCleanupAllData() { + this.close(); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/standalone/StandaloneLeaderElectionService.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/standalone/StandaloneLeaderElectionService.java new file mode 100644 index 00000000..36593a86 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/standalone/StandaloneLeaderElectionService.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability.standalone; + +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderContender; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; + +import javax.annotation.Nonnull; + +import java.util.UUID; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** + * A Standalone implementation of the {@link LeaderElectionService} interface. The standalone + * implementation assumes that there is only a single {@link LeaderContender} and thus directly + * grants him the leadership upon start up. Furthermore, there is no communication needed between + * multiple standalone leader election services. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderelection.StandaloneLeaderElectionService). + */ +public class StandaloneLeaderElectionService implements LeaderElectionService { + + private LeaderContender contender = null; + + @Override + public void start(LeaderContender newContender) throws Exception { + if (contender != null) { + // Service was already started + throw new IllegalArgumentException( + "Leader election service cannot be started multiple times."); + } + + contender = checkNotNull(newContender); + + // directly grant leadership to the given contender + contender.grantLeadership(HaServices.DEFAULT_LEADER_ID); + } + + @Override + public void stop() { + if (contender != null) { + contender.revokeLeadership(); + contender = null; + } + } + + @Override + public void confirmLeadership(LeaderInformation leaderInfo) {} + + @Override + public boolean hasLeadership(@Nonnull UUID leaderSessionId) { + return (contender != null && HaServices.DEFAULT_LEADER_ID.equals(leaderSessionId)); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/standalone/StandaloneLeaderRetrievalService.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/standalone/StandaloneLeaderRetrievalService.java new file mode 100644 index 00000000..f63172a6 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/standalone/StandaloneLeaderRetrievalService.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability.standalone; + +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalListener; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalService; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * A Standalone implementation of the {@link LeaderRetrievalService}. This implementation assumes + * that there is only a single contender for leadership (e.g., a single ShuffleManager process) and + * that this process is reachable under a constant address. + * + *

As soon as this service is started, it immediately notifies the leader listener of the leader + * contender with the pre-configured address. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderretrieval.StandaloneLeaderRetrievalService). + */ +public class StandaloneLeaderRetrievalService implements LeaderRetrievalService { + + private final Object startStopLock = new Object(); + + /** Leader information including leader address and leader ID. */ + private final LeaderInformation leaderInfo; + + /** Flag whether this service is started. */ + private boolean started; + + /** Creates a StandaloneLeaderRetrievalService with the given leader address. */ + public StandaloneLeaderRetrievalService(LeaderInformation leaderInfo) { + this.leaderInfo = checkNotNull(leaderInfo); + } + + // ------------------------------------------------------------------------ + + @Override + public void start(LeaderRetrievalListener listener) { + checkArgument(listener != null, "Listener must not be null."); + + synchronized (startStopLock) { + checkState(!started, "StandaloneLeaderRetrievalService can only be started once."); + started = true; + + // directly notify the listener, because we already know the leading ShuffleManager's + // address + listener.notifyLeaderAddress(leaderInfo); + } + } + + @Override + public void stop() { + synchronized (startStopLock) { + started = false; + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperHaServices.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperHaServices.java new file mode 100644 index 00000000..b08b3b3d --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperHaServices.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalService; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; +import org.apache.flink.shaded.curator4.org.apache.curator.utils.ZKPaths; +import org.apache.flink.shaded.zookeeper3.org.apache.zookeeper.KeeperException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** + * An implementation of the {@link HaServices} using Apache ZooKeeper. The services store data in + * ZooKeeper's nodes. + */ +public class ZooKeeperHaServices implements HaServices { + + private static final Logger LOG = LoggerFactory.getLogger(ZooKeeperHaServices.class); + + public static final String SHUFFLE_MANAGER_LEADER_LATCH_PATH = ".shuffle_manager_leader_lock"; + + public static final String SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH = + ".shuffle_manager_leader_info"; + + // ------------------------------------------------------------------------ + + /** The coordinator configuration. */ + protected final Configuration configuration; + + /** The ZooKeeper client to use. */ + private final CuratorFramework client; + + public ZooKeeperHaServices(Configuration configuration, CuratorFramework client) { + this.configuration = checkNotNull(configuration); + this.client = checkNotNull(client); + } + + @Override + public LeaderRetrievalService createLeaderRetrievalService(LeaderReceptor receptor) { + return ZooKeeperUtils.createLeaderRetrievalService( + client, configuration, SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH, receptor); + } + + @Override + public LeaderElectionService createLeaderElectionService() { + return ZooKeeperUtils.createLeaderElectionService( + client, + configuration, + SHUFFLE_MANAGER_LEADER_LATCH_PATH, + SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH); + } + + @Override + public void close() throws Exception { + LOG.info("Close {}.", getClass().getSimpleName()); + + try { + client.close(); + } catch (Throwable throwable) { + LOG.error("Could not properly close {}.", getClass().getSimpleName(), throwable); + ExceptionUtils.rethrowException(throwable); + } + } + + @Override + public void closeAndCleanupAllData() throws Exception { + LOG.info("Close and clean up all data for {}.", getClass().getSimpleName()); + + Throwable exception = null; + + try { + cleanupZooKeeperPaths(); + } catch (Throwable throwable) { + exception = throwable; + } + + try { + close(); + } catch (Throwable throwable) { + exception = exception == null ? throwable : exception; + } + + if (exception != null) { + LOG.error( + "Could not properly close and clean up all data for {}.", + getClass().getSimpleName(), + exception); + ExceptionUtils.rethrowException(exception); + } + + LOG.info("Finished cleaning up the high availability data."); + } + + // ------------------------------------------------------------------------ + // Utilities + // ------------------------------------------------------------------------ + + /** Cleans up leftover ZooKeeper paths. */ + private void cleanupZooKeeperPaths() throws Exception { + deleteOwnedZNode(); + tryDeleteEmptyParentZNodes(); + } + + private void deleteOwnedZNode() throws Exception { + // delete the HA_CLUSTER_ID znode which is owned by this cluster + + // Since we are using Curator version 2.12 there is a bug in deleting the children + // if there is a concurrent delete operation. Therefore we need to add this retry + // logic. See https://issues.apache.org/jira/browse/CURATOR-430 for more information. + // The retry logic can be removed once we upgrade to Curator version >= 4.0.1. + boolean zNodeDeleted = false; + while (!zNodeDeleted) { + try { + client.delete().deletingChildrenIfNeeded().forPath("/"); + zNodeDeleted = true; + } catch (KeeperException.NoNodeException ignored) { + // concurrent delete operation. Try again. + LOG.debug( + "Retrying to delete owned znode because of other concurrent delete operation."); + } + } + } + + /** + * Tries to delete empty parent znodes. + * + * @throws Exception if the deletion fails for other reason than {@link + * KeeperException.NotEmptyException} + */ + private void tryDeleteEmptyParentZNodes() throws Exception { + // try to delete the parent znodes if they are empty + String remainingPath = getParentPath(getNormalizedPath(client.getNamespace())); + final CuratorFramework nonNamespaceClient = client.usingNamespace(null); + + while (!isRootPath(remainingPath)) { + try { + nonNamespaceClient.delete().forPath(remainingPath); + } catch (KeeperException.NotEmptyException ignored) { + // We can only delete empty znodes + break; + } + + remainingPath = getParentPath(remainingPath); + } + } + + private static boolean isRootPath(String remainingPath) { + return ZKPaths.PATH_SEPARATOR.equals(remainingPath); + } + + private static String getNormalizedPath(String path) { + return ZKPaths.makePath(path, ""); + } + + private static String getParentPath(String path) { + return ZKPaths.getPathAndNode(path).getPath(); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperLeaderElectionDriver.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperLeaderElectionDriver.java new file mode 100644 index 00000000..0831c93b --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperLeaderElectionDriver.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper; + +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionDriver; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionEventHandler; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.api.UnhandledErrorListener; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.ChildData; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.NodeCache; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.NodeCacheListener; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.leader.LeaderLatch; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.leader.LeaderLatchListener; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.state.ConnectionState; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.state.ConnectionStateListener; +import org.apache.flink.shaded.zookeeper3.org.apache.zookeeper.CreateMode; +import org.apache.flink.shaded.zookeeper3.org.apache.zookeeper.KeeperException; +import org.apache.flink.shaded.zookeeper3.org.apache.zookeeper.data.Stat; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * {@link LeaderElectionDriver} implementation for Zookeeper. The leading ShuffleManager is elected + * using ZooKeeper. The current leader's address as well as its leader session ID is published via + * ZooKeeper. + */ +public class ZooKeeperLeaderElectionDriver + implements LeaderElectionDriver, + LeaderLatchListener, + NodeCacheListener, + UnhandledErrorListener { + + private static final Logger LOG = LoggerFactory.getLogger(ZooKeeperLeaderElectionDriver.class); + + /** Client to the ZooKeeper quorum. */ + private final CuratorFramework client; + + /** Curator recipe for leader election. */ + private final LeaderLatch leaderLatch; + + /** Curator recipe to watch a given ZooKeeper node for changes. */ + private final NodeCache nodeCache; + + /** ZooKeeper path of the node which stores the current leader information. */ + private final String leaderPath; + + private final ConnectionStateListener listener = + (client, newState) -> handleStateChange(newState); + + private final LeaderElectionEventHandler leaderElectionEventHandler; + + private final FatalErrorHandler fatalErrorHandler; + + private final String leaderContenderDescription; + + private volatile boolean running = true; + + public ZooKeeperLeaderElectionDriver( + CuratorFramework client, + String latchPath, + String leaderPath, + LeaderElectionEventHandler leaderElectionEventHandler, + FatalErrorHandler fatalErrorHandler, + String leaderContenderDescription) + throws Exception { + this.client = checkNotNull(client); + this.leaderPath = checkNotNull(leaderPath); + this.leaderElectionEventHandler = checkNotNull(leaderElectionEventHandler); + this.fatalErrorHandler = checkNotNull(fatalErrorHandler); + this.leaderContenderDescription = checkNotNull(leaderContenderDescription); + this.leaderLatch = new LeaderLatch(client, checkNotNull(latchPath)); + this.nodeCache = new NodeCache(client, leaderPath); + + client.getUnhandledErrorListenable().addListener(this); + leaderLatch.addListener(this); + leaderLatch.start(); + nodeCache.getListenable().addListener(this); + nodeCache.start(); + client.getConnectionStateListenable().addListener(listener); + } + + @Override + public void close() throws Exception { + if (!running) { + return; + } + running = false; + + LOG.info("Closing {}", this); + + client.getUnhandledErrorListenable().removeListener(this); + client.getConnectionStateListenable().removeListener(listener); + + Exception exception = null; + + try { + nodeCache.close(); + } catch (Exception e) { + exception = e; + } + + try { + leaderLatch.close(); + } catch (Exception e) { + exception = exception == null ? e : exception; + } + + if (exception != null) { + throw new Exception( + "Could not properly stop the ZooKeeperLeaderElectionDriver.", exception); + } + } + + @Override + public boolean hasLeadership() { + checkState(running, "Not in running state."); + return leaderLatch.hasLeadership(); + } + + @Override + public void isLeader() { + leaderElectionEventHandler.onGrantLeadership(); + } + + @Override + public void notLeader() { + leaderElectionEventHandler.onRevokeLeadership(); + } + + @Override + public void nodeChanged() throws Exception { + if (!leaderLatch.hasLeadership()) { + return; + } + + ChildData childData = nodeCache.getCurrentData(); + if (childData == null) { + leaderElectionEventHandler.onLeaderInformationChange(LeaderInformation.empty()); + return; + } + + byte[] data = childData.getData(); + if (data == null || data.length <= 0) { + leaderElectionEventHandler.onLeaderInformationChange(LeaderInformation.empty()); + return; + } + + LeaderInformation leaderInfo = LeaderInformation.fromByteArray(data); + leaderElectionEventHandler.onLeaderInformationChange(leaderInfo); + } + + /** Writes the current leader's address as well the given leader session ID to ZooKeeper. */ + @Override + public void writeLeaderInformation(LeaderInformation leaderInfo) { + checkState(running, "Not in running state."); + // this method does not have to be synchronized because the curator framework client + // is thread-safe. We do not write the empty data to ZooKeeper here. Because + // check-leadership-and-update is not a transactional operation. We may wrongly clear the + // data written by new leader. + LOG.info("Writing leader information: {}.", leaderInfo); + if (leaderInfo.isEmpty()) { + return; + } + + try { + byte[] leaderInfoBytes = leaderInfo.toByteArray(); + while (leaderLatch.hasLeadership()) { + Stat stat = client.checkExists().forPath(leaderPath); + + if (stat == null) { + try { + client.create() + .creatingParentsIfNeeded() + .withMode(CreateMode.EPHEMERAL) + .forPath(leaderPath, leaderInfoBytes); + break; + } catch (KeeperException.NodeExistsException nodeExists) { + // node has been created in the meantime --> try again + } + continue; + } + + long owner = stat.getEphemeralOwner(); + long sessionID = client.getZookeeperClient().getZooKeeper().getSessionId(); + if (owner == sessionID) { + try { + client.setData().forPath(leaderPath, leaderInfoBytes); + break; + } catch (KeeperException.NoNodeException noNode) { + // node was deleted in the meantime + } + } else { + try { + client.delete().forPath(leaderPath); + } catch (KeeperException.NoNodeException noNode) { + // node was deleted in the meantime --> try again + } + } + } + LOG.info("Successfully wrote leader information: {}.", leaderInfo); + } catch (Throwable throwable) { + fatalErrorHandler.onFatalError( + new Exception( + "Could not write leader address and leader session ID to ZooKeeper.", + throwable)); + } + } + + private void handleStateChange(ConnectionState newState) { + switch (newState) { + case CONNECTED: + LOG.info("Connected to ZooKeeper quorum. Leader election can start."); + break; + case SUSPENDED: + LOG.warn( + "Connection to ZooKeeper suspended. The contender {} no longer participates" + + " in the leader election.", + leaderContenderDescription); + break; + case RECONNECTED: + LOG.info( + "Connection to ZooKeeper was reconnected. Leader election can be restarted."); + break; + case LOST: + // Maybe we have to throw an exception here to terminate the ShuffleManager + LOG.warn( + "Connection to ZooKeeper lost. The contender {} no longer participates in " + + "the leader election.", + leaderContenderDescription); + break; + } + } + + @Override + public void unhandledError(String message, Throwable throwable) { + fatalErrorHandler.onFatalError( + new Exception( + String.format( + "Unhandled error in ZooKeeperLeaderElectionDriver: %s.", message), + throwable)); + } + + @Override + public String toString() { + return "ZooKeeperLeaderElectionDriver{" + "leaderPath=" + leaderPath + "}"; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperLeaderElectionDriverFactory.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperLeaderElectionDriverFactory.java new file mode 100644 index 00000000..93f9fdc2 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperLeaderElectionDriverFactory.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper; + +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionDriverFactory; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionEventHandler; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; + +/** + * {@link LeaderElectionDriverFactory} implementation for Zookeeper. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderretrieval.ZooKeeperLeaderElectionDriverFactory). + */ +public class ZooKeeperLeaderElectionDriverFactory implements LeaderElectionDriverFactory { + + private final CuratorFramework client; + + private final String latchPath; + + private final String leaderPath; + + public ZooKeeperLeaderElectionDriverFactory( + CuratorFramework client, String latchPath, String leaderPath) { + this.client = client; + this.latchPath = latchPath; + this.leaderPath = leaderPath; + } + + @Override + public ZooKeeperLeaderElectionDriver createLeaderElectionDriver( + LeaderElectionEventHandler leaderEventHandler, + FatalErrorHandler fatalErrorHandler, + String leaderContenderDescription) + throws Exception { + return new ZooKeeperLeaderElectionDriver( + client, + latchPath, + leaderPath, + leaderEventHandler, + fatalErrorHandler, + leaderContenderDescription); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperLeaderRetrievalDriverFactory.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperLeaderRetrievalDriverFactory.java new file mode 100644 index 00000000..c345828f --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperLeaderRetrievalDriverFactory.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalDriver; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalDriverFactory; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalEventHandler; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; + +/** + * {@link LeaderRetrievalDriverFactory} implementation for Zookeeper. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderretrieval.ZooKeeperLeaderRetrievalDriverFactory). + */ +public class ZooKeeperLeaderRetrievalDriverFactory implements LeaderRetrievalDriverFactory { + + private final CuratorFramework client; + + private final String clusterID; + + private final String retrievalPathSuffix; + + private final HaServices.LeaderReceptor receptor; + + public ZooKeeperLeaderRetrievalDriverFactory( + CuratorFramework client, + String clusterID, + String retrievalPathSuffix, + HaServices.LeaderReceptor receptor) { + checkArgument(client != null, "Must be not null."); + checkArgument(clusterID != null, "Must be not null."); + checkArgument(retrievalPathSuffix != null, "Must be not null."); + checkArgument(receptor != null, "Must be not null."); + + this.client = client; + this.clusterID = clusterID; + this.retrievalPathSuffix = retrievalPathSuffix; + this.receptor = receptor; + } + + @Override + public LeaderRetrievalDriver createLeaderRetrievalDriver( + LeaderRetrievalEventHandler leaderEventHandler, FatalErrorHandler fatalErrorHandler) + throws Exception { + switch (receptor) { + case SHUFFLE_CLIENT: + return new ZooKeeperMultiLeaderRetrievalDriver( + client, retrievalPathSuffix, leaderEventHandler, fatalErrorHandler); + case SHUFFLE_WORKER: + return new ZooKeeperSingleLeaderRetrievalDriver( + client, + clusterID + retrievalPathSuffix, + leaderEventHandler, + fatalErrorHandler); + default: + throw new ShuffleException("Unknown leader receptor type: " + receptor); + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperMultiLeaderRetrievalDriver.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperMultiLeaderRetrievalDriver.java new file mode 100644 index 00000000..aefc8369 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperMultiLeaderRetrievalDriver.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper; + +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; +import com.alibaba.flink.shuffle.common.utils.ProtocolUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalDriver; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalEventHandler; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.api.UnhandledErrorListener; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.ChildData; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.PathChildrenCache; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.PathChildrenCacheEvent; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.PathChildrenCacheListener; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; + +/** + * A {@link LeaderRetrievalDriver} implementation which can select a leader from multiple remote + * shuffle clusters. + */ +public class ZooKeeperMultiLeaderRetrievalDriver + implements LeaderRetrievalDriver, UnhandledErrorListener, PathChildrenCacheListener { + + private static final Logger LOG = + LoggerFactory.getLogger(ZooKeeperMultiLeaderRetrievalDriver.class); + + private final CuratorFramework client; + + private final String retrievalPathSuffix; + + private final LeaderRetrievalEventHandler leaderListener; + + private final PathChildrenCache pathChildrenCache; + + private final FatalErrorHandler fatalErrorHandler; + + private final AtomicReference currentLeaderPath = new AtomicReference<>(); + + private final AtomicBoolean running = new AtomicBoolean(true); + + public ZooKeeperMultiLeaderRetrievalDriver( + CuratorFramework client, + String retrievalPathSuffix, + LeaderRetrievalEventHandler leaderListener, + FatalErrorHandler fatalErrorHandler) + throws Exception { + checkArgument(client != null, "Must be not null."); + checkArgument(retrievalPathSuffix != null, "Must be not null."); + checkArgument(leaderListener != null, "Must be not null."); + checkArgument(fatalErrorHandler != null, "Must be not null."); + + this.client = client; + this.retrievalPathSuffix = retrievalPathSuffix; + this.leaderListener = leaderListener; + this.fatalErrorHandler = fatalErrorHandler; + + client.getUnhandledErrorListenable().addListener(this); + this.pathChildrenCache = new PathChildrenCache(client, "/", true); + pathChildrenCache.getListenable().addListener(this); + pathChildrenCache.start(PathChildrenCache.StartMode.BUILD_INITIAL_CACHE); + mayUpdateLeader(pathChildrenCache.getCurrentData()); + } + + private void mayUpdateLeader(List childDataList) { + if (childDataList == null || childDataList.isEmpty() || currentLeaderPath.get() != null) { + return; + } + + LeaderInformation selectedLeaderInfo = null; + String selectedLeaderPath = null; + try { + for (ChildData childData : childDataList) { + if (childData == null || !childData.getPath().endsWith(retrievalPathSuffix)) { + continue; + } + + LeaderInformation leaderInfo = deserializeLeaderInfo(childData); + if (leaderInfo != null + && (selectedLeaderInfo == null + || selectedLeaderInfo.getProtocolVersion() + < leaderInfo.getProtocolVersion())) { + selectedLeaderInfo = leaderInfo; + selectedLeaderPath = childData.getPath(); + } + } + + if (selectedLeaderInfo != null + && currentLeaderPath.compareAndSet(null, selectedLeaderPath)) { + notifyNewLeaderInfo(selectedLeaderInfo); + } + } catch (Throwable throwable) { + fatalErrorHandler.onFatalError( + new Exception("FATAL: Failed to retrieve the leader information.", throwable)); + } + } + + private LeaderInformation deserializeLeaderInfo(ChildData childData) { + try { + if (childData == null) { + return null; + } + + byte[] dataBytes = childData.getData(); + if (dataBytes == null || dataBytes.length == 0) { + return null; + } + + LeaderInformation leaderInfo = LeaderInformation.fromByteArray(dataBytes); + if (!ProtocolUtils.isServerProtocolCompatible( + leaderInfo.getProtocolVersion(), leaderInfo.getSupportedVersion())) { + LOG.info("Ignore incompatible leader {}.", leaderInfo); + return null; + } + return leaderInfo; + } catch (Throwable throwable) { + fatalErrorHandler.onFatalError( + new Exception("FATAL: Failed to deserialize leader information.", throwable)); + return null; + } + } + + private void notifyNewLeaderInfo(LeaderInformation leaderInfo) { + LOG.info("Notify new leader information: {} : {}.", leaderInfo, currentLeaderPath.get()); + leaderListener.notifyLeaderAddress(leaderInfo); + } + + @Override + public void childEvent( + CuratorFramework curatorFramework, PathChildrenCacheEvent pathChildrenCacheEvent) { + String leaderPath = currentLeaderPath.get(); + ChildData childData = pathChildrenCacheEvent.getData(); + LeaderInformation leaderInfo; + + switch (pathChildrenCacheEvent.getType()) { + case INITIALIZED: + LOG.info("Children cache initialized, begin to retrieve leader information."); + break; + case CHILD_ADDED: + LOG.info("New child node added: {}.", childData.getPath()); + mayUpdateLeader(Collections.singletonList(childData)); + break; + case CHILD_REMOVED: + LOG.info("New child node added: {}.", childData.getPath()); + if (leaderPath == null || !leaderPath.equals(childData.getPath())) { + return; + } + + if (currentLeaderPath.compareAndSet(leaderPath, null)) { + notifyNewLeaderInfo(LeaderInformation.empty()); + } + mayUpdateLeader(pathChildrenCache.getCurrentData()); + break; + case CHILD_UPDATED: + LOG.info("Child node data updated: {}.", childData.getPath()); + if (leaderPath == null || !leaderPath.equals(childData.getPath())) { + return; + } + + leaderInfo = deserializeLeaderInfo(childData); + notifyNewLeaderInfo(leaderInfo != null ? leaderInfo : LeaderInformation.empty()); + break; + case CONNECTION_SUSPENDED: + LOG.warn("Connection to ZooKeeper suspended. Can no longer retrieve the leader."); + leaderListener.notifyLeaderAddress(LeaderInformation.empty()); + break; + case CONNECTION_RECONNECTED: + LOG.info("Connection to ZooKeeper was reconnected. Restart leader retrieval."); + currentLeaderPath.compareAndSet(leaderPath, null); + mayUpdateLeader(pathChildrenCache.getCurrentData()); + break; + case CONNECTION_LOST: + LOG.warn("Connection to ZooKeeper lost. Can no longer retrieve the leader."); + leaderListener.notifyLeaderAddress(LeaderInformation.empty()); + break; + default: + // this should never happen + fatalErrorHandler.onFatalError( + new Exception( + "Unknown zookeeper event: " + pathChildrenCacheEvent.getType())); + } + } + + @Override + public void close() throws Exception { + if (running.compareAndSet(true, false)) { + LOG.info("Closing {}.", this); + + client.getUnhandledErrorListenable().removeListener(this); + + try { + pathChildrenCache.close(); + } catch (Throwable throwable) { + throw new Exception( + "Could not properly stop the ZooKeeperLeaderRetrievalDriver.", throwable); + } + } + } + + @Override + public void unhandledError(String message, Throwable throwable) { + fatalErrorHandler.onFatalError( + new Exception( + String.format( + "Unhandled error in ZooKeeperMultiLeaderRetrievalDriver: %s.", + message), + throwable)); + } + + @Override + public String toString() { + return "ZooKeeperMultiLeaderRetrievalDriver{" + + "retrievalPathSuffix=" + + retrievalPathSuffix + + "}"; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperSingleLeaderRetrievalDriver.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperSingleLeaderRetrievalDriver.java new file mode 100644 index 00000000..694a8da9 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperSingleLeaderRetrievalDriver.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper; + +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalDriver; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalEventHandler; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalService; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.api.UnhandledErrorListener; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.ChildData; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.NodeCache; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.NodeCacheListener; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.state.ConnectionState; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.state.ConnectionStateListener; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.atomic.AtomicBoolean; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; + +/** + * The counterpart to the {@link ZooKeeperLeaderElectionDriver}. {@link LeaderRetrievalService} + * implementation for Zookeeper. It retrieves the current leader which has been elected by the + * {@link ZooKeeperLeaderElectionDriver}. The leader address as well as the current leader session + * ID is retrieved from ZooKeeper. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.leaderretrieval.ZooKeeperLeaderRetrievalDriver). + */ +public class ZooKeeperSingleLeaderRetrievalDriver + implements LeaderRetrievalDriver, NodeCacheListener, UnhandledErrorListener { + + private static final Logger LOG = + LoggerFactory.getLogger(ZooKeeperSingleLeaderRetrievalDriver.class); + + /** Connection to the used ZooKeeper quorum. */ + private final CuratorFramework client; + + /** Curator recipe to watch changes of a specific ZooKeeper node. */ + private final NodeCache nodeCache; + + private final String retrievalPath; + + private final ConnectionStateListener connectionStateListener = + (client, newState) -> handleStateChange(newState); + + private final LeaderRetrievalEventHandler leaderListener; + + private final FatalErrorHandler fatalErrorHandler; + + private final AtomicBoolean running = new AtomicBoolean(true); + + public ZooKeeperSingleLeaderRetrievalDriver( + CuratorFramework client, + String retrievalPath, + LeaderRetrievalEventHandler leaderListener, + FatalErrorHandler fatalErrorHandler) + throws Exception { + checkArgument(client != null, "CuratorFramework client must be not null."); + checkArgument(retrievalPath != null, "Retrieval path must be not null."); + checkArgument(leaderListener != null, "Event handler must be not null."); + checkArgument(fatalErrorHandler != null, "Fatal error handler must be not null."); + + this.client = client; + this.nodeCache = new NodeCache(client, retrievalPath); + this.retrievalPath = retrievalPath; + this.leaderListener = leaderListener; + this.fatalErrorHandler = fatalErrorHandler; + + client.getUnhandledErrorListenable().addListener(this); + nodeCache.getListenable().addListener(this); + nodeCache.start(); + client.getConnectionStateListenable().addListener(connectionStateListener); + } + + @Override + public void close() throws Exception { + if (running.compareAndSet(true, false)) { + LOG.info("Closing {}.", this); + + client.getUnhandledErrorListenable().removeListener(this); + client.getConnectionStateListenable().removeListener(connectionStateListener); + + try { + nodeCache.close(); + } catch (Throwable throwable) { + throw new Exception( + "Could not properly stop the ZooKeeperLeaderRetrievalDriver.", throwable); + } + } + } + + @Override + public void nodeChanged() { + LOG.info("Leader node has changed."); + retrieveLeaderInformationFromZooKeeper(); + } + + private void retrieveLeaderInformationFromZooKeeper() { + try { + ChildData childData = nodeCache.getCurrentData(); + if (childData == null) { + leaderListener.notifyLeaderAddress(LeaderInformation.empty()); + return; + } + + byte[] data = childData.getData(); + if (data == null || data.length <= 0) { + leaderListener.notifyLeaderAddress(LeaderInformation.empty()); + return; + } + + LeaderInformation leaderInfo = LeaderInformation.fromByteArray(data); + leaderListener.notifyLeaderAddress(leaderInfo); + } catch (Throwable throwable) { + fatalErrorHandler.onFatalError( + new Exception("Could not handle node changed event.", throwable)); + } + } + + private void handleStateChange(ConnectionState newState) { + switch (newState) { + case CONNECTED: + LOG.info("Connected to ZooKeeper quorum. Leader retrieval can start."); + break; + case SUSPENDED: + LOG.warn("Connection to ZooKeeper suspended. Can no longer retrieve the leader."); + leaderListener.notifyLeaderAddress(LeaderInformation.empty()); + break; + case RECONNECTED: + LOG.info("Connection to ZooKeeper was reconnected. Restart leader retrieval."); + retrieveLeaderInformationFromZooKeeper(); + break; + case LOST: + LOG.warn("Connection to ZooKeeper lost. Can no longer retrieve the leader."); + leaderListener.notifyLeaderAddress(LeaderInformation.empty()); + break; + } + } + + @Override + public void unhandledError(String message, Throwable throwable) { + fatalErrorHandler.onFatalError( + new Exception( + String.format( + "Unhandled error in ZooKeeperSingleLeaderRetrievalDriver: %s.", + message), + throwable)); + } + + @Override + public String toString() { + return "ZooKeeperSingleLeaderRetrievalDriver{" + "retrievalPath=" + retrievalPath + "}"; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperUtils.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperUtils.java new file mode 100644 index 00000000..10a42793 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/highavailability/zookeeper/ZooKeeperUtils.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.DefaultLeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.DefaultLeaderRetrievalService; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionDriverFactory; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalDriverFactory; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFrameworkFactory; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.imps.DefaultACLProvider; +import org.apache.flink.shaded.curator4.org.apache.curator.retry.ExponentialBackoffRetry; + +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** Class containing helper functions to interact with ZooKeeper. */ +public class ZooKeeperUtils { + + private static final Logger LOG = LoggerFactory.getLogger(ZooKeeperUtils.class); + + /** + * Starts a {@link CuratorFramework} instance and connects it to the given ZooKeeper quorum. + * + * @param configuration {@link Configuration} object containing the configuration values. + */ + public static CuratorFramework startCuratorFramework(Configuration configuration) { + checkNotNull(configuration); + String zkQuorum = configuration.getString(HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM); + + if (zkQuorum == null || StringUtils.isBlank(zkQuorum)) { + throw new RuntimeException( + "No valid ZooKeeper quorum has been specified. " + + "You can specify the quorum via the configuration key '" + + HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM.key() + + "'."); + } + + long sessionTimeout = + configuration + .getDuration(HighAvailabilityOptions.ZOOKEEPER_SESSION_TIMEOUT) + .toMillis(); + long connectionTimeout = + configuration + .getDuration(HighAvailabilityOptions.ZOOKEEPER_CONNECTION_TIMEOUT) + .toMillis(); + long retryWait = + configuration.getDuration(HighAvailabilityOptions.ZOOKEEPER_RETRY_WAIT).toMillis(); + int maxRetryAttempts = + configuration.getInteger(HighAvailabilityOptions.ZOOKEEPER_MAX_RETRY_ATTEMPTS); + + String root = configuration.getString(HighAvailabilityOptions.HA_ZOOKEEPER_ROOT); + LOG.info("Using '{}' as Zookeeper namespace.", root); + + CuratorFramework curatorFramework = + CuratorFrameworkFactory.builder() + .connectString(zkQuorum) + .sessionTimeoutMs(CommonUtils.checkedDownCast(sessionTimeout)) + .connectionTimeoutMs(CommonUtils.checkedDownCast(connectionTimeout)) + .retryPolicy( + new ExponentialBackoffRetry( + CommonUtils.checkedDownCast(retryWait), maxRetryAttempts)) + // Curator prepends a '/' manually and throws an Exception if the + // namespace starts with a '/'. + .namespace(root.startsWith("/") ? root.substring(1) : root) + .aclProvider(new DefaultACLProvider()) + .build(); + curatorFramework.start(); + return curatorFramework; + } + + /** + * Returns the configured ZooKeeper quorum (and removes whitespace, because ZooKeeper does not + * tolerate it). + */ + public static String getZooKeeperEnsemble(Configuration conf) { + String zkQuorum = conf.getString(HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM); + if (zkQuorum == null || StringUtils.isBlank(zkQuorum)) { + throw new ConfigurationException("No ZooKeeper quorum specified in config."); + } + // Remove all whitespace + return zkQuorum.replaceAll("\\s+", ""); + } + + /** + * Creates a {@link DefaultLeaderRetrievalService} instance with {@link + * ZooKeeperSingleLeaderRetrievalDriver}. + * + * @param client The {@link CuratorFramework} ZooKeeper client to use + * @param configuration {@link Configuration} object containing the configuration values + * @param retrievalPathSuffix The path suffix of the leader retrieval node + * @param receptor Type of the leader information receptor + * @return {@link DefaultLeaderRetrievalService} instance. + */ + public static DefaultLeaderRetrievalService createLeaderRetrievalService( + final CuratorFramework client, + final Configuration configuration, + final String retrievalPathSuffix, + final HaServices.LeaderReceptor receptor) { + return new DefaultLeaderRetrievalService( + createLeaderRetrievalDriverFactory( + client, configuration, retrievalPathSuffix, receptor)); + } + + /** + * Creates a {@link LeaderRetrievalDriverFactory} implemented by ZooKeeper. + * + * @param client The {@link CuratorFramework} ZooKeeper client to use + * @param configuration {@link Configuration} object containing the configuration values + * @param retrievalPathSuffix The path suffix of the leader retrieval node + * @param receptor Type of the leader information receptor + * @return {@link LeaderRetrievalDriverFactory} instance. + */ + public static ZooKeeperLeaderRetrievalDriverFactory createLeaderRetrievalDriverFactory( + final CuratorFramework client, + final Configuration configuration, + final String retrievalPathSuffix, + final HaServices.LeaderReceptor receptor) { + return new ZooKeeperLeaderRetrievalDriverFactory( + client, generateZookeeperClusterId(configuration), retrievalPathSuffix, receptor); + } + + /** + * Creates a {@link DefaultLeaderElectionService} instance with {@link + * ZooKeeperLeaderElectionDriver}. + * + * @param client The {@link CuratorFramework} ZooKeeper client to use + * @param configuration {@link Configuration} object containing the configuration values + * @param latchPathSuffix The path suffix of the leader latch node + * @param retrievalPathSuffix The path suffix of the leader retrieval node + * @return {@link DefaultLeaderElectionService} instance. + */ + public static DefaultLeaderElectionService createLeaderElectionService( + final CuratorFramework client, + final Configuration configuration, + final String latchPathSuffix, + final String retrievalPathSuffix) { + return new DefaultLeaderElectionService( + createLeaderElectionDriverFactory( + client, configuration, latchPathSuffix, retrievalPathSuffix)); + } + + /** + * Creates a {@link LeaderElectionDriverFactory} implemented by ZooKeeper. + * + * @param client The {@link CuratorFramework} ZooKeeper client to use + * @param configuration {@link Configuration} object containing the configuration values + * @param latchPathSuffix The path suffix of the leader latch node + * @param retrievalPathSuffix The path suffix of the leader retrieval node + * @return {@link LeaderElectionDriverFactory} instance. + */ + public static ZooKeeperLeaderElectionDriverFactory createLeaderElectionDriverFactory( + final CuratorFramework client, + final Configuration configuration, + final String latchPathSuffix, + final String retrievalPathSuffix) { + return new ZooKeeperLeaderElectionDriverFactory( + client, + getPath(configuration, latchPathSuffix), + getPath(configuration, retrievalPathSuffix)); + } + + private static String getPath(Configuration configuration, String pathSuffix) { + return generateZookeeperClusterId(configuration) + pathSuffix; + } + + private static String generateZookeeperClusterId(Configuration configuration) { + String clusterId = configuration.getString(ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID); + if (clusterId.trim().isEmpty()) { + throw new ConfigurationException( + String.format( + "Illegal config value for %s. Must be not empty.", + ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID.key())); + } + + if (!clusterId.startsWith("/")) { + clusterId = '/' + clusterId; + } + + if (clusterId.endsWith("/")) { + clusterId = clusterId.substring(0, clusterId.length() - 1); + } + + return clusterId; + } + + /** Private constructor to prevent instantiation. */ + private ZooKeeperUtils() {} +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/DataPartitionCoordinate.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/DataPartitionCoordinate.java new file mode 100644 index 00000000..0c8fac49 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/DataPartitionCoordinate.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; + +import java.io.Serializable; +import java.util.Objects; + +/** + * The coordinate of a data partition. A data partition could be fully identified by the DataSetID + * and the DataPartitionID. + */ +public class DataPartitionCoordinate implements Serializable { + + private static final long serialVersionUID = 6488556324861837102L; + + private final DataSetID dataSetId; + + private final DataPartitionID dataPartitionId; + + public DataPartitionCoordinate(DataSetID dataSetId, DataPartitionID dataPartitionId) { + this.dataSetId = dataSetId; + this.dataPartitionId = dataPartitionId; + } + + public DataSetID getDataSetId() { + return dataSetId; + } + + public DataPartitionID getDataPartitionId() { + return dataPartitionId; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DataPartitionCoordinate that = (DataPartitionCoordinate) o; + return Objects.equals(dataSetId, that.dataSetId) + && Objects.equals(dataPartitionId, that.dataPartitionId); + } + + @Override + public int hashCode() { + return Objects.hash(dataSetId, dataPartitionId); + } + + @Override + public String toString() { + return "DataPartitionCoordinate{" + + "dataSetId=" + + dataSetId + + ", dataPartitionId=" + + dataPartitionId + + '}'; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/DataPartitionStatus.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/DataPartitionStatus.java new file mode 100644 index 00000000..f019e1e0 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/DataPartitionStatus.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import com.alibaba.flink.shuffle.core.ids.JobID; + +import java.io.Serializable; +import java.util.Objects; + +/** The DataPartition's state. */ +public class DataPartitionStatus implements Serializable, Cloneable { + + private static final long serialVersionUID = 1912308953069806999L; + + private final JobID jobId; + + private final DataPartitionCoordinate coordinate; + + private boolean isReleasing; + + public DataPartitionStatus(JobID jobId, DataPartitionCoordinate coordinate) { + this(jobId, coordinate, false); + } + + public DataPartitionStatus( + JobID jobId, DataPartitionCoordinate coordinate, boolean isReleasing) { + this.jobId = jobId; + this.coordinate = coordinate; + this.isReleasing = isReleasing; + } + + public JobID getJobId() { + return jobId; + } + + public DataPartitionCoordinate getCoordinate() { + return coordinate; + } + + public boolean isReleasing() { + return isReleasing; + } + + public void setReleasing(boolean releasing) { + isReleasing = releasing; + } + + @Override + public DataPartitionStatus clone() { + return new DataPartitionStatus(jobId, coordinate, isReleasing); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + DataPartitionStatus that = (DataPartitionStatus) o; + return isReleasing == that.isReleasing + && Objects.equals(jobId, that.jobId) + && Objects.equals(coordinate, that.coordinate); + } + + @Override + public int hashCode() { + return Objects.hash(jobId, coordinate, isReleasing); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/DefaultShuffleResource.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/DefaultShuffleResource.java new file mode 100644 index 00000000..da506f3b --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/DefaultShuffleResource.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import com.alibaba.flink.shuffle.core.storage.DataPartition; + +import org.apache.commons.lang3.StringUtils; + +import java.util.Objects; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** Shuffle Resource representation for the data partition. */ +public class DefaultShuffleResource implements ShuffleResource { + + private static final long serialVersionUID = -8562771913795553025L; + + /** The addresses of the allocated shuffle resource for the data partition. */ + private final ShuffleWorkerDescriptor[] shuffleWorkerDescriptors; + + /** The type of the data partition. */ + private final DataPartition.DataPartitionType dataPartitionType; + + public DefaultShuffleResource( + ShuffleWorkerDescriptor[] shuffleWorkerDescriptors, + DataPartition.DataPartitionType dataPartitionType) { + checkArgument(shuffleWorkerDescriptors.length > 0, "Must be positive."); + checkArgument( + dataPartitionType == DataPartition.DataPartitionType.REDUCE_PARTITION + || shuffleWorkerDescriptors.length == 1, + "Illegal number of shuffle worker descriptors."); + + this.shuffleWorkerDescriptors = shuffleWorkerDescriptors; + this.dataPartitionType = dataPartitionType; + } + + @Override + public ShuffleWorkerDescriptor[] getReducePartitionLocations() { + checkState(dataPartitionType.equals(DataPartition.DataPartitionType.REDUCE_PARTITION)); + return shuffleWorkerDescriptors; + } + + @Override + public ShuffleWorkerDescriptor getMapPartitionLocation() { + checkState(dataPartitionType.equals(DataPartition.DataPartitionType.MAP_PARTITION)); + return shuffleWorkerDescriptors[0]; + } + + public DataPartition.DataPartitionType getDataPartitionType() { + return dataPartitionType; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DefaultShuffleResource that = (DefaultShuffleResource) o; + if (shuffleWorkerDescriptors.length != that.shuffleWorkerDescriptors.length) { + return false; + } + + if (!dataPartitionType.equals(that.dataPartitionType)) { + return false; + } + + for (int i = 0; i < shuffleWorkerDescriptors.length; i++) { + if (!Objects.equals(shuffleWorkerDescriptors[i], that.shuffleWorkerDescriptors[i])) { + return false; + } + } + return true; + } + + @Override + public int hashCode() { + int result = + StringUtils.isBlank(dataPartitionType.toString()) + ? 0 + : dataPartitionType.hashCode(); + for (ShuffleWorkerDescriptor shuffleWorkerDescriptor : shuffleWorkerDescriptors) { + result = result * 31 + Objects.hash(shuffleWorkerDescriptor); + } + return result; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("{"); + for (int i = 0; i < shuffleWorkerDescriptors.length; i++) { + sb.append(shuffleWorkerDescriptors[i].toString()); + if (i < shuffleWorkerDescriptors.length - 1) { + sb.append(","); + } + } + + if (!StringUtils.isBlank(dataPartitionType.toString())) { + sb.append(",").append("dataPartitionType=").append(dataPartitionType); + } + sb.append("}"); + + return sb.toString(); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/JobDataPartitionDistribution.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/JobDataPartitionDistribution.java new file mode 100644 index 00000000..d2634e53 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/JobDataPartitionDistribution.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import java.io.Serializable; +import java.util.Map; + +/** The distribution of the data partitions. */ +public class JobDataPartitionDistribution implements Serializable { + + private static final long serialVersionUID = -6643141147307585686L; + + private final Map dataPartitionDistribution; + + public JobDataPartitionDistribution( + Map dataPartitionDistribution) { + this.dataPartitionDistribution = dataPartitionDistribution; + } + + public Map getDataPartitionDistribution() { + return dataPartitionDistribution; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ManagerToJobHeartbeatPayload.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ManagerToJobHeartbeatPayload.java new file mode 100644 index 00000000..6996e4bf --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ManagerToJobHeartbeatPayload.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker.ChangedWorkerStatus; +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +import java.io.Serializable; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** The payload for the heartbeat between shuffle manager and job. */ +public class ManagerToJobHeartbeatPayload implements Serializable { + + private static final long serialVersionUID = 7306522388105576112L; + + /** The resource id of the shuffle manager. */ + private final InstanceID managerID; + + /** The workers that get changed for this job. */ + private final ChangedWorkerStatus changedWorkerStatus; + + public ManagerToJobHeartbeatPayload( + InstanceID managerID, ChangedWorkerStatus changedWorkerStatus) { + this.managerID = checkNotNull(managerID); + this.changedWorkerStatus = checkNotNull(changedWorkerStatus); + } + + public InstanceID getManagerID() { + return managerID; + } + + public ChangedWorkerStatus getJobChangedWorkerStatus() { + return changedWorkerStatus; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/RegistrationSuccess.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/RegistrationSuccess.java new file mode 100644 index 00000000..d2828ae9 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/RegistrationSuccess.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import com.alibaba.flink.shuffle.coordinator.registration.RegistrationResponse; +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +/** The registration response containing ResourceId of the target. */ +public class RegistrationSuccess extends RegistrationResponse.Success { + + private static final long serialVersionUID = -4143369389748407647L; + + /** The instance id of the address. */ + private final InstanceID instanceID; + + public RegistrationSuccess(InstanceID instanceID) { + this.instanceID = instanceID; + } + + public InstanceID getInstanceID() { + return instanceID; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManager.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManager.java new file mode 100644 index 00000000..77fa6654 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManager.java @@ -0,0 +1,885 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.common.utils.FutureUtils; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatListener; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatManager; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServices; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatTarget; +import com.alibaba.flink.shuffle.coordinator.heartbeat.NoOpHeartbeatManager; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderContender; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker.AssignmentTracker; +import com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker.ChangedWorkerStatus; +import com.alibaba.flink.shuffle.coordinator.registration.RegistrationResponse; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerGateway; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerMetrics; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.RegistrationID; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleFencedRpcEndpoint; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; +import com.alibaba.flink.shuffle.rpc.message.Acknowledge; +import com.alibaba.flink.shuffle.rpc.utils.AkkaRpcServiceUtils; + +import org.apache.commons.lang3.tuple.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeoutException; +import java.util.stream.Collectors; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * The ShuffleManager is responsible for: 1. Allocating shuffle workers to a FLINK job for + * writing/reading shuffle data; 2. Managing all the ShuffleWorkers. + */ +public class ShuffleManager extends RemoteShuffleFencedRpcEndpoint + implements ShuffleManagerGateway, LeaderContender { + + private static final Logger LOG = LoggerFactory.getLogger(ShuffleManager.class); + + public static final String SHUFFLE_MANAGER_NAME = "shufflemanager"; + + /** Unique id of the shuffle manager. */ + private final InstanceID managerID; + + /** The ha service. */ + private final HaServices haServices; + + /** The service to elect a shuffle manager leader. */ + private final LeaderElectionService leaderElectionService; + + /** Fatal error handler. */ + private final FatalErrorHandler fatalErrorHandler; + + /** The executor is responsible for some time-consuming operations. */ + private final Executor ioExecutor; + + /** The heartbeat service between the shuffle manager and jobs. */ + private final HeartbeatServices jobHeartbeatServices; + + /** The heartbeat service between the shuffle manager and shuffle works. */ + private final HeartbeatServices workerHeartbeatServices; + + /** + * The period for shuffle manager to wait for workers to register before it starts to notify + * clients to remove partitions. + */ + private final long workerStatusSyncPeriodMillis; + + /** + * The heartbeat manager is responsible for timing heartbeat between the shuffle manager and + * jobs. + */ + private HeartbeatManager jobHeartbeatManager; + + /** + * The heartbeat manager is responsible for timing heartbeat between the shuffle manager and + * shuffle workers. + */ + private HeartbeatManager workerHeartbeatManager; + + /** The instance id of client for each job. */ + private final Map registeredClients; + + /** The component tracks the data partition assignment. */ + private final AssignmentTracker assignmentTracker; + + /** All currently registered shuffle workers. */ + private final Map shuffleWorkers; + + /** All registering shuffle workers. */ + private final Map> + shuffleWorkerGatewayFutures; + + private long targetWorkStatusSyncTimeMillis = 0; + + private boolean firstLeaderShipGrant = true; + + public ShuffleManager( + RemoteShuffleRpcService rpcService, + InstanceID managerID, + HaServices haServices, + FatalErrorHandler fatalErrorHandler, + Executor ioExecutor, + HeartbeatServices jobHeartbeatServices, + HeartbeatServices workerHeartbeatServices, + AssignmentTracker assignmentTracker) { + super(rpcService, AkkaRpcServiceUtils.createRandomName(SHUFFLE_MANAGER_NAME), null); + + this.managerID = managerID; + this.haServices = checkNotNull(haServices); + this.leaderElectionService = haServices.createLeaderElectionService(); + this.fatalErrorHandler = fatalErrorHandler; + this.ioExecutor = ioExecutor; + + this.jobHeartbeatServices = jobHeartbeatServices; + this.workerHeartbeatServices = workerHeartbeatServices; + this.jobHeartbeatManager = NoOpHeartbeatManager.getInstance(); + this.workerHeartbeatManager = NoOpHeartbeatManager.getInstance(); + + this.registeredClients = new HashMap<>(); + this.assignmentTracker = assignmentTracker; + + this.workerStatusSyncPeriodMillis = 2 * workerHeartbeatServices.getHeartbeatTimeout(); + + this.shuffleWorkers = new HashMap<>(); + this.shuffleWorkerGatewayFutures = new HashMap<>(); + } + + @Override + public CompletableFuture registerWorker( + ShuffleWorkerRegistration workerRegistration) { + CompletableFuture shuffleWorkerGatewayFuture = + getRpcService() + .connectTo(workerRegistration.getRpcAddress(), ShuffleWorkerGateway.class); + shuffleWorkerGatewayFutures.put( + workerRegistration.getWorkerID(), shuffleWorkerGatewayFuture); + + LOG.info("Shuffle worker {} is registering", workerRegistration); + + return shuffleWorkerGatewayFuture.handleAsync( + (ShuffleWorkerGateway shuffleWorkerGateway, Throwable throwable) -> { + final InstanceID workerID = workerRegistration.getWorkerID(); + if (shuffleWorkerGatewayFuture == shuffleWorkerGatewayFutures.get(workerID)) { + shuffleWorkerGatewayFutures.remove(workerID); + if (throwable != null) { + return new RegistrationResponse.Decline(throwable.getMessage()); + } else { + return registerShuffleWorkerInternal( + shuffleWorkerGateway, workerRegistration); + } + } else { + LOG.debug( + "Ignoring outdated ShuffleWorkerGateway connection for {}.", + workerID); + return new RegistrationResponse.Decline( + "Decline outdated shuffle worker registration."); + } + }, + getRpcMainThreadScheduledExecutor()); + } + + @Override + public CompletableFuture workerReportDataPartitionReleased( + InstanceID workerID, + RegistrationID registrationID, + JobID jobID, + DataSetID dataSetID, + DataPartitionID dataPartitionID) { + if (!assignmentTracker.isWorkerRegistered(registrationID)) { + LOG.warn("Received report from unmanaged Shuffle Worker {}", workerID); + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + assignmentTracker.workerReportDataPartitionReleased( + registrationID, jobID, dataSetID, dataPartitionID); + + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + @Override + public CompletableFuture reportDataPartitionStatus( + InstanceID workerID, + RegistrationID registrationID, + List dataPartitionStatuses) { + if (!assignmentTracker.isWorkerRegistered(registrationID)) { + LOG.warn("Received report from unmanaged Shuffle Worker {}", workerID); + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + dataPartitionStatuses.forEach( + dataStatus -> { + JobID jobId = dataStatus.getJobId(); + + if (!assignmentTracker.isJobRegistered(jobId)) { + internalRegisterClient(jobId); + } + }); + + assignmentTracker.synchronizeWorkerDataPartitions(registrationID, dataPartitionStatuses); + + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + @Override + public void heartbeatFromWorker(InstanceID workerID, WorkerToManagerHeartbeatPayload payload) { + workerHeartbeatManager.receiveHeartbeat(workerID, payload); + } + + @Override + public CompletableFuture disconnectWorker(InstanceID workerID, Exception cause) { + closeShuffleWorkerConnection(workerID, cause); + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + @Override + public CompletableFuture registerClient( + JobID jobID, InstanceID clientID) { + + InstanceID oldInstanceId = registeredClients.get(jobID); + if (oldInstanceId != null && !oldInstanceId.equals(clientID)) { + LOG.warn( + "The job {} with instance id {} is replaced with another instance {}", + jobID, + oldInstanceId, + clientID); + } + registeredClients.put(jobID, clientID); + + internalRegisterClient(jobID); + return CompletableFuture.completedFuture(new RegistrationSuccess(managerID)); + } + + private void internalRegisterClient(JobID jobId) { + if (!assignmentTracker.isJobRegistered(jobId)) { + LOG.info("Registering session {}", jobId); + + assignmentTracker.registerJob(jobId); + + jobHeartbeatManager.monitorTarget( + instanceIDFromJobID(jobId), + new HeartbeatTarget() { + @Override + public void receiveHeartbeat( + InstanceID heartbeatOrigin, Void heartbeatPayload) { + // We do not need to notify session side + } + + @Override + public void requestHeartbeat( + InstanceID requestOrigin, Void heartbeatPayload) { + // should not do it + } + }); + } + } + + @Override + public CompletableFuture unregisterClient(JobID jobID, InstanceID clientID) { + try { + checkInstanceIdConsistent(jobID, clientID, "Unregister client"); + } catch (Exception e) { + return FutureUtils.completedExceptionally(e); + } + + LOG.info("Client {} unregister actively, will do nothing now", jobID); + registeredClients.remove(jobID); + + // Do nothing here since the job might recover. + // Let's trust the timeout to ensure the final cleanup. + + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + private void unregisterClientInternal(JobID jobId) { + if (assignmentTracker.isJobRegistered(jobId)) { + LOG.info("Unregister session {}", jobId); + assignmentTracker.unregisterJob(jobId); + jobHeartbeatManager.unmonitorTarget(instanceIDFromJobID(jobId)); + } + } + + @Override + public CompletableFuture heartbeatFromClient( + JobID jobID, InstanceID clientID, Set cachedWorkerList) { + try { + checkInstanceIdConsistent(jobID, clientID, "heartbeatFromClient"); + } catch (Exception e) { + return FutureUtils.completedExceptionally(e); + } + + jobHeartbeatManager.receiveHeartbeat(instanceIDFromJobID(jobID), null); + + ChangedWorkerStatus changedWorkerStatus = + assignmentTracker.computeChangedWorkers( + jobID, + cachedWorkerList, + System.nanoTime() / 1000000 >= targetWorkStatusSyncTimeMillis); + return CompletableFuture.completedFuture( + new ManagerToJobHeartbeatPayload(this.managerID, changedWorkerStatus)); + } + + @Override + public CompletableFuture requestShuffleResource( + JobID jobID, + InstanceID clientID, + DataSetID dataSetID, + MapPartitionID mapPartitionID, + int numberOfConsumers, + String dataPartitionFactoryName) { + + try { + checkInstanceIdConsistent(jobID, clientID, "Allocate shuffle resource"); + } catch (Exception e) { + return FutureUtils.completedExceptionally(e); + } + + LOG.info( + "Request resource for session {}, dataset {}, producer {} and total consumer {}", + jobID, + dataSetID, + mapPartitionID, + numberOfConsumers); + + try { + ShuffleResource allocatedShuffleResource = + assignmentTracker.requestShuffleResource( + jobID, + dataSetID, + mapPartitionID, + numberOfConsumers, + dataPartitionFactoryName); + return CompletableFuture.completedFuture(allocatedShuffleResource); + } catch (Exception e) { + LOG.error("Request new Shuffle Resource failed.", e); + return FutureUtils.completedExceptionally(e); + } + } + + @Override + public CompletableFuture releaseShuffleResource( + JobID jobID, InstanceID clientID, DataSetID dataSetID, MapPartitionID mapPartitionID) { + + try { + checkInstanceIdConsistent(jobID, clientID, "Release shuffle resource"); + } catch (Exception e) { + return FutureUtils.completedExceptionally(e); + } + + LOG.info( + "Remove resource for session {}, dataset {}, producer {}", + jobID, + dataSetID, + mapPartitionID); + assignmentTracker.releaseShuffleResource(jobID, dataSetID, mapPartitionID); + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + @Override + public CompletableFuture getNumberOfRegisteredWorkers() { + return CompletableFuture.supplyAsync( + shuffleWorkers::size, getRpcMainThreadScheduledExecutor()); + } + + @Override + public CompletableFuture> getShuffleWorkerMetrics() { + + List>> workerMetrics = + shuffleWorkers.values().stream() + .map( + worker -> + worker.shuffleWorkerGateway + .getWorkerMetrics() + .thenApply( + metric -> + Pair.of( + worker.shuffleWorkerId, + metric))) + .collect(Collectors.toList()); + + return FutureUtils.combineAll(workerMetrics) + .thenApply( + pairs -> + pairs.stream() + .collect(Collectors.toMap(Pair::getKey, Pair::getValue))); + } + + @Override + public CompletableFuture> listJobs() { + return CompletableFuture.completedFuture(assignmentTracker.listJobs()); + } + + @Override + public CompletableFuture getJobDataPartitionDistribution( + JobID jobID) { + Map jobDataPartitionDistribution = + assignmentTracker.getDataPartitionDistribution(jobID).entrySet().stream() + .collect( + Collectors.toMap( + Map.Entry::getKey, + e -> shuffleWorkers.get(e.getValue()).getRegistration())); + return CompletableFuture.completedFuture( + new JobDataPartitionDistribution(jobDataPartitionDistribution)); + } + + private void checkInstanceIdConsistent( + JobID jobID, InstanceID requestId, String loggedOperation) { + InstanceID oldInstanceId = registeredClients.get(jobID); + checkState( + Objects.equals(oldInstanceId, requestId), + String.format( + "%s requests with inconsistent instance id for job %s, current id is %s and requested is %s", + loggedOperation, jobID, oldInstanceId, requestId)); + } + + // ------------------------------------------------------------------------ + // RPC lifecycle methods + // ------------------------------------------------------------------------ + + @Override + protected void onStart() throws Exception { + try { + startShuffleManagerServices(); + } catch (Throwable t) { + final Exception exception = + new Exception(String.format("Could not start %s", getAddress()), t); + onFatalError(exception); + throw exception; + } + } + + private void startShuffleManagerServices() throws Exception { + try { + leaderElectionService.start(this); + } catch (Exception e) { + handleStartShuffleManagerServicesException(e); + } + } + + private void handleStartShuffleManagerServicesException(Exception e) throws Exception { + try { + stopShuffleManagerServices(); + } catch (Exception inner) { + e.addSuppressed(inner); + } + + throw e; + } + + @Override + public final CompletableFuture onStop() { + try { + stopShuffleManagerServices(); + } catch (Exception exception) { + return FutureUtils.completedExceptionally( + new ShuffleException( + "Could not properly shut down the ShuffleManager.", exception)); + } + + return CompletableFuture.completedFuture(null); + } + + private void stopShuffleManagerServices() throws Exception { + Exception exception = null; + + stopHeartbeatServices(); + + try { + leaderElectionService.stop(); + } catch (Exception e) { + exception = e; + } + + try { + haServices.closeAndCleanupAllData(); + } catch (Exception e) { + exception = exception == null ? e : exception; + } + + clearStateInternal(); + + if (exception != null) { + ExceptionUtils.rethrowException(exception); + } + } + + // ------------------------------------------------------------------------ + // Error Handling + // ------------------------------------------------------------------------ + + /** + * Notifies the ShuffleManager that a fatal error has occurred and it cannot proceed. + * + * @param t The exception describing the fatal error + */ + protected void onFatalError(Throwable t) { + try { + LOG.error("Fatal error occurred in ShuffleManager.", t); + } catch (Throwable ignored) { + } + + // The fatal error handler implementation should make sure that this call is non-blocking + fatalErrorHandler.onFatalError(t); + } + + // ------------------------------------------------------------------------ + // Leader Contender + // ------------------------------------------------------------------------ + + /** + * Callback method when current ShuffleManager is granted leadership. + * + * @param newLeaderSessionID unique leadershipID + */ + @Override + public void grantLeadership(final UUID newLeaderSessionID) { + final CompletableFuture acceptLeadershipFuture = + CompletableFuture.supplyAsync( + () -> tryAcceptLeadership(newLeaderSessionID), + getUnfencedMainThreadExecutor()); + + final CompletableFuture confirmationFuture = + acceptLeadershipFuture.thenAcceptAsync( + (acceptLeadership) -> { + if (acceptLeadership) { + // confirming the leader session ID might be blocking, + leaderElectionService.confirmLeadership( + new LeaderInformation(newLeaderSessionID, getAddress())); + } + }, + ioExecutor); + + confirmationFuture.whenCompleteAsync( + (Void ignored, Throwable throwable) -> { + if (throwable != null) { + onFatalError(throwable); + } + targetWorkStatusSyncTimeMillis = + System.nanoTime() / 1000000 + workerStatusSyncPeriodMillis; + }, + getUnfencedMainThreadExecutor()); + } + + @Override + public void revokeLeadership() { + runAsyncWithoutFencing( + () -> { + LOG.info( + "ShuffleManager {} was revoked leadership. Clearing fencing token {}.", + getAddress(), + getFencingToken()); + + clearStateInternal(); + + setFencingToken(null); + + // We force increase this deadline to avoid there are long time period between + // the revoke and re-grant. + targetWorkStatusSyncTimeMillis = + System.nanoTime() / 1000000 + Integer.MAX_VALUE; + }); + } + + @Override + public void handleError(Throwable throwable) { + onFatalError(throwable); + } + + private boolean tryAcceptLeadership(UUID newLeaderSessionID) { + if (leaderElectionService.hasLeadership(newLeaderSessionID)) { + LOG.info( + "ShuffleManager {} was granted leadership with fencing token {}", + getAddress(), + newLeaderSessionID); + + // clear the state if we've been the leader before + if (getFencingToken() != null) { + clearStateInternal(); + } + + setFencingToken(newLeaderSessionID); + + startHeartbeatServices(); + + firstLeaderShipGrant = false; + + return true; + } + + return false; + } + + // ------------------------------------------------------------------------ + // Heartbeat Service + // ------------------------------------------------------------------------ + private void startHeartbeatServices() { + if (firstLeaderShipGrant) { + LOG.info("Initialize the heartbeat services."); + jobHeartbeatManager = + jobHeartbeatServices.createHeartbeatManager( + managerID, + new JobHeartbeatListener(), + getRpcMainThreadScheduledExecutor(), + log); + + workerHeartbeatManager = + workerHeartbeatServices.createHeartbeatManagerSender( + managerID, + new WorkerHeartbeatListener(), + getRpcMainThreadScheduledExecutor(), + log); + } + } + + private void stopHeartbeatServices() { + jobHeartbeatManager.stop(); + workerHeartbeatManager.stop(); + } + + // ------------------------------------------------------------------------ + // Clear the internal state + // ------------------------------------------------------------------------ + + private void clearStateInternal() { + LOG.info("Currently, we would not clear the state to avoid large-scale restarting."); + } + + // ------------------------------------------------------------------------ + // ShuffleWorker related action + // ------------------------------------------------------------------------ + + /** + * Registers a new ShuffleWorker. + * + * @param shuffleWorkerRegistration shuffle worker registration parameters + * @return RegistrationResponse + */ + private RegistrationResponse registerShuffleWorkerInternal( + ShuffleWorkerGateway shuffleWorkerGateway, + ShuffleWorkerRegistration shuffleWorkerRegistration) { + + final InstanceID workerID = shuffleWorkerRegistration.getWorkerID(); + final ShuffleWorkerRegistrationInstance oldRegistration = shuffleWorkers.remove(workerID); + if (oldRegistration != null) { + // TODO :: suggest old ShuffleWorker to stop itself + log.info( + "Replacing old registration of ShuffleWorker {}: {}.", + workerID, + oldRegistration.getShuffleWorkerRegisterId()); + + // remove old shuffle worker registration from assignment tracker. + assignmentTracker.unregisterWorker(oldRegistration.getShuffleWorkerRegisterId()); + } + + final String workerAddress = shuffleWorkerRegistration.getHostname(); + ShuffleWorkerRegistrationInstance newRecord = + new ShuffleWorkerRegistrationInstance( + workerID, shuffleWorkerGateway, shuffleWorkerRegistration); + + log.info( + "Registering ShuffleWorker with ID {} ({}, {}) at ShuffleManager", + workerID, + workerAddress, + newRecord.getShuffleWorkerRegisterId()); + shuffleWorkers.put(workerID, newRecord); + + workerHeartbeatManager.monitorTarget( + workerID, + new HeartbeatTarget() { + @Override + public void receiveHeartbeat(InstanceID instanceID, Void payload) { + // the ShuffleManager will always send heartbeat requests to the + // ShuffleWorker + } + + @Override + public void requestHeartbeat(InstanceID instanceID, Void payload) { + shuffleWorkerGateway.heartbeatFromManager(instanceID); + } + }); + + assignmentTracker.registerWorker( + workerID, + newRecord.getShuffleWorkerRegisterId(), + shuffleWorkerGateway, + shuffleWorkerRegistration.getHostname(), + shuffleWorkerRegistration.getDataPort()); + + LOG.info( + "Stat on register worker: shuffleWorkers.size = {}, assignmentTracker has tracked {} workers", + shuffleWorkers.size(), + assignmentTracker.getNumberOfWorkers()); + + return new ShuffleWorkerRegistrationSuccess( + newRecord.getShuffleWorkerRegisterId(), managerID); + } + + private void closeShuffleWorkerConnection( + final InstanceID shuffleWorkerId, final Exception cause) { + LOG.info("Disconnect ShuffleWorker {} because: {}", shuffleWorkerId, cause.getMessage()); + + workerHeartbeatManager.unmonitorTarget(shuffleWorkerId); + + final ShuffleWorkerRegistrationInstance record = shuffleWorkers.remove(shuffleWorkerId); + if (record != null) { + assignmentTracker.unregisterWorker(record.getShuffleWorkerRegisterId()); + + record.getShuffleWorkerGateway().disconnectManager(cause); + } else { + log.info( + "No open ShuffleWorker connection {}. Ignoring close ShuffleWorker connection. Closing reason was: {}", + shuffleWorkerId, + cause.getMessage()); + } + + LOG.info( + "Stat on unregister worker: shuffleWorkers.size = {}, assignmentTracker has tracked {} workers", + shuffleWorkers.size(), + assignmentTracker.getNumberOfWorkers()); + } + + AssignmentTracker getAssignmentTracker() { + return assignmentTracker; + } + + public Map getRegisteredClients() { + return registeredClients; + } + + public Map getShuffleWorkers() { + return shuffleWorkers; + } + + HeartbeatManager getJobHeartbeatManager() { + return jobHeartbeatManager; + } + + HeartbeatManager getWorkerHeartbeatManager() { + return workerHeartbeatManager; + } + + // ------------------------------------------------------------------------ + // Static utility classes + // ------------------------------------------------------------------------ + private class JobHeartbeatListener implements HeartbeatListener { + + @Override + public void notifyHeartbeatTimeout(InstanceID instanceID) { + validateRunsInMainThread(); + LOG.info("The heartbeat of client with id {} timed out.", instanceID); + + unregisterClientInternal(jobIdFromInstanceID(instanceID)); + } + + @Override + public void reportPayload(InstanceID instanceID, Void payload) { + validateRunsInMainThread(); + } + + @Override + public Void retrievePayload(InstanceID instanceID) { + return null; + } + } + + private class WorkerHeartbeatListener + implements HeartbeatListener { + + @Override + public void notifyHeartbeatTimeout(InstanceID instanceID) { + validateRunsInMainThread(); + LOG.info("The heartbeat of shuffle worker with id {} timed out.", instanceID); + + closeShuffleWorkerConnection( + instanceID, + new TimeoutException( + "The heartbeat of ShuffleWorker with id " + + instanceID + + " timed out.")); + } + + @Override + public void reportPayload(InstanceID instanceID, WorkerToManagerHeartbeatPayload payload) { + final ShuffleWorkerRegistrationInstance shuffleWorkerRegistrationInstance = + shuffleWorkers.get(instanceID); + if (shuffleWorkerRegistrationInstance != null) { + reportDataPartitionStatus( + instanceID, + shuffleWorkerRegistrationInstance.getShuffleWorkerRegisterId(), + payload.getDataPartitionStatuses()); + } else { + LOG.warn( + "The shuffle worker with id {} is not registered before but receive the heartbeat.", + instanceID); + } + } + + @Override + public Void retrievePayload(InstanceID instanceID) { + return null; + } + } + + /** This class records a shuffle worker's registration in one success attempt. */ + private static class ShuffleWorkerRegistrationInstance { + + private final InstanceID shuffleWorkerId; + + private final RegistrationID shuffleWorkerRegisterId; + + private final ShuffleWorkerGateway shuffleWorkerGateway; + + private final ShuffleWorkerRegistration registration; + + public ShuffleWorkerRegistrationInstance( + InstanceID shuffleWorkerId, + ShuffleWorkerGateway shuffleWorkerGateway, + ShuffleWorkerRegistration registration) { + this.shuffleWorkerId = shuffleWorkerId; + this.registration = registration; + this.shuffleWorkerRegisterId = new RegistrationID(); + this.shuffleWorkerGateway = shuffleWorkerGateway; + } + + public RegistrationID getShuffleWorkerRegisterId() { + return shuffleWorkerRegisterId; + } + + public ShuffleWorkerGateway getShuffleWorkerGateway() { + return shuffleWorkerGateway; + } + + public InstanceID getShuffleWorkerId() { + return shuffleWorkerId; + } + + public ShuffleWorkerRegistration getRegistration() { + return registration; + } + } + + private static InstanceID instanceIDFromJobID(JobID jobId) { + return new InstanceID(jobId.getId()); + } + + private static JobID jobIdFromInstanceID(InstanceID instanceID) { + return new JobID(instanceID.getId()); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManagerGateway.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManagerGateway.java new file mode 100644 index 00000000..379ab080 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManagerGateway.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import com.alibaba.flink.shuffle.coordinator.registration.RegistrationResponse; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.RegistrationID; +import com.alibaba.flink.shuffle.rpc.message.Acknowledge; + +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** The shuffle manager rpc gateway. */ +public interface ShuffleManagerGateway extends ShuffleManagerJobGateway { + + /** + * Registers a shuffle worker to the manager. + * + * @param workerRegistration The information of the shuffle worker. + * @return The response of the registration. + */ + CompletableFuture registerWorker( + ShuffleWorkerRegistration workerRegistration); + + /** + * Worker initiates releasing one data partition, like when the data has failure during writing. + * On calling of this method, the data should have been removed on the worker side and one piece + * of meta is kept for synchronizing the status with manager. + * + * @param jobID The id of the job produces the data partition. + * @param dataSetID The id of the dataset that contains the data partition.. + * @param dataPartitionID The id of the data partition. + */ + CompletableFuture workerReportDataPartitionReleased( + InstanceID workerID, + RegistrationID registrationID, + JobID jobID, + DataSetID dataSetID, + DataPartitionID dataPartitionID); + + /** + * Reports the list of data partitions in the current worker. + * + * @param workerID The InstanceID of the shuffle worker. + * @param registrationID The InstanceId of the shuffle worker. + * @param dataPartitionStatuses The list of data partitions in the shuffle worker. + * @return Whether the report is success. + */ + CompletableFuture reportDataPartitionStatus( + InstanceID workerID, + RegistrationID registrationID, + List dataPartitionStatuses); + + /** + * Receives heartbeat from the shuffle worker. + * + * @param workerID The InstanceID of the worker. + * @param payload The status of the current worker. + */ + void heartbeatFromWorker(InstanceID workerID, WorkerToManagerHeartbeatPayload payload); + + /** + * Disconnects the shuffle worker from the manager. + * + * @param workerID The InstanceID of the shuffle worker. + * @param cause The reason of disconnect. + * @return Whether the disconnect is success. + */ + CompletableFuture disconnectWorker(InstanceID workerID, Exception cause); +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManagerJobGateway.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManagerJobGateway.java new file mode 100644 index 00000000..43f2b8ee --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManagerJobGateway.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import com.alibaba.flink.shuffle.coordinator.registration.RegistrationResponse; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerMetrics; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleFencedRpcGateway; +import com.alibaba.flink.shuffle.rpc.message.Acknowledge; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; + +/** The rpc gateway used by a FLINK job to allocate or release a {@link ShuffleResource}. */ +public interface ShuffleManagerJobGateway extends RemoteShuffleFencedRpcGateway { + + /** + * Registers a client instance to the ShuffleManager, which represents a job. + * + * @param jobID The FLINK job id. + * @return The response for the registration. + */ + CompletableFuture registerClient(JobID jobID, InstanceID clientID); + + /** + * Unregisters a client instance from the ShuffleManager, which represents a job. + * + * @param jobID The job id. + * @return The response for the registration. + */ + CompletableFuture unregisterClient(JobID jobID, InstanceID clientID); + + /** + * Receives the heartbeat from the client, which represents a job. + * + * @param jobID The job id. + * @return The InstanceID of the client. + */ + CompletableFuture heartbeatFromClient( + JobID jobID, InstanceID clientID, Set cachedWorkerList); + + /** + * Requests a shuffle resource for storing all the data partitions produced by one map task in + * one dataset. + * + * @param jobID The job id. + * @param dataSetID The id of the dataset that contains this partition. + * @param mapPartitionID The id represents the map task. + * @param numberOfConsumers The number of consumers of the partition. + * @param dataPartitionFactoryName The factory name of the data partition. + * @return The allocated shuffle resources. + */ + CompletableFuture requestShuffleResource( + JobID jobID, + InstanceID clientID, + DataSetID dataSetID, + MapPartitionID mapPartitionID, + int numberOfConsumers, + String dataPartitionFactoryName); + + /** + * Releases resources for all the data partitions produced by one map task in one dataset. + * + * @param jobID The job id. + * @param dataSetID The id of the dataset that contains this partition. + * @param mapPartitionID The id represents the map task. + * @return The result for releasing shuffle resource. + */ + CompletableFuture releaseShuffleResource( + JobID jobID, InstanceID clientID, DataSetID dataSetID, MapPartitionID mapPartitionID); + + /** + * Gets num of registered shuffle workers. + * + * @return the num of registered shuffle workers. + */ + CompletableFuture getNumberOfRegisteredWorkers(); + + /** + * Gets shuffle workers metrics. + * + * @return the shuffle workers metrics keyed by {@link InstanceID}. + */ + CompletableFuture> getShuffleWorkerMetrics(); + + CompletableFuture> listJobs(); + + CompletableFuture getJobDataPartitionDistribution(JobID jobID); +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManagerRunner.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManagerRunner.java new file mode 100644 index 00000000..1153c055 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManagerRunner.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.JvmShutdownSafeguard; +import com.alibaba.flink.shuffle.common.utils.SignalHandler; +import com.alibaba.flink.shuffle.coordinator.manager.entrypoint.ShuffleManagerEntrypoint; +import com.alibaba.flink.shuffle.coordinator.utils.ClusterEntrypointUtils; +import com.alibaba.flink.shuffle.coordinator.utils.EnvironmentInformation; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static com.alibaba.flink.shuffle.coordinator.utils.ClusterEntrypointUtils.STARTUP_FAILURE_RETURN_CODE; + +/** Runner for {@link ShuffleManager}. */ +public class ShuffleManagerRunner { + + private static final Logger LOG = LoggerFactory.getLogger(ShuffleManagerRunner.class); + + public static void main(String[] args) { + // startup checks and logging + EnvironmentInformation.logEnvironmentInfo(LOG, "Shuffle Manager", args); + SignalHandler.register(LOG); + JvmShutdownSafeguard.installAsShutdownHook(LOG); + + long maxOpenFileHandles = EnvironmentInformation.getOpenFileHandlesLimit(); + + if (maxOpenFileHandles != -1L) { + LOG.info("Maximum number of open file descriptors is {}.", maxOpenFileHandles); + } else { + LOG.info("Cannot determine the maximum number of open file descriptors"); + } + + try { + Configuration configuration = ClusterEntrypointUtils.parseParametersOrExit(args); + ShuffleManagerEntrypoint shuffleManagerEntrypoint = + new ShuffleManagerEntrypoint(configuration); + ShuffleManagerEntrypoint.runShuffleManagerEntrypoint(shuffleManagerEntrypoint); + } catch (Throwable t) { + LOG.error("ShuffleManager initialization failed.", t); + System.exit(STARTUP_FAILURE_RETURN_CODE); + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleResource.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleResource.java new file mode 100644 index 00000000..953d74c7 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleResource.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import java.io.Serializable; + +/** + * The interface represents the resource that a task could shuffle-read from or shuffle-write to. + */ +public interface ShuffleResource extends Serializable { + + ShuffleWorkerDescriptor[] getReducePartitionLocations(); + + ShuffleWorkerDescriptor getMapPartitionLocation(); +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleWorkerDescriptor.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleWorkerDescriptor.java new file mode 100644 index 00000000..da873562 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleWorkerDescriptor.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +import java.io.Serializable; +import java.util.Objects; + +/** Describes the address and identification of a shuffle worker. */ +public class ShuffleWorkerDescriptor implements Serializable { + + private static final long serialVersionUID = 7675147692379529718L; + + private final InstanceID workerId; + + private final String workerAddress; + + private final int dataPort; + + public ShuffleWorkerDescriptor(InstanceID workerId, String workerAddress, int dataPort) { + this.workerId = workerId; + this.workerAddress = workerAddress; + this.dataPort = dataPort; + } + + public InstanceID getWorkerId() { + return workerId; + } + + public String getWorkerAddress() { + return workerAddress; + } + + public int getDataPort() { + return dataPort; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + ShuffleWorkerDescriptor that = (ShuffleWorkerDescriptor) o; + return dataPort == that.dataPort + && Objects.equals(workerId, that.workerId) + && Objects.equals(workerAddress, that.workerAddress); + } + + @Override + public int hashCode() { + return Objects.hash(workerId, workerAddress, dataPort); + } + + @Override + public String toString() { + return workerId + "@[" + workerAddress + ":" + dataPort + "]"; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleWorkerRegistration.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleWorkerRegistration.java new file mode 100644 index 00000000..abbc1e88 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleWorkerRegistration.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +import java.io.Serializable; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** Information provided by the ShuffleWorker when it registers to the ShuffleManager. */ +public class ShuffleWorkerRegistration implements Serializable { + + private static final long serialVersionUID = 5666500078135011993L; + + /** The rpc address of the ShuffleWorker that registers. */ + private final String rpcAddress; + + /** The hostname of the ShuffleWorker that registers. */ + private final String hostname; + + /** The resource id of the ShuffleWorker that registers. */ + private final InstanceID workerID; + + /** The port used for data transfer. */ + private final int dataPort; + + /** The process id of the shuffle worker. Currently, it is only used in e2e tests. */ + private final int processID; + + public ShuffleWorkerRegistration( + final String rpcAddress, + final String hostname, + final InstanceID workerID, + final int dataPort, + int processID) { + this.rpcAddress = checkNotNull(rpcAddress); + this.hostname = checkNotNull(hostname); + this.workerID = checkNotNull(workerID); + this.dataPort = dataPort; + this.processID = processID; + } + + public String getRpcAddress() { + return rpcAddress; + } + + public String getHostname() { + return hostname; + } + + public InstanceID getWorkerID() { + return workerID; + } + + public int getDataPort() { + return dataPort; + } + + public int getProcessID() { + return processID; + } + + @Override + public String toString() { + return "ShuffleWorkerRegistration{" + + "rpcAddress='" + + rpcAddress + + '\'' + + ", hostname='" + + hostname + + '\'' + + ", instanceID=" + + workerID + + ", dataPort=" + + dataPort + + ", processId=" + + processID + + '}'; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleWorkerRegistrationSuccess.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleWorkerRegistrationSuccess.java new file mode 100644 index 00000000..61670dd5 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleWorkerRegistrationSuccess.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.RegistrationID; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** Indicates a successful response from ShuffleManager to ShuffleWorker when registration. */ +public final class ShuffleWorkerRegistrationSuccess extends RegistrationSuccess { + + private static final long serialVersionUID = -3022884771553508472L; + + /** The registration id of the shuffle worker. */ + private final RegistrationID registrationID; + + /** + * Create a new {@code ShuffleWorkerToManagerRegistrationSuccess} message. + * + * @param registrationID The ID that the ShuffleManager assigned the registration. + * @param managerID The unique ID that identifies the ShuffleManager. + */ + public ShuffleWorkerRegistrationSuccess(RegistrationID registrationID, InstanceID managerID) { + + super(managerID); + this.registrationID = checkNotNull(registrationID); + } + + /** Gets the ID that the ShuffleManager assigned the registration. */ + public RegistrationID getRegistrationID() { + return registrationID; + } + + /** Gets the unique ID that identifies the ShuffleManager. */ + @Override + public String toString() { + return "ShuffleWorkerToManagerRegistrationSuccess{" + + "registrationId=" + + registrationID + + ", shuffleManagerInstanceID=" + + getInstanceID() + + '}'; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/WorkerToManagerHeartbeatPayload.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/WorkerToManagerHeartbeatPayload.java new file mode 100644 index 00000000..3346fe4e --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/WorkerToManagerHeartbeatPayload.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import java.io.Serializable; +import java.util.List; + +/** Payload for heartbeats sent from the ShuffleWorker to the ShuffleManager. */ +public class WorkerToManagerHeartbeatPayload implements Serializable { + + private static final long serialVersionUID = 5801527514715010018L; + + private final List dataPartitionStatuses; + + public WorkerToManagerHeartbeatPayload(List dataPartitionStatuses) { + this.dataPartitionStatuses = dataPartitionStatuses; + } + + public List getDataPartitionStatuses() { + return dataPartitionStatuses; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/AssignmentTracker.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/AssignmentTracker.java new file mode 100644 index 00000000..a5e4b9e3 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/AssignmentTracker.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker; + +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionStatus; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleResource; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerGateway; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.RegistrationID; + +import java.util.Collection; +import java.util.List; +import java.util.Map; + +/** Component tracks the assignment of data partitions on shuffle workers. */ +public interface AssignmentTracker { + + /** + * Whether a shuffle worker is registered. + * + * @param registrationID The registration ID of the shuffle worker. + * @return Whether the shuffle worker is registered. + */ + boolean isWorkerRegistered(RegistrationID registrationID); + + /** + * Registers a shuffle worker on worker connected to the manager. + * + * @param workerID The ID of the shuffle worker. + * @param registrationID The ID of the shuffle worker registration. + * @param gateway The gateway of the shuffle worker. + * @param externalAddress The address for the tasks to connect to. + * @param dataPort The port of the shuffle worker to provide data partition read/write service. + */ + void registerWorker( + InstanceID workerID, + RegistrationID registrationID, + ShuffleWorkerGateway gateway, + String externalAddress, + int dataPort); + + /** + * Worker initiates releasing one data partition, like when the data has failure during writing. + * On calling of this method, the data should has been removed on the worker side and one piece + * of meta is kept for synchronizing the status with manager. + * + * @param workerRegistrationID The registration ID of the shuffle worker. + * @param jobID The id of the job produces the data partition. + * @param dataSetID The id of the dataset that contains the data partition.. + * @param dataPartitionID The id of the data partition. + */ + void workerReportDataPartitionReleased( + RegistrationID workerRegistrationID, + JobID jobID, + DataSetID dataSetID, + DataPartitionID dataPartitionID); + + /** + * Synchronizes the list of data partition on a shuffle worker. The caller of this method should + * ensures all the reported jobs has been registered. + * + * @param workerRegistrationID The registration ID of the shuffle worker. + * @param dataPartitionStatuses The statuses of data partitions residing on the shuffle worker. + */ + void synchronizeWorkerDataPartitions( + RegistrationID workerRegistrationID, List dataPartitionStatuses); + + /** + * Unregisters a shuffle worker when disconnected with a worker. + * + * @param workerRegistrationID The registration ID of the shuffle worker. + */ + void unregisterWorker(RegistrationID workerRegistrationID); + + /** + * Whether a job client is registered. + * + * @param jobID The job id. + * @return Whether the job client is registered. + */ + boolean isJobRegistered(JobID jobID); + + /** + * Registers a job client. + * + * @param jobID The job id. + */ + void registerJob(JobID jobID); + + /** + * Unregisters a job client. + * + * @param jobID The job id. + */ + void unregisterJob(JobID jobID); + + /** + * Requests resources for all the data partitions produced by one map task in one dataset. + * + * @param jobID The job id. + * @param dataSetID The id of the dataset that contains this partition. + * @param mapPartitionID The id represents the map task. + * @param numberOfConsumers The number of consumers of the partition. + * @param dataPartitionFactoryName The factory name of the data partition. + * @return The allocated shuffle resources. + */ + ShuffleResource requestShuffleResource( + JobID jobID, + DataSetID dataSetID, + MapPartitionID mapPartitionID, + int numberOfConsumers, + String dataPartitionFactoryName) + throws ShuffleResourceAllocationException; + + /** + * Client releases resources for all the data partitions produced by one map task in one + * dataset. + * + * @param jobID The job id. + * @param dataSetID The id of the dataset that contains this partition. + * @param mapPartitionID The id represents the map task. + */ + void releaseShuffleResource(JobID jobID, DataSetID dataSetID, MapPartitionID mapPartitionID); + + /** + * Computes the list of workers that get changed for the specified job. + * + * @param cachedWorkerList The current cached worker list for this job. + * @param considerUnrelatedWorkers Whether to also consider the unrelated workers. + * @return The status of changed workers. + */ + ChangedWorkerStatus computeChangedWorkers( + JobID jobID, Collection cachedWorkerList, boolean considerUnrelatedWorkers); + + /** + * Lists the registered jobs. + * + * @return The list of current jobs. + */ + List listJobs(); + + /** + * Gets the number of workers. + * + * @return The number of current workers. + */ + int getNumberOfWorkers(); + + /** + * Gets the data partition distribution for a specific job. + * + * @param jobID the id of the job to acquire. + * @return The data partition distribution. + */ + Map getDataPartitionDistribution(JobID jobID); +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/AssignmentTrackerImpl.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/AssignmentTrackerImpl.java new file mode 100644 index 00000000..ec1e0741 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/AssignmentTrackerImpl.java @@ -0,0 +1,469 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker; + +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionStatus; +import com.alibaba.flink.shuffle.coordinator.manager.DefaultShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerDescriptor; +import com.alibaba.flink.shuffle.coordinator.metrics.ClusterMetrics; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerGateway; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.RegistrationID; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.core.storage.DataPartitionFactory; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * Tracks the status of the current jobs, workers and the data partitions. + * + *

A data partition might be in three status: + * + *

    + *
  1. Normal + *
  2. Releasing + *
  3. Released + *
+ * + *

For each job, we only maintain normal data partitions and their assigned workers. For each + * worker, we maintain both the normal and releasing data partitions, and the releasing data + * partitions would be removed once the workers have indeed released the data partitions. + */ +public class AssignmentTrackerImpl implements AssignmentTracker { + + private static final Logger LOG = LoggerFactory.getLogger(AssignmentTrackerImpl.class); + + /** The currently registered jobs. */ + private final Map jobs = new HashMap<>(); + + /** The currently registered workers. */ + private final Map workers = new HashMap<>(); + + public AssignmentTrackerImpl() { + registerMetrics(); + } + + private void registerMetrics() { + ClusterMetrics.registerGaugeForNumJobsServing(jobs::size); + ClusterMetrics.registerGaugeForNumShuffleWorkers(workers::size); + } + + @Override + public boolean isWorkerRegistered(RegistrationID registrationID) { + return workers.containsKey(registrationID); + } + + @Override + public void registerWorker( + InstanceID workerID, + RegistrationID registrationID, + ShuffleWorkerGateway gateway, + String externalAddress, + int dataPort) { + + checkState( + !workers.containsKey(registrationID), + String.format("The worker %s has been registered", registrationID)); + + workers.put( + registrationID, + new WorkerStatus(workerID, registrationID, gateway, externalAddress, dataPort)); + } + + @Override + public void workerReportDataPartitionReleased( + RegistrationID workerRegistrationID, + JobID jobID, + DataSetID dataSetID, + DataPartitionID dataPartitionID) { + + WorkerStatus workerStatus = workers.get(workerRegistrationID); + internalNotifyWorkerToRemoveReleasedDataPartition( + jobID, new DataPartitionCoordinate(dataSetID, dataPartitionID), workerStatus); + } + + @Override + public void synchronizeWorkerDataPartitions( + RegistrationID workerRegistrationID, List reportedDataPartitions) { + + WorkerStatus workerStatus = workers.get(workerRegistrationID); + if (workerStatus == null) { + LOG.warn("Received report from unknown worker {}", workerRegistrationID); + return; + } + + // First ensure the JM have recorded all the possible data partitions. + for (DataPartitionStatus reportedDataPartition : reportedDataPartitions) { + if (!workerStatus + .getDataPartitions() + .containsKey(reportedDataPartition.getCoordinate())) { + internalAddDataPartition(workerStatus, reportedDataPartition); + } + } + + Map reportedStatusMap = new HashMap<>(); + reportedDataPartitions.forEach( + status -> reportedStatusMap.put(status.getCoordinate(), status)); + + List notifyWorkerToReleaseData = new ArrayList<>(); + List notifyWorkerToReleaseMeta = new ArrayList<>(); + List workerHasReleasedMeta = new ArrayList<>(); + + for (Map.Entry entry : + workerStatus.getDataPartitions().entrySet()) { + DataPartitionStatus status = entry.getValue(); + + DataPartitionStatus reportedStatus = reportedStatusMap.get(status.getCoordinate()); + + if (status.isReleasing()) { + if (reportedStatus == null) { + // We have removed the meta for this data partition, then + // the master could also remove the meta + workerHasReleasedMeta.add(status); + } else if (reportedStatus.isReleasing()) { + // Now the manager is aware of the releasing, then it + // asks the worker to remove the remaining meta. + notifyWorkerToReleaseMeta.add(status); + } else if (!reportedStatus.isReleasing()) { + // Worker might still not know we are going to remove the data + // partition. + notifyWorkerToReleaseData.add(status); + } + } else { + if (reportedStatus != null && reportedStatus.isReleasing()) { + // Worker initiate the releasing, the master then synchronize + // the status and asks the worker to remove the remaining meta. + status.setReleasing(true); + notifyWorkerToReleaseMeta.add(status); + } + } + } + + notifyWorkerToReleaseData.forEach( + status -> + internalReleaseDataPartition( + status.getJobId(), status.getCoordinate(), workerStatus)); + + notifyWorkerToReleaseMeta.forEach( + status -> + internalNotifyWorkerToRemoveReleasedDataPartition( + status.getJobId(), status.getCoordinate(), workerStatus)); + + workerHasReleasedMeta.forEach( + status -> workerStatus.removeReleasedDataPartition(status.getCoordinate())); + } + + @Override + public void unregisterWorker(RegistrationID workerRegistrationID) { + WorkerStatus workerStatus = workers.remove(workerRegistrationID); + + if (workerStatus == null) { + return; + } + + for (DataPartitionStatus status : workerStatus.getDataPartitions().values()) { + JobStatus jobStatus = jobs.get(status.getJobId()); + if (jobStatus != null + && jobStatus.getDataPartitions().containsKey(status.getCoordinate())) { + if (!Objects.equals( + jobStatus + .getDataPartitions() + .get(status.getCoordinate()) + .getRegistrationID(), + workerStatus.getRegistrationID())) { + LOG.warn( + "Inconsistency happens: job think the partition {} is on {}, in fact it is on {}", + status, + jobStatus.getDataPartitions().get(status.getCoordinate()), + workerStatus); + } + + jobStatus.removeDataPartition(status.getCoordinate()); + } + } + } + + @Override + public boolean isJobRegistered(JobID jobID) { + return jobs.containsKey(jobID); + } + + @Override + public void registerJob(JobID jobID) { + jobs.put(jobID, new JobStatus()); + } + + @Override + public void unregisterJob(JobID jobID) { + JobStatus jobStatus = jobs.remove(jobID); + + if (jobStatus == null) { + return; + } + + jobStatus + .getDataPartitions() + .forEach( + (id, workerStatus) -> { + DataPartitionStatus dataPartitionStatus = + workerStatus.getDataPartitions().get(id); + if (dataPartitionStatus != null) { + internalReleaseDataPartition( + dataPartitionStatus.getJobId(), + dataPartitionStatus.getCoordinate(), + workerStatus); + } + }); + } + + @Override + public ShuffleResource requestShuffleResource( + JobID jobID, + DataSetID dataSetID, + MapPartitionID mapPartitionID, + int numberOfConsumers, + String dataPartitionFactoryName) + throws ShuffleResourceAllocationException { + JobStatus jobStatus = jobs.get(jobID); + + if (jobStatus == null) { + throw new ShuffleResourceAllocationException( + "Job is not registered before requesting resources."); + } + + WorkerStatus oldStatus = + jobStatus + .getDataPartitions() + .get(new DataPartitionCoordinate(dataSetID, mapPartitionID)); + if (oldStatus != null) { + ShuffleWorkerDescriptor descriptor = oldStatus.createShuffleWorkerDescriptor(); + LOG.warn( + "The request data partition {}-{}-{} has been allocated on {}", + jobID, + dataSetID, + mapPartitionID, + descriptor); + return new DefaultShuffleResource( + new ShuffleWorkerDescriptor[] {descriptor}, + getDataPartitionType(dataPartitionFactoryName)); + } + + Optional> min = + workers.entrySet().stream() + .min( + Comparator.comparingInt( + entry -> entry.getValue().getDataPartitions().size())); + + if (!min.isPresent()) { + throw new ShuffleResourceAllocationException("No available workers"); + } + + WorkerStatus minWorkerStatus = min.get().getValue(); + internalAddDataPartition( + minWorkerStatus, + new DataPartitionStatus( + jobID, new DataPartitionCoordinate(dataSetID, mapPartitionID))); + + return new DefaultShuffleResource( + new ShuffleWorkerDescriptor[] {minWorkerStatus.createShuffleWorkerDescriptor()}, + getDataPartitionType(dataPartitionFactoryName)); + } + + /** Public permissions are used for unit testing. */ + public DataPartition.DataPartitionType getDataPartitionType(String dataPartitionFactoryName) + throws ShuffleResourceAllocationException { + try { + return DataPartitionFactory.getDataPartitionType(dataPartitionFactoryName); + } catch (Throwable throwable) { + LOG.error("Failed to get the data partition type.", throwable); + throw new ShuffleResourceAllocationException( + String.format( + "Could not find the target data partition factory class %s, please " + + "check the class name make sure the class is in the classpath.", + dataPartitionFactoryName), + throwable); + } + } + + @Override + public void releaseShuffleResource( + JobID jobID, DataSetID dataSetID, MapPartitionID mapPartitionID) { + // Try to find the worker serving this data + JobStatus jobStatus = jobs.get(jobID); + if (jobStatus != null) { + DataPartitionCoordinate coordinate = + new DataPartitionCoordinate(dataSetID, mapPartitionID); + WorkerStatus workerStatus = jobStatus.getDataPartitions().get(coordinate); + internalReleaseDataPartition(jobID, coordinate, workerStatus); + } + } + + @Override + public ChangedWorkerStatus computeChangedWorkers( + JobID jobID, + Collection cachedWorkerList, + boolean considerUnrelatedWorkers) { + JobStatus jobStatus = jobs.get(jobID); + + Map remainedWorkers = new HashMap<>(); + + List unrelatedWorkers = new ArrayList<>(); + Map> newlyRelatedWorkers = new HashMap<>(); + + jobStatus + .getDataPartitions() + .forEach( + (dataPartition, worker) -> { + if (worker.getDataPartitions().get(dataPartition).isReleasing()) { + return; + } + + remainedWorkers.put(worker.getWorkerID(), worker.getRegistrationID()); + if (!cachedWorkerList.contains(worker.getWorkerID())) { + newlyRelatedWorkers + .computeIfAbsent(worker.getWorkerID(), k -> new HashSet<>()) + .add(dataPartition); + } + }); + + if (considerUnrelatedWorkers) { + for (InstanceID workerID : cachedWorkerList) { + if (!remainedWorkers.containsKey(workerID)) { + unrelatedWorkers.add(workerID); + } else { + // The following is a safeguard for unexpected situations. + RegistrationID registrationID = remainedWorkers.get(workerID); + if (!workers.containsKey(registrationID)) { + LOG.warn( + "Inconsistency: Remaining partitions on removed worker {}", + workerID); + unrelatedWorkers.add(workerID); + } + } + } + } + + return new ChangedWorkerStatus(unrelatedWorkers, newlyRelatedWorkers); + } + + @Override + public List listJobs() { + return new ArrayList<>(jobs.keySet()); + } + + @Override + public int getNumberOfWorkers() { + return workers.size(); + } + + @Override + public Map getDataPartitionDistribution(JobID jobID) { + return jobs.get(jobID).getDataPartitions().entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getWorkerID())); + } + + public Map getJobs() { + return jobs; + } + + public Map getWorkers() { + return workers; + } + + private void internalReleaseDataPartition( + JobID jobId, DataPartitionCoordinate coordinate, @Nullable WorkerStatus workerStatus) { + + JobStatus jobStatus = jobs.get(jobId); + if (jobStatus != null) { + jobStatus.removeDataPartition(coordinate); + } + + if (workerStatus != null) { + workerStatus.markAsReleasing(jobId, coordinate); + workerStatus + .getGateway() + .releaseDataPartition( + jobId, coordinate.getDataSetId(), coordinate.getDataPartitionId()); + } + } + + private void internalNotifyWorkerToRemoveReleasedDataPartition( + JobID jobId, DataPartitionCoordinate coordinate, @Nullable WorkerStatus workerStatus) { + + JobStatus jobStatus = jobs.get(jobId); + if (jobStatus != null) { + jobStatus.removeDataPartition(coordinate); + } + + if (workerStatus != null) { + workerStatus.markAsReleasing(jobId, coordinate); + workerStatus + .getGateway() + .removeReleasedDataPartitionMeta( + jobId, coordinate.getDataSetId(), coordinate.getDataPartitionId()); + } + } + + private void internalAddDataPartition( + WorkerStatus workerStatus, DataPartitionStatus dataPartitionStatus) { + + checkState( + jobs.containsKey(dataPartitionStatus.getJobId()), + "A data partition is added before job registered."); + + // If this data partition is maintained by another worker, + // we remove the record to keep the data consistent. + // This might happen when the restarted worker registered + // before the original one timeout. + JobStatus jobStatus = jobs.get(dataPartitionStatus.getJobId()); + WorkerStatus oldWorkerStatus = + jobStatus.getDataPartitions().get(dataPartitionStatus.getCoordinate()); + + if (oldWorkerStatus != null) { + oldWorkerStatus.removeReleasedDataPartition(dataPartitionStatus.getCoordinate()); + } + + jobs.get(dataPartitionStatus.getJobId()) + .addDataPartition(dataPartitionStatus.getCoordinate(), workerStatus); + workerStatus.addDataPartition(dataPartitionStatus); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/ChangedWorkerStatus.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/ChangedWorkerStatus.java new file mode 100644 index 00000000..dc29ddfa --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/ChangedWorkerStatus.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker; + +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +import java.io.Serializable; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** The status of workers that get changed compared with the cached worker list. */ +public class ChangedWorkerStatus implements Serializable { + + private static final long serialVersionUID = -3808740665897798476L; + + /** The workers get unavailable. */ + private final List irrelevantWorkers; + + /** The workers get available again. */ + private final Map> relevantWorkers; + + public ChangedWorkerStatus( + List irrelevantWorkers, + Map> relevantWorkers) { + this.irrelevantWorkers = checkNotNull(irrelevantWorkers); + this.relevantWorkers = checkNotNull(relevantWorkers); + } + + public List getIrrelevantWorkers() { + return irrelevantWorkers; + } + + public Map> getRelevantWorkers() { + return relevantWorkers; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/JobStatus.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/JobStatus.java new file mode 100644 index 00000000..66939185 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/JobStatus.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker; + +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** Records the assignments of partition. */ +public class JobStatus { + + private final Map dataPartitions = new HashMap<>(); + + public Map getDataPartitions() { + return Collections.unmodifiableMap(dataPartitions); + } + + public void addDataPartition(DataPartitionCoordinate coordinate, WorkerStatus workerStatus) { + dataPartitions.put(coordinate, workerStatus); + } + + public void removeDataPartition(DataPartitionCoordinate coordinate) { + dataPartitions.remove(coordinate); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/ShuffleResourceAllocationException.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/ShuffleResourceAllocationException.java new file mode 100644 index 00000000..fc1228e8 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/ShuffleResourceAllocationException.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker; + +/** Exception happens during allocating shuffle resource. */ +public class ShuffleResourceAllocationException extends Exception { + + private static final long serialVersionUID = 6582098318747729455L; + + public ShuffleResourceAllocationException(String message) { + super(message); + } + + public ShuffleResourceAllocationException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/WorkerStatus.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/WorkerStatus.java new file mode 100644 index 00000000..f5f24d37 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/WorkerStatus.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker; + +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionStatus; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerDescriptor; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerGateway; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.RegistrationID; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** The status of a shuffle worker. */ +class WorkerStatus { + + private final InstanceID workerID; + + private final RegistrationID registrationID; + + private final ShuffleWorkerGateway gateway; + + private final String dataAddress; + + private final int dataPort; + + private final Map dataPartitions = + new HashMap<>(); + + public WorkerStatus( + InstanceID workerID, + RegistrationID registrationID, + ShuffleWorkerGateway gateway, + String dataAddress, + int dataPort) { + + this.workerID = checkNotNull(workerID); + this.registrationID = checkNotNull(registrationID); + this.gateway = checkNotNull(gateway); + this.dataAddress = checkNotNull(dataAddress); + this.dataPort = dataPort; + } + + public InstanceID getWorkerID() { + return workerID; + } + + public RegistrationID getRegistrationID() { + return registrationID; + } + + public ShuffleWorkerDescriptor createShuffleWorkerDescriptor() { + return new ShuffleWorkerDescriptor(workerID, dataAddress, dataPort); + } + + public void addDataPartition(DataPartitionStatus dataPartitionStatus) { + dataPartitions.put(dataPartitionStatus.getCoordinate(), dataPartitionStatus); + } + + public Map getDataPartitions() { + return Collections.unmodifiableMap(dataPartitions); + } + + public void markAsReleasing(JobID jobId, DataPartitionCoordinate coordinate) { + DataPartitionStatus dataPartitionStatus = dataPartitions.get(coordinate); + + if (dataPartitionStatus == null) { + dataPartitionStatus = new DataPartitionStatus(jobId, coordinate, true); + dataPartitions.put(coordinate, dataPartitionStatus); + } else { + dataPartitionStatus.setReleasing(true); + } + } + + public void removeReleasedDataPartition(DataPartitionCoordinate coordinate) { + dataPartitions.remove(coordinate); + } + + public ShuffleWorkerGateway getGateway() { + return gateway; + } + + @Override + public String toString() { + return "WorkerStatus{" + + "workerID=" + + workerID + + ", registrationID=" + + registrationID + + ", dataAddress='" + + dataAddress + + '\'' + + ", dataPort=" + + dataPort + + '}'; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/entrypoint/ShuffleManagerEntrypoint.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/entrypoint/ShuffleManagerEntrypoint.java new file mode 100644 index 00000000..46fb253a --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/manager/entrypoint/ShuffleManagerEntrypoint.java @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager.entrypoint; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.functions.AutoCloseableAsync; +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.common.utils.FutureUtils; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServices; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServicesUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServiceUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleManager; +import com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker.AssignmentTrackerImpl; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; +import com.alibaba.flink.shuffle.core.executor.ExecutorThreadFactory; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.metrics.entry.MetricUtils; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; +import com.alibaba.flink.shuffle.rpc.utils.AkkaRpcServiceUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.lang.reflect.UndeclaredThrowableException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** The entrypoint class for the {@link ShuffleManager}. */ +public class ShuffleManagerEntrypoint implements AutoCloseableAsync, FatalErrorHandler { + + protected static final Logger LOG = LoggerFactory.getLogger(ShuffleManagerEntrypoint.class); + + protected static final int NORMAL_RETURN_CODE = 0; + protected static final int STARTUP_FAILURE_RETURN_CODE = 1; + protected static final int RUNTIME_FAILURE_RETURN_CODE = 2; + + private static final long INITIALIZATION_SHUTDOWN_TIMEOUT = 30000L; + + /** The lock to guard startup / shutdown / manipulation methods. */ + private final Object lock = new Object(); + + private final Configuration configuration; + + private final CompletableFuture terminationFuture; + + private final AtomicBoolean isShutDown = new AtomicBoolean(false); + + private final HaServices haServices; + + private final RemoteShuffleRpcService shuffleRpcService; + + private final ExecutorService ioExecutor; + + private final ShuffleManager shuffleManager; + + public ShuffleManagerEntrypoint(Configuration configuration) throws Exception { + this.configuration = checkNotNull(configuration); + + this.terminationFuture = new CompletableFuture<>(); + + AkkaRpcServiceUtils.loadRpcSystem(configuration); + this.shuffleRpcService = + AkkaRpcServiceUtils.createRemoteRpcService( + configuration, + configuration.getString(ManagerOptions.RPC_ADDRESS), + String.valueOf(configuration.getInteger(ManagerOptions.RPC_PORT)), + configuration.getString(ManagerOptions.RPC_BIND_ADDRESS), + Optional.ofNullable( + configuration.getInteger(ManagerOptions.RPC_BIND_PORT))); + + MetricUtils.startManagerMetricSystem(configuration); + + // update the configuration used to create the high availability services + configuration.setString(ManagerOptions.RPC_ADDRESS, shuffleRpcService.getAddress()); + configuration.setInteger(ManagerOptions.RPC_PORT, shuffleRpcService.getPort()); + + this.ioExecutor = Executors.newFixedThreadPool(1, new ExecutorThreadFactory("cluster-io")); + + this.haServices = HaServiceUtils.createHAServices(configuration); + + HeartbeatServices workerHeartbeatServices = + HeartbeatServicesUtils.createManagerWorkerHeartbeatServices(configuration); + HeartbeatServices jobHeartbeatServices = + HeartbeatServicesUtils.createManagerJobHeartbeatServices(configuration); + + this.shuffleManager = + new ShuffleManager( + shuffleRpcService, + new InstanceID(), + haServices, + this, + ioExecutor, + jobHeartbeatServices, + workerHeartbeatServices, + new AssignmentTrackerImpl()); + } + + // -------------------------------------------------------------------------------------------- + // Lifecycle management + // -------------------------------------------------------------------------------------------- + + public static void runShuffleManagerEntrypoint( + ShuffleManagerEntrypoint shuffleManagerEntrypoint) { + final String clusterEntrypointName = shuffleManagerEntrypoint.getClass().getSimpleName(); + try { + shuffleManagerEntrypoint.start(); + } catch (Exception e) { + LOG.error( + String.format("Could not start cluster entrypoint %s.", clusterEntrypointName), + e); + System.exit(STARTUP_FAILURE_RETURN_CODE); + } + + shuffleManagerEntrypoint + .getTerminationFuture() + .whenComplete( + (ignored, throwable) -> { + int returnCode = + throwable != null + ? RUNTIME_FAILURE_RETURN_CODE + : NORMAL_RETURN_CODE; + + LOG.info( + "Terminating cluster entrypoint process {} with exit code {}.", + clusterEntrypointName, + returnCode, + throwable); + System.exit(returnCode); + }); + } + + public void start() throws Exception { + LOG.info("Starting {}.", getClass().getSimpleName()); + + try { + runCluster(); + } catch (Throwable t) { + final Throwable strippedThrowable = + ExceptionUtils.stripException(t, UndeclaredThrowableException.class); + + try { + // clean up any partial state + shutDownAsync(ExceptionUtils.stringifyException(strippedThrowable)) + .get(INITIALIZATION_SHUTDOWN_TIMEOUT, TimeUnit.MILLISECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + strippedThrowable.addSuppressed(e); + } + + throw new ShuffleException( + String.format( + "Failed to initialize the cluster entrypoint %s.", + getClass().getSimpleName()), + strippedThrowable); + } + } + + private void runCluster() { + synchronized (lock) { + shuffleManager.start(); + + shuffleManager + .getTerminationFuture() + .whenComplete( + (applicationStatus, throwable) -> { + if (throwable != null) { + shutDownAsync(ExceptionUtils.stringifyException(throwable)); + } else { + shutDownAsync(null); + } + }); + } + } + + @Override + public CompletableFuture closeAsync() { + return shutDownAsync("Cluster entrypoint has been closed externally.") + .thenAccept(ignored -> {}); + } + + protected CompletableFuture stopClusterServices() { + synchronized (lock) { + Throwable exception = null; + + final Collection> terminationFutures = new ArrayList<>(3); + + if (haServices != null) { + try { + haServices.close(); + } catch (Throwable throwable) { + exception = throwable; + LOG.error("Failed to close HA service.", throwable); + } + } + + if (ioExecutor != null) { + try { + ioExecutor.shutdown(); + } catch (Throwable throwable) { + exception = exception == null ? throwable : exception; + LOG.error("Failed to close executor service.", throwable); + } + } + + if (shuffleRpcService != null) { + terminationFutures.add(shuffleRpcService.stopService()); + } + + if (exception != null) { + terminationFutures.add(FutureUtils.completedExceptionally(exception)); + } + + MetricUtils.stopMetricSystem(); + + return FutureUtils.completeAll(terminationFutures); + } + } + + @Override + public void onFatalError(Throwable exception) { + LOG.error("Fatal error occurred in the cluster entrypoint.", exception); + System.exit(RUNTIME_FAILURE_RETURN_CODE); + } + + // -------------------------------------------------- + // Internal methods + // -------------------------------------------------- + + public CompletableFuture getTerminationFuture() { + return terminationFuture; + } + + // -------------------------------------------------- + // Helper methods + // -------------------------------------------------- + + private CompletableFuture shutDownAsync(@Nullable String diagnostics) { + if (isShutDown.compareAndSet(false, true)) { + LOG.info("Shutting {} down. Diagnostics {}.", getClass().getSimpleName(), diagnostics); + + stopClusterServices() + .whenComplete( + (Void ignored2, Throwable serviceThrowable) -> { + if (serviceThrowable != null) { + terminationFuture.completeExceptionally(serviceThrowable); + } else { + terminationFuture.complete(null); + } + }); + } + + return terminationFuture; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/metrics/ClusterMetrics.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/metrics/ClusterMetrics.java new file mode 100644 index 00000000..ee7cf5f2 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/metrics/ClusterMetrics.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.metrics; + +import com.alibaba.flink.shuffle.metrics.entry.MetricUtils; + +import com.alibaba.metrics.Gauge; + +import java.util.function.Supplier; + +/** Constants and util methods of CLUSTER metrics. */ +public class ClusterMetrics { + + // Group name + public static final String CLUSTER = "remote-shuffle.cluster"; + + // Number of available shuffle workers. + public static final String NUM_SHUFFLE_WORKERS = CLUSTER + ".num_shuffle_workers"; + + // Number of jobs under serving. + public static final String NUM_JOBS_SERVING = CLUSTER + ".num_jobs_serving"; + + public static void registerGaugeForNumShuffleWorkers(Supplier shuffleWorkerNum) { + MetricUtils.registerMetric( + CLUSTER, + NUM_SHUFFLE_WORKERS, + new Gauge() { + @Override + public Integer getValue() { + return shuffleWorkerNum.get(); + } + + @Override + public long lastUpdateTime() { + return System.currentTimeMillis(); + } + }); + } + + public static void registerGaugeForNumJobsServing(Supplier jobsServingNum) { + MetricUtils.registerMetric( + CLUSTER, + NUM_JOBS_SERVING, + new Gauge() { + @Override + public Integer getValue() { + return jobsServingNum.get(); + } + + @Override + public long lastUpdateTime() { + return System.currentTimeMillis(); + } + }); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/ConnectingConnection.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/ConnectingConnection.java new file mode 100644 index 00000000..67d8725f --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/ConnectingConnection.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.registration; + +import com.alibaba.flink.shuffle.coordinator.manager.RegistrationSuccess; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcGateway; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; + +import org.slf4j.Logger; + +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.function.Function; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** The connecting connection to one rpc address. */ +public class ConnectingConnection + extends RegisteredRpcConnection { + + /** The name of the target. */ + private final String targetName; + + /** The class of the rpc gateway. */ + private final Class targetType; + + /** The rpc service used to connect to the rpc address. */ + private final RemoteShuffleRpcService rpcService; + + /** The configuration for registration. */ + private final RetryingRegistrationConfiguration retryingRegistrationConfiguration; + + /** The listener when connecting succeed. */ + private final RegistrationConnectionListener, S> + registrationListener; + + private final Function> registrationFunction; + + public ConnectingConnection( + Logger log, + String targetName, + Class targetType, + RemoteShuffleRpcService rpcService, + RetryingRegistrationConfiguration retryingRegistrationConfiguration, + String targetAddress, + UUID leaderId, + Executor executor, + RegistrationConnectionListener, S> registrationListener, + Function> registrationFunction) { + + super(log, targetAddress, leaderId, executor); + this.targetName = checkNotNull(targetName); + this.targetType = checkNotNull(targetType); + this.rpcService = checkNotNull(rpcService); + this.retryingRegistrationConfiguration = checkNotNull(retryingRegistrationConfiguration); + this.registrationListener = checkNotNull(registrationListener); + this.registrationFunction = checkNotNull(registrationFunction); + } + + @Override + protected RetryingRegistration generateRegistration() { + return new RpcTargetRegistration( + log, + rpcService, + targetName, + targetType, + getTargetAddress(), + getTargetLeaderId(), + retryingRegistrationConfiguration, + registrationFunction); + } + + @Override + protected void onRegistrationSuccess(S success) { + log.info("Successful registration at shuffle manager {}", getTargetAddress()); + + registrationListener.onRegistrationSuccess(this, success); + } + + @Override + protected void onRegistrationFailure(Throwable failure) { + log.info("Failed to register at shuffle manager {}.", getTargetAddress(), failure); + + registrationListener.onRegistrationFailure(failure); + } + + // ------------------------------------------------------------------------ + // Utilities + // ------------------------------------------------------------------------ + + private class RpcTargetRegistration extends RetryingRegistration { + + private final Function> registrationFunction; + + public RpcTargetRegistration( + Logger log, + RemoteShuffleRpcService rpcService, + String targetName, + Class targetType, + String targetAddress, + UUID fencingToken, + RetryingRegistrationConfiguration retryingRegistrationConfiguration, + Function> registrationFunction) { + + super( + log, + rpcService, + targetName, + targetType, + targetAddress, + fencingToken, + retryingRegistrationConfiguration); + this.registrationFunction = registrationFunction; + } + + @Override + protected CompletableFuture invokeRegistration( + G gateway, UUID fencingToken) throws Exception { + return registrationFunction.apply(gateway).thenApply(result -> result); + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/EstablishedConnection.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/EstablishedConnection.java new file mode 100644 index 00000000..efa543b9 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/EstablishedConnection.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.registration; + +/** The established connection. */ +public class EstablishedConnection { + + /** The gateway of the connection. */ + private final G gateway; + + /** The response when the connection succeed. */ + private final S response; + + public EstablishedConnection(G gateway, S response) { + this.gateway = gateway; + this.response = response; + } + + public G getGateway() { + return gateway; + } + + public S getResponse() { + return response; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/RegisteredRpcConnection.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/RegisteredRpcConnection.java new file mode 100644 index 00000000..fa450d14 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/RegisteredRpcConnection.java @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.registration; + +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcGateway; + +import org.apache.commons.lang3.tuple.Pair; +import org.slf4j.Logger; + +import java.io.Serializable; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * This class contains the bases of RPC connecting from one component to another component, for + * example the RPC connection from ShuffleWorker to ShuffleManager. This {@code + * RegisteredRpcConnection} implements registration and get target gateway. + * + *

The registration gives access to a future that is completed upon successful registration. The + * RPC connection can be closed, for example when the target where it tries to register at looses + * leader status. + * + * @param The type of the fencing token + * @param The type of the gateway to connect to. + * @param The type of the successful registration responses. + */ +public abstract class RegisteredRpcConnection< + F extends Serializable, + G extends RemoteShuffleRpcGateway, + S extends RegistrationResponse.Success> { + + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater + REGISTRATION_UPDATER = + AtomicReferenceFieldUpdater.newUpdater( + RegisteredRpcConnection.class, + RetryingRegistration.class, + "pendingRegistration"); + + /** The logger for all log messages of this class. */ + protected final Logger log; + + /** The fencing token fo the remote component. */ + private final F fencingToken; + + /** The target component Address, for example the ShuffleManager Address. */ + private final String targetAddress; + + /** + * Execution context to be used to execute the on complete action of the + * ShuffleManagerRegistration. + */ + private final Executor executor; + + /** The Registration of this RPC connection. */ + private volatile RetryingRegistration pendingRegistration; + + /** The gateway to register, it's null until the registration is completed. */ + private volatile G targetGateway; + + /** Flag indicating that the RPC connection is closed. */ + private volatile boolean closed; + + // ------------------------------------------------------------------------ + + public RegisteredRpcConnection( + Logger log, String targetAddress, F fencingToken, Executor executor) { + this.log = checkNotNull(log); + this.targetAddress = checkNotNull(targetAddress); + this.fencingToken = checkNotNull(fencingToken); + this.executor = checkNotNull(executor); + } + + // ------------------------------------------------------------------------ + // Life cycle + // ------------------------------------------------------------------------ + + public void start() { + checkState(!closed, "The RPC connection is already closed"); + checkState( + !isConnected() && pendingRegistration == null, + "The RPC connection is already started"); + + final RetryingRegistration newRegistration = createNewRegistration(); + + if (REGISTRATION_UPDATER.compareAndSet(this, null, newRegistration)) { + newRegistration.startRegistration(); + } else { + // concurrent start operation + newRegistration.cancel(); + } + } + + /** + * Tries to reconnect to the {@link #targetAddress} by cancelling the pending registration and + * starting a new pending registration. + * + * @return {@code false} if the connection has been closed or a concurrent modification has + * happened; otherwise {@code true} + */ + public boolean tryReconnect() { + checkState(isConnected(), "Cannot reconnect to an unknown destination."); + + if (closed) { + return false; + } else { + final RetryingRegistration currentPendingRegistration = pendingRegistration; + + if (currentPendingRegistration != null) { + currentPendingRegistration.cancel(); + } + + final RetryingRegistration newRegistration = createNewRegistration(); + + if (REGISTRATION_UPDATER.compareAndSet( + this, currentPendingRegistration, newRegistration)) { + newRegistration.startRegistration(); + } else { + // concurrent modification + newRegistration.cancel(); + return false; + } + + // double check for concurrent close operations + if (closed) { + newRegistration.cancel(); + + return false; + } else { + return true; + } + } + } + + /** + * This method generate a specific Registration, for example shuffle worker Registration at the + * shuffle manager. + */ + protected abstract RetryingRegistration generateRegistration(); + + /** This method handle the Registration Response. */ + protected abstract void onRegistrationSuccess(S success); + + /** This method handle the Registration failure. */ + protected abstract void onRegistrationFailure(Throwable failure); + + /** Close connection. */ + public void close() { + closed = true; + + // make sure we do not keep re-trying forever + if (pendingRegistration != null) { + pendingRegistration.cancel(); + } + } + + public boolean isClosed() { + return closed; + } + + // ------------------------------------------------------------------------ + // Properties + // ------------------------------------------------------------------------ + + public F getTargetLeaderId() { + return fencingToken; + } + + public String getTargetAddress() { + return targetAddress; + } + + /** Gets the RegisteredGateway. This returns null until the registration is completed. */ + public G getTargetGateway() { + return targetGateway; + } + + public boolean isConnected() { + return targetGateway != null; + } + + // ------------------------------------------------------------------------ + + @Override + public String toString() { + String connectionInfo = + "(ADDRESS: " + targetAddress + " FENCINGTOKEN: " + fencingToken + ")"; + + if (isConnected()) { + connectionInfo = + "RPC connection to " + + targetGateway.getClass().getSimpleName() + + " " + + connectionInfo; + } else { + connectionInfo = "RPC connection to " + connectionInfo; + } + + if (isClosed()) { + connectionInfo += " is closed"; + } else if (isConnected()) { + connectionInfo += " is established"; + } else { + connectionInfo += " is connecting"; + } + + return connectionInfo; + } + + // ------------------------------------------------------------------------ + // Internal methods + // ------------------------------------------------------------------------ + + private RetryingRegistration createNewRegistration() { + RetryingRegistration newRegistration = checkNotNull(generateRegistration()); + + CompletableFuture> future = newRegistration.getFuture(); + + future.whenCompleteAsync( + (Pair result, Throwable failure) -> { + if (failure != null) { + if (failure instanceof CancellationException) { + // we ignore cancellation exceptions because they originate from + // cancelling + // the RetryingRegistration + log.debug( + "Retrying registration towards {} was cancelled.", + targetAddress); + } else { + // this future should only ever fail if there is a bug, not if the + // registration is declined + onRegistrationFailure(failure); + } + } else { + targetGateway = result.getLeft(); + onRegistrationSuccess(result.getRight()); + } + }, + executor); + + return newRegistration; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/RegistrationConnectionListener.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/RegistrationConnectionListener.java new file mode 100644 index 00000000..03d7eb10 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/RegistrationConnectionListener.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.registration; + +/** + * Classes which want to be notified about the registration result by the {@link + * RegisteredRpcConnection} have to implement this interface. + */ +public interface RegistrationConnectionListener< + T extends RegisteredRpcConnection, S extends RegistrationResponse.Success> { + + /** + * This method is called by the {@link RegisteredRpcConnection} when the registration is + * success. + * + * @param success The concrete response information for successful registration. + * @param connection The instance which established the connection + */ + void onRegistrationSuccess(T connection, S success); + + /** + * This method is called by the {@link RegisteredRpcConnection} when the registration fails. + * + * @param failure The exception which causes the registration failure. + */ + void onRegistrationFailure(Throwable failure); +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/RegistrationResponse.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/RegistrationResponse.java new file mode 100644 index 00000000..292980db --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/RegistrationResponse.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.registration; + +import java.io.Serializable; + +/** Base class for responses given to registration attempts from {@link RetryingRegistration}. */ +public abstract class RegistrationResponse implements Serializable { + + private static final long serialVersionUID = -3764580247222612427L; + + // ---------------------------------------------------------------------------- + + /** + * Base class for a successful registration. Concrete registration implementations will + * typically extend this class to attach more information. + */ + public static class Success extends RegistrationResponse { + private static final long serialVersionUID = 1L; + + @Override + public String toString() { + return "Registration Successful"; + } + } + + // ---------------------------------------------------------------------------- + + /** A rejected (declined) registration. */ + public static final class Decline extends RegistrationResponse { + private static final long serialVersionUID = 1L; + + /** The rejection reason. */ + private final String reason; + + /** + * Creates a new rejection message. + * + * @param reason The reason for the rejection. + */ + public Decline(String reason) { + this.reason = reason != null ? reason : "(unknown)"; + } + + /** Gets the reason for the rejection. */ + public String getReason() { + return reason; + } + + @Override + public String toString() { + return "Registration Declined (" + reason + ')'; + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/RetryingRegistration.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/RetryingRegistration.java new file mode 100644 index 00000000..b6142fbb --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/RetryingRegistration.java @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.registration; + +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleFencedRpcGateway; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcGateway; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; + +import org.apache.commons.lang3.tuple.Pair; +import org.slf4j.Logger; + +import java.io.Serializable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** + * This class implements the basis of registering one component at another component, for example + * registering the ShuffleWorker at the ShuffleManager. This {@code RetryingRegistration} implements + * both the initial address resolution and the retries-with-backoff strategy. + * + *

The registration gives access to a future that is completed upon successful registration. The + * registration can be canceled, for example when the target where it tries to register at looses + * leader status. + * + * @param The type of the fencing token + * @param The type of the gateway to connect to. + * @param The type of the successful registration responses. + */ +public abstract class RetryingRegistration< + F extends Serializable, + G extends RemoteShuffleRpcGateway, + S extends RegistrationResponse.Success> { + + // ------------------------------------------------------------------------ + // Fields + // ------------------------------------------------------------------------ + + private final Logger log; + + private final RemoteShuffleRpcService rpcService; + + private final String targetName; + + private final Class targetType; + + private final String targetAddress; + + private final F fencingToken; + + private final CompletableFuture> completionFuture; + + private final RetryingRegistrationConfiguration retryingRegistrationConfiguration; + + private volatile boolean canceled; + + // ------------------------------------------------------------------------ + + public RetryingRegistration( + Logger log, + RemoteShuffleRpcService rpcService, + String targetName, + Class targetType, + String targetAddress, + F fencingToken, + RetryingRegistrationConfiguration retryingRegistrationConfiguration) { + + this.log = checkNotNull(log); + this.rpcService = checkNotNull(rpcService); + this.targetName = checkNotNull(targetName); + this.targetType = checkNotNull(targetType); + this.targetAddress = checkNotNull(targetAddress); + this.fencingToken = checkNotNull(fencingToken); + this.retryingRegistrationConfiguration = checkNotNull(retryingRegistrationConfiguration); + + this.completionFuture = new CompletableFuture<>(); + } + + // ------------------------------------------------------------------------ + // completion and cancellation + // ------------------------------------------------------------------------ + + public CompletableFuture> getFuture() { + return completionFuture; + } + + /** Cancels the registration procedure. */ + public void cancel() { + canceled = true; + completionFuture.cancel(false); + } + + /** + * Checks if the registration was canceled. + * + * @return True if the registration was canceled, false otherwise. + */ + public boolean isCanceled() { + return canceled; + } + + // ------------------------------------------------------------------------ + // registration + // ------------------------------------------------------------------------ + + protected abstract CompletableFuture invokeRegistration( + G gateway, F fencingToken) throws Exception; + + /** + * This method resolves the target address to a callable gateway and starts the registration + * after that. + */ + @SuppressWarnings("unchecked") + public void startRegistration() { + if (canceled) { + // we already got canceled + return; + } + + try { + // trigger resolution of the target address to a callable gateway + final CompletableFuture rpcGatewayFuture; + + if (RemoteShuffleFencedRpcGateway.class.isAssignableFrom(targetType)) { + rpcGatewayFuture = + (CompletableFuture) + rpcService.connectTo( + targetAddress, + fencingToken, + targetType.asSubclass(RemoteShuffleFencedRpcGateway.class)); + } else { + rpcGatewayFuture = rpcService.connectTo(targetAddress, targetType); + } + + // upon success, start the registration attempts + CompletableFuture rpcGatewayAcceptFuture = + rpcGatewayFuture.thenAcceptAsync( + (G rpcGateway) -> { + log.info( + "Resolved {} address, beginning registration.", targetName); + register(rpcGateway, 1); + }, + rpcService.getExecutor()); + + // upon failure, retry, unless this is cancelled + rpcGatewayAcceptFuture.whenCompleteAsync( + (Void v, Throwable failure) -> { + if (failure != null && !canceled) { + final Throwable strippedFailure = + ExceptionUtils.stripException( + failure, CompletionException.class); + if (log.isDebugEnabled()) { + log.debug( + "Could not resolve {} address {}, retrying in {} ms.", + targetName, + targetAddress, + retryingRegistrationConfiguration.getErrorDelayMillis(), + strippedFailure); + } else { + log.info( + "Could not resolve {} address {}, retrying in {} ms: {}", + targetName, + targetAddress, + retryingRegistrationConfiguration.getErrorDelayMillis(), + strippedFailure.getMessage()); + } + + startRegistrationLater( + retryingRegistrationConfiguration.getErrorDelayMillis()); + } + }, + rpcService.getExecutor()); + } catch (Throwable t) { + completionFuture.completeExceptionally(t); + cancel(); + } + } + + /** + * This method performs a registration attempt and triggers either a success notification or a + * retry, depending on the result. + */ + @SuppressWarnings("unchecked") + private void register(final G gateway, final int attempt) { + // eager check for canceling to avoid some unnecessary work + if (canceled) { + return; + } + + try { + log.debug("Registration at {} attempt {}.", targetName, attempt); + CompletableFuture registrationFuture = + invokeRegistration(gateway, fencingToken); + + // if the registration was successful, let the ShuffleWorker know + CompletableFuture registrationAcceptFuture = + registrationFuture.thenAcceptAsync( + (RegistrationResponse result) -> { + if (!isCanceled()) { + if (result instanceof RegistrationResponse.Success) { + // registration successful! + S success = (S) result; + completionFuture.complete(Pair.of(gateway, success)); + } else { + // registration refused or unknown + if (result instanceof RegistrationResponse.Decline) { + RegistrationResponse.Decline decline = + (RegistrationResponse.Decline) result; + log.info( + "Registration at {} was declined: {}.", + targetName, + decline.getReason()); + } else { + log.error( + "Received unknown response to registration attempt: {}.", + result); + } + + log.info( + "Pausing and re-attempting registration in {} ms.", + retryingRegistrationConfiguration + .getRefusedDelayMillis()); + registerLater( + gateway, + attempt + 1, + retryingRegistrationConfiguration + .getRefusedDelayMillis()); + } + } + }, + rpcService.getExecutor()); + + // upon failure, retry + registrationAcceptFuture.whenCompleteAsync( + (Void v, Throwable failure) -> { + if (failure != null && !isCanceled()) { + if (ExceptionUtils.stripException(failure, CompletionException.class) + instanceof TimeoutException) { + log.debug( + "Registration at {} ({}) attempt {} timed out.", + targetName, + targetAddress, + attempt); + register(gateway, attempt + 1); + } else { + // a serious failure occurred. we still should not give up, but keep + // trying + log.error( + "Registration at {} failed due to an error, pausing and " + + "re-attempting registration in {} ms.", + targetName, + retryingRegistrationConfiguration.getErrorDelayMillis(), + failure); + + registerLater( + gateway, + attempt + 1, + retryingRegistrationConfiguration.getErrorDelayMillis()); + } + } + }, + rpcService.getExecutor()); + } catch (Throwable t) { + completionFuture.completeExceptionally(t); + cancel(); + } + } + + private void registerLater(final G gateway, final int attempt, long delay) { + rpcService.scheduleRunnable(() -> register(gateway, attempt), delay, TimeUnit.MILLISECONDS); + } + + private void startRegistrationLater(final long delay) { + rpcService.scheduleRunnable(this::startRegistration, delay, TimeUnit.MILLISECONDS); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/RetryingRegistrationConfiguration.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/RetryingRegistrationConfiguration.java new file mode 100644 index 00000000..1c318c73 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/registration/RetryingRegistrationConfiguration.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.registration; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; + +/** Configuration for the cluster components. */ +public class RetryingRegistrationConfiguration { + + private final long errorDelayMillis; + + private final long refusedDelayMillis; + + public RetryingRegistrationConfiguration(long errorDelayMillis, long refusedDelayMillis) { + checkArgument(errorDelayMillis >= 0, "delay on error must be non-negative"); + checkArgument( + refusedDelayMillis >= 0, "delay on refused registration must be non-negative"); + + this.errorDelayMillis = errorDelayMillis; + this.refusedDelayMillis = refusedDelayMillis; + } + + public long getErrorDelayMillis() { + return errorDelayMillis; + } + + public long getRefusedDelayMillis() { + return refusedDelayMillis; + } + + public static RetryingRegistrationConfiguration fromConfiguration( + final Configuration configuration) { + long errorDelayMillis = + configuration.getDuration(ClusterOptions.ERROR_REGISTRATION_DELAY).toMillis(); + long refusedDelayMillis = + configuration.getDuration(ClusterOptions.REFUSED_REGISTRATION_DELAY).toMillis(); + + return new RetryingRegistrationConfiguration(errorDelayMillis, refusedDelayMillis); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/utils/ClusterEntrypointUtils.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/utils/ClusterEntrypointUtils.java new file mode 100644 index 00000000..1ccbafda --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/utils/ClusterEntrypointUtils.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.utils; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.FatalErrorExitUtils; +import com.alibaba.flink.shuffle.core.utils.ConfigurationParserUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Utility class for running manager and workers. */ +public final class ClusterEntrypointUtils { + + private static final Logger LOG = LoggerFactory.getLogger(ClusterEntrypointUtils.class); + + public static final int STARTUP_FAILURE_RETURN_CODE = 1; + + private ClusterEntrypointUtils() { + throw new UnsupportedOperationException("This class should not be instantiated."); + } + + /** + * Parses passed String array using the parameter definitions of the passed {@code + * ParserResultFactory}. The method will call {@code System.exit} and print the usage + * information to stdout in case of a parsing error. + * + * @param args The String array that shall be parsed. + * @return The parsing result. + */ + public static Configuration parseParametersOrExit(String[] args) { + + try { + return ConfigurationParserUtils.loadConfiguration(args); + } catch (Exception e) { + LOG.error("Could not parse command line arguments {}.", args, e); + FatalErrorExitUtils.exitProcessIfNeeded(STARTUP_FAILURE_RETURN_CODE); + } + + return null; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/utils/ConnectionUtils.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/utils/ConnectionUtils.java new file mode 100644 index 00000000..891397f6 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/utils/ConnectionUtils.java @@ -0,0 +1,539 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.utils; + +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalListener; +import com.alibaba.flink.shuffle.rpc.utils.AkkaRpcServiceUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.NetworkInterface; +import java.net.Socket; +import java.net.SocketAddress; +import java.net.UnknownHostException; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.Enumeration; +import java.util.List; + +/** + * Utilities to determine the network interface and address that should be used to bind the + * ShuffleWorker communication to. + * + *

Implementation note: This class uses {@code System.nanoTime()} to measure elapsed time, + * because that is not susceptible to clock changes. + * + *

This class is copied from Apache Flink (org.apache.flink.runtime.net.ConnectionUtils). + */ +public class ConnectionUtils { + + private static final Logger LOG = LoggerFactory.getLogger(ConnectionUtils.class); + + private static final long MIN_SLEEP_TIME = 50; + private static final long MAX_SLEEP_TIME = 20000; + + /** + * The states of address detection mechanism. There is only a state transition if the current + * state failed to determine the address. + */ + private enum AddressDetectionState { + /** Connect from interface returned by InetAddress.getLocalHost(). * */ + LOCAL_HOST(200), + /** Detect own IP address based on the target IP address. Look for common prefix */ + ADDRESS(50), + /** Try to connect on all Interfaces and all their addresses with a low timeout. */ + FAST_CONNECT(50), + /** Try to connect on all Interfaces and all their addresses with a long timeout. */ + SLOW_CONNECT(1000), + /** Choose any non-loopback address. */ + HEURISTIC(0); + + private final int timeout; + + AddressDetectionState(int timeout) { + this.timeout = timeout; + } + + public int getTimeout() { + return timeout; + } + } + + /** + * Finds the local network address from which this machine can connect to the target address. + * This method tries to establish a proper network connection to the given target, so it only + * succeeds if the target socket address actually accepts connections. The method tries various + * strategies multiple times and uses an exponential backoff timer between tries. + * + *

If no connection attempt was successful after the given maximum time, the method will + * choose some address based on heuristics (excluding link-local and loopback addresses.) + * + *

This method will initially not log on info level (to not flood the log while the backoff + * time is still very low). It will start logging after a certain time has passes. + * + * @param targetAddress The address that the method tries to connect to. + * @param maxWaitMillis The maximum time that this method tries to connect, before falling back + * to the heuristics. + * @param startLoggingAfter The time after which the method will log on INFO level. + */ + public static InetAddress findConnectingAddress( + InetSocketAddress targetAddress, long maxWaitMillis, long startLoggingAfter) + throws IOException { + if (targetAddress == null) { + throw new NullPointerException("targetAddress must not be null"); + } + if (maxWaitMillis <= 0) { + throw new IllegalArgumentException("Max wait time must be positive"); + } + + final long startTimeNanos = System.nanoTime(); + + long currentSleepTime = MIN_SLEEP_TIME; + long elapsedTimeMillis = 0; + + final List strategies = + Collections.unmodifiableList( + Arrays.asList( + AddressDetectionState.LOCAL_HOST, + AddressDetectionState.ADDRESS, + AddressDetectionState.FAST_CONNECT, + AddressDetectionState.SLOW_CONNECT)); + + // loop while there is time left + while (elapsedTimeMillis < maxWaitMillis) { + boolean logging = elapsedTimeMillis >= startLoggingAfter; + if (logging) { + LOG.info("Trying to connect to " + targetAddress); + } + + // Try each strategy in order + for (AddressDetectionState strategy : strategies) { + InetAddress address = findAddressUsingStrategy(strategy, targetAddress, logging); + if (address != null) { + return address; + } + } + + // we have made a pass with all strategies over all interfaces + // sleep for a while before we make the next pass + elapsedTimeMillis = (System.nanoTime() - startTimeNanos) / 1_000_000; + + long toWait = Math.min(maxWaitMillis - elapsedTimeMillis, currentSleepTime); + if (toWait > 0) { + if (logging) { + LOG.info("Could not connect. Waiting for {} msecs before next attempt", toWait); + } else { + LOG.debug( + "Could not connect. Waiting for {} msecs before next attempt", toWait); + } + + try { + Thread.sleep(toWait); + } catch (InterruptedException e) { + throw new IOException("Connection attempts have been interrupted."); + } + } + + // increase the exponential backoff timer + currentSleepTime = Math.min(2 * currentSleepTime, MAX_SLEEP_TIME); + } + + // our attempts timed out. use the heuristic fallback + LOG.warn( + "Could not connect to {}. Selecting a local address using heuristics.", + targetAddress); + InetAddress heuristic = + findAddressUsingStrategy(AddressDetectionState.HEURISTIC, targetAddress, true); + if (heuristic != null) { + return heuristic; + } else { + LOG.warn( + "Could not find any IPv4 address that is not loopback or link-local. Using localhost address."); + return InetAddress.getLocalHost(); + } + } + + /** + * This utility method tries to connect to the ShuffleManager using the InetAddress returned by + * InetAddress.getLocalHost(). The purpose of the utility is to have a final try connecting to + * the target address using the LocalHost before using the address returned. We do a second try + * because the JM might have been unavailable during the first check. + * + * @param preliminaryResult The address detected by the heuristic + * @return either the preliminaryResult or the address returned by InetAddress.getLocalHost() + * (if we are able to connect to targetAddress from there) + */ + private static InetAddress tryLocalHostBeforeReturning( + InetAddress preliminaryResult, SocketAddress targetAddress, boolean logging) + throws IOException { + + InetAddress localhostName = InetAddress.getLocalHost(); + + if (preliminaryResult.equals(localhostName)) { + // preliminary result is equal to the local host name + return preliminaryResult; + } else if (tryToConnect( + localhostName, + targetAddress, + AddressDetectionState.SLOW_CONNECT.getTimeout(), + logging)) { + // success, we were able to use local host to connect + LOG.debug( + "Preferring {} (InetAddress.getLocalHost()) for local bind point over previous candidate {}", + localhostName, + preliminaryResult); + return localhostName; + } else { + // we have to make the preliminary result the final result + return preliminaryResult; + } + } + + /** + * Try to find a local address which allows as to connect to the targetAddress using the given + * strategy. + * + * @param strategy Depending on the strategy, the method will enumerate all interfaces, trying + * to connect to the target address + * @param targetAddress The address we try to connect to + * @param logging Boolean indicating the logging verbosity + * @return null if we could not find an address using this strategy, otherwise, the local + * address. + * @throws IOException + */ + private static InetAddress findAddressUsingStrategy( + AddressDetectionState strategy, InetSocketAddress targetAddress, boolean logging) + throws IOException { + // try LOCAL_HOST strategy independent of the network interfaces + if (strategy == AddressDetectionState.LOCAL_HOST) { + InetAddress localhostName; + try { + localhostName = InetAddress.getLocalHost(); + } catch (UnknownHostException uhe) { + LOG.warn("Could not resolve local hostname to an IP address: {}", uhe.getMessage()); + return null; + } + + if (tryToConnect(localhostName, targetAddress, strategy.getTimeout(), logging)) { + LOG.debug( + "Using InetAddress.getLocalHost() immediately for the connecting address"); + // Here, we are not calling tryLocalHostBeforeReturning() because it is the + // LOCAL_HOST strategy + return localhostName; + } else { + return null; + } + } + + final InetAddress address = targetAddress.getAddress(); + if (address == null) { + return null; + } + final byte[] targetAddressBytes = address.getAddress(); + + // for each network interface + Enumeration e = NetworkInterface.getNetworkInterfaces(); + while (e.hasMoreElements()) { + + NetworkInterface netInterface = e.nextElement(); + + // for each address of the network interface + Enumeration ee = netInterface.getInetAddresses(); + while (ee.hasMoreElements()) { + InetAddress interfaceAddress = ee.nextElement(); + + switch (strategy) { + case ADDRESS: + if (hasCommonPrefix(targetAddressBytes, interfaceAddress.getAddress())) { + LOG.debug( + "Target address {} and local address {} share prefix - trying to connect.", + targetAddress, + interfaceAddress); + + if (tryToConnect( + interfaceAddress, + targetAddress, + strategy.getTimeout(), + logging)) { + return tryLocalHostBeforeReturning( + interfaceAddress, targetAddress, logging); + } + } + break; + + case FAST_CONNECT: + case SLOW_CONNECT: + LOG.debug( + "Trying to connect to {} from local address {} with timeout {}", + targetAddress, + interfaceAddress, + strategy.getTimeout()); + + if (tryToConnect( + interfaceAddress, targetAddress, strategy.getTimeout(), logging)) { + return tryLocalHostBeforeReturning( + interfaceAddress, targetAddress, logging); + } + break; + + case HEURISTIC: + if (LOG.isDebugEnabled()) { + LOG.debug( + "Choosing InetAddress.getLocalHost() address as a heuristic."); + } + + return InetAddress.getLocalHost(); + + default: + throw new RuntimeException("Unsupported strategy: " + strategy); + } + } // end for each address of the interface + } // end for each interface + + return null; + } + + /** + * Checks if two addresses have a common prefix (first 2 bytes). Example: 192.168.???.??? Works + * also with ipv6, but accepts probably too many addresses + */ + private static boolean hasCommonPrefix(byte[] address, byte[] address2) { + return address[0] == address2[0] && address[1] == address2[1]; + } + + /** + * @param fromAddress The address to connect from. + * @param toSocket The socket address to connect to. + * @param timeout The timeout fr the connection. + * @param logFailed Flag to indicate whether to log failed attempts on info level (failed + * attempts are always logged on DEBUG level). + * @return True, if the connection was successful, false otherwise. + * @throws IOException Thrown if the socket cleanup fails. + */ + private static boolean tryToConnect( + InetAddress fromAddress, SocketAddress toSocket, int timeout, boolean logFailed) + throws IOException { + if (LOG.isDebugEnabled()) { + LOG.debug( + "Trying to connect to (" + + toSocket + + ") from local address " + + fromAddress + + " with timeout " + + timeout); + } + try (Socket socket = new Socket()) { + // port 0 = let the OS choose the port + SocketAddress bindP = new InetSocketAddress(fromAddress, 0); + // machine + socket.bind(bindP); + socket.connect(toSocket, timeout); + return true; + } catch (Exception ex) { + String message = + "Failed to connect from address '" + fromAddress + "': " + ex.getMessage(); + if (LOG.isDebugEnabled()) { + LOG.debug(message, ex); + } else if (logFailed) { + LOG.info(message); + } + return false; + } + } + + /** + * A {@link LeaderRetrievalListener} that allows retrieving an {@link InetAddress} for the + * current leader. + */ + public static class LeaderConnectingAddressListener implements LeaderRetrievalListener { + + private static final Duration defaultLoggingDelay = Duration.ofMillis(400); + + private enum LeaderRetrievalState { + NOT_RETRIEVED, + RETRIEVED, + NEWLY_RETRIEVED + } + + private final Object retrievalLock = new Object(); + + private String akkaURL; + private LeaderRetrievalState retrievalState = LeaderRetrievalState.NOT_RETRIEVED; + private Exception exception; + + public InetAddress findConnectingAddress(Duration timeout) throws Exception { + return findConnectingAddress(timeout, defaultLoggingDelay); + } + + public InetAddress findConnectingAddress(Duration timeout, Duration startLoggingAfter) + throws Exception { + + final long startTimeNanos = System.nanoTime(); + long currentSleepTime = MIN_SLEEP_TIME; + long elapsedTimeMillis = 0; + InetSocketAddress targetAddress = null; + + try { + while (elapsedTimeMillis < timeout.toMillis()) { + + long maxTimeout = timeout.toMillis() - elapsedTimeMillis; + + synchronized (retrievalLock) { + if (exception != null) { + throw exception; + } + + if (retrievalState == LeaderRetrievalState.NOT_RETRIEVED) { + try { + retrievalLock.wait(maxTimeout); + } catch (InterruptedException e) { + throw new Exception( + "Finding connecting address was interrupted" + + "while waiting for the leader retrieval."); + } + } else if (retrievalState == LeaderRetrievalState.NEWLY_RETRIEVED) { + targetAddress = + AkkaRpcServiceUtils.getInetSocketAddressFromAkkaURL(akkaURL); + + LOG.debug( + "Retrieved new target address {} for akka URL {}.", + targetAddress, + akkaURL); + + retrievalState = LeaderRetrievalState.RETRIEVED; + + currentSleepTime = MIN_SLEEP_TIME; + } else { + currentSleepTime = Math.min(2 * currentSleepTime, MAX_SLEEP_TIME); + } + } + + if (targetAddress != null) { + AddressDetectionState strategy = AddressDetectionState.LOCAL_HOST; + + boolean logging = elapsedTimeMillis >= startLoggingAfter.toMillis(); + if (logging) { + LOG.info("Trying to connect to address {}", targetAddress); + } + + do { + InetAddress address = + findAddressUsingStrategy(strategy, targetAddress, logging); + if (address != null) { + return address; + } + + // pick the next strategy + switch (strategy) { + case LOCAL_HOST: + strategy = AddressDetectionState.ADDRESS; + break; + case ADDRESS: + strategy = AddressDetectionState.FAST_CONNECT; + break; + case FAST_CONNECT: + strategy = AddressDetectionState.SLOW_CONNECT; + break; + case SLOW_CONNECT: + strategy = null; + break; + default: + throw new RuntimeException("Unsupported strategy: " + strategy); + } + } while (strategy != null); + } + + elapsedTimeMillis = (System.nanoTime() - startTimeNanos) / 1_000_000; + + long timeToWait = + Math.min( + Math.max(timeout.toMillis() - elapsedTimeMillis, 0), + currentSleepTime); + + if (timeToWait > 0) { + synchronized (retrievalLock) { + try { + retrievalLock.wait(timeToWait); + } catch (InterruptedException e) { + throw new Exception( + "Finding connecting address was interrupted while pausing."); + } + } + + elapsedTimeMillis = (System.nanoTime() - startTimeNanos) / 1_000_000; + } + } + + InetAddress heuristic = null; + + if (targetAddress != null) { + LOG.warn( + "Could not connect to {}. Selecting a local address using heuristics.", + targetAddress); + heuristic = + findAddressUsingStrategy( + AddressDetectionState.HEURISTIC, targetAddress, true); + } + + if (heuristic != null) { + return heuristic; + } else { + LOG.warn( + "Could not find any IPv4 address that is not loopback or link-local. " + + "Using localhost address."); + return InetAddress.getLocalHost(); + } + } catch (Exception e) { + throw new Exception( + String.format( + "Could not retrieve the connecting address to the current leader " + + "with the akka URL %s .", + akkaURL), + e); + } + } + + @Override + public void notifyLeaderAddress(LeaderInformation leaderInfo) { + String leaderAddress = leaderInfo.getLeaderAddress(); + if (leaderAddress != null && !leaderAddress.isEmpty()) { + synchronized (retrievalLock) { + akkaURL = leaderAddress; + retrievalState = + LeaderConnectingAddressListener.LeaderRetrievalState.NEWLY_RETRIEVED; + retrievalLock.notifyAll(); + } + } + } + + @Override + public void handleError(Exception exception) { + synchronized (retrievalLock) { + this.exception = exception; + retrievalLock.notifyAll(); + } + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/utils/EnvironmentInformation.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/utils/EnvironmentInformation.java new file mode 100644 index 00000000..552b5373 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/utils/EnvironmentInformation.java @@ -0,0 +1,459 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.utils; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.Hardware; +import com.alibaba.flink.shuffle.common.utils.OperatingSystem; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.lang.management.ManagementFactory; +import java.lang.management.RuntimeMXBean; +import java.lang.reflect.Method; +import java.time.Instant; +import java.time.ZoneId; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeParseException; +import java.util.List; +import java.util.Properties; + +/** + * Utility class that gives access to the execution environment of the JVM, like the executing user, + * startup options, or the JVM version. + */ +public class EnvironmentInformation { + + private static final Logger LOG = LoggerFactory.getLogger(EnvironmentInformation.class); + + public static final String UNKNOWN = ""; + + // the keys whose values should be hidden + private static final String[] SENSITIVE_KEYS = + new String[] {"password", "secret", "fs.azure.account.key", "apikey"}; + + // the hidden content to be displayed + public static final String HIDDEN_CONTENT = "******"; + + /** + * Returns the version of the code as String. + * + * @return The project version string. + */ + public static String getVersion() { + return getVersionsInstance().projectVersion; + } + + /** @return The Instant this version of the software was built. */ + public static Instant getBuildTime() { + return getVersionsInstance().gitBuildTime; + } + + /** + * @return The Instant this version of the software was built as a String using the + * Europe/Berlin timezone. + */ + public static String getBuildTimeString() { + return getVersionsInstance().gitBuildTimeStr; + } + + /** @return The last known commit id of this version of the software. */ + public static String getGitCommitId() { + return getVersionsInstance().gitCommitId; + } + + /** @return The last known abbreviated commit id of this version of the software. */ + public static String getGitCommitIdAbbrev() { + return getVersionsInstance().gitCommitIdAbbrev; + } + + /** @return The Instant of the last commit of this code. */ + public static Instant getGitCommitTime() { + return getVersionsInstance().gitCommitTime; + } + + /** + * @return The Instant of the last commit of this code as a String using the Europe/Berlin + * timezone. + */ + public static String getGitCommitTimeString() { + return getVersionsInstance().gitCommitTimeStr; + } + + /** + * Returns the code revision (commit and commit date) of the remote shuffle implementation, as + * generated by the Maven builds. + * + * @return The code revision. + */ + public static RevisionInformation getRevisionInformation() { + return new RevisionInformation(getGitCommitIdAbbrev(), getGitCommitTimeString()); + } + + private static final class Versions { + private static final Instant DEFAULT_TIME_INSTANT = Instant.EPOCH; + private static final String DEFAULT_TIME_STRING = "1970-01-01T00:00:00+0000"; + private static final String UNKNOWN_COMMIT_ID = "DecafC0ffeeD0d0F00d"; + private static final String UNKNOWN_COMMIT_ID_ABBREV = "DeadD0d0"; + private String projectVersion = UNKNOWN; + private Instant gitBuildTime = DEFAULT_TIME_INSTANT; + private String gitBuildTimeStr = DEFAULT_TIME_STRING; + private String gitCommitId = UNKNOWN_COMMIT_ID; + private String gitCommitIdAbbrev = UNKNOWN_COMMIT_ID_ABBREV; + private Instant gitCommitTime = DEFAULT_TIME_INSTANT; + private String gitCommitTimeStr = DEFAULT_TIME_STRING; + + private static final String PROP_FILE = ".coordinator.version.properties"; + + private static final String FAIL_MESSAGE = + "The file " + + PROP_FILE + + " has not been generated correctly. You MUST run 'mvn generate-sources' in the shuffle-coordinator module."; + + private String getProperty(Properties properties, String key, String defaultValue) { + String value = properties.getProperty(key); + if (value == null || value.charAt(0) == '$') { + return defaultValue; + } + return value; + } + + public Versions() { + ClassLoader classLoader = EnvironmentInformation.class.getClassLoader(); + try (InputStream propFile = classLoader.getResourceAsStream(PROP_FILE)) { + if (propFile != null) { + Properties properties = new Properties(); + properties.load(propFile); + + projectVersion = getProperty(properties, "project.version", UNKNOWN); + + gitCommitId = getProperty(properties, "git.commit.id", UNKNOWN_COMMIT_ID); + gitCommitIdAbbrev = + getProperty( + properties, "git.commit.id.abbrev", UNKNOWN_COMMIT_ID_ABBREV); + + // This is to reliably parse the datetime format configured in the + // git-commit-id-plugin + DateTimeFormatter gitDateTimeFormatter = + DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ssZ"); + + DateTimeFormatter dateTime = + DateTimeFormatter.ISO_OFFSET_DATE_TIME.withZone(ZoneId.of("UTC")); + + try { + String propGitCommitTime = + getProperty(properties, "git.commit.time", DEFAULT_TIME_STRING); + gitCommitTime = + gitDateTimeFormatter.parse(propGitCommitTime, Instant::from); + gitCommitTimeStr = dateTime.format(gitCommitTime); + + String propGitBuildTime = + getProperty(properties, "git.build.time", DEFAULT_TIME_STRING); + gitBuildTime = gitDateTimeFormatter.parse(propGitBuildTime, Instant::from); + gitBuildTimeStr = dateTime.format(gitBuildTime); + } catch (DateTimeParseException dtpe) { + LOG.error("{} : {}", FAIL_MESSAGE, dtpe); + throw new IllegalStateException(FAIL_MESSAGE); + } + } + } catch (IOException ioe) { + LOG.info( + "Cannot determine code revision: Unable to read version property file.: {}", + ioe.getMessage()); + } + } + } + + private static final class VersionsHolder { + static final Versions INSTANCE = new Versions(); + } + + private static Versions getVersionsInstance() { + return VersionsHolder.INSTANCE; + } + + /** + * The maximum JVM heap size, in bytes. + * + *

This method uses the -Xmx value of the JVM, if set. If not set, it returns (as a + * heuristic) 1/4th of the physical memory size. + * + * @return The maximum JVM heap size, in bytes. + */ + public static long getMaxJvmHeapMemory() { + final long maxMemory = Runtime.getRuntime().maxMemory(); + if (maxMemory != Long.MAX_VALUE) { + // we have the proper max memory + return maxMemory; + } else { + // max JVM heap size is not set - use the heuristic to use 1/4th of the physical memory + final long physicalMemory = Hardware.getSizeOfPhysicalMemory(); + if (physicalMemory != -1) { + // got proper value for physical memory + return physicalMemory / 4; + } else { + throw new RuntimeException( + "Could not determine the amount of free memory.\n" + + "Please set the maximum memory for the JVM, e.g. -Xmx512M for 512 megabytes."); + } + } + } + + /** + * Gets an estimate of the size of the free heap memory. + * + *

NOTE: This method is heavy-weight. It triggers a garbage collection to reduce + * fragmentation and get a better estimate at the size of free memory. It is typically more + * accurate than the plain version {@link #getSizeOfFreeHeapMemory()}. + * + * @return An estimate of the size of the free heap memory, in bytes. + */ + public static long getSizeOfFreeHeapMemoryWithDefrag() { + // trigger a garbage collection, to reduce fragmentation + System.gc(); + + return getSizeOfFreeHeapMemory(); + } + + /** + * Gets an estimate of the size of the free heap memory. The estimate may vary, depending on the + * current level of memory fragmentation and the number of dead objects. For a better (but more + * heavy-weight) estimate, use {@link #getSizeOfFreeHeapMemoryWithDefrag()}. + * + * @return An estimate of the size of the free heap memory, in bytes. + */ + public static long getSizeOfFreeHeapMemory() { + Runtime r = Runtime.getRuntime(); + return getMaxJvmHeapMemory() - r.totalMemory() + r.freeMemory(); + } + + /** + * Gets the version of the JVM in the form "VM_Name - Vendor - Spec/Version". + * + * @return The JVM version. + */ + public static String getJvmVersion() { + try { + final RuntimeMXBean bean = ManagementFactory.getRuntimeMXBean(); + return bean.getVmName() + + " - " + + bean.getVmVendor() + + " - " + + bean.getSpecVersion() + + '/' + + bean.getVmVersion(); + } catch (Throwable t) { + return UNKNOWN; + } + } + + /** + * Gets the system parameters and environment parameters that were passed to the JVM on startup. + * + * @return The options passed to the JVM on startup. + */ + public static String getJvmStartupOptions() { + try { + final RuntimeMXBean bean = ManagementFactory.getRuntimeMXBean(); + final StringBuilder bld = new StringBuilder(); + + for (String s : bean.getInputArguments()) { + bld.append(s).append(' '); + } + + return bld.toString(); + } catch (Throwable t) { + return UNKNOWN; + } + } + + /** + * Gets the system parameters and environment parameters that were passed to the JVM on startup. + * + * @return The options passed to the JVM on startup. + */ + public static String[] getJvmStartupOptionsArray() { + try { + RuntimeMXBean bean = ManagementFactory.getRuntimeMXBean(); + List options = bean.getInputArguments(); + return options.toArray(new String[options.size()]); + } catch (Throwable t) { + return new String[0]; + } + } + + /** + * Gets the directory for temporary files, as returned by the JVM system property + * "java.io.tmpdir". + * + * @return The directory for temporary files. + */ + public static String getTemporaryFileDirectory() { + return System.getProperty("java.io.tmpdir"); + } + + /** + * Tries to retrieve the maximum number of open file handles. This method will only work on + * UNIX-based operating systems with Sun/Oracle Java versions. + * + *

If the number of max open file handles cannot be determined, this method returns {@code + * -1}. + * + * @return The limit of open file handles, or {@code -1}, if the limit could not be determined. + */ + public static long getOpenFileHandlesLimit() { + if (OperatingSystem + .isWindows()) { // getMaxFileDescriptorCount method is not available on Windows + return -1L; + } + Class sunBeanClass; + try { + sunBeanClass = Class.forName("com.sun.management.UnixOperatingSystemMXBean"); + } catch (ClassNotFoundException e) { + return -1L; + } + + try { + Method fhLimitMethod = sunBeanClass.getMethod("getMaxFileDescriptorCount"); + Object result = fhLimitMethod.invoke(ManagementFactory.getOperatingSystemMXBean()); + return (Long) result; + } catch (Throwable t) { + LOG.warn("Unexpected error when accessing file handle limit", t); + return -1L; + } + } + + /** + * Logs information about the environment, like code revision, current user, Java version, and + * JVM parameters. + * + * @param log The logger to log the information to. + * @param componentName The component name to mention in the log. + * @param commandLineArgs The arguments accompanying the starting the component. + */ + public static void logEnvironmentInfo( + Logger log, String componentName, String[] commandLineArgs) { + if (log.isInfoEnabled()) { + RevisionInformation rev = getRevisionInformation(); + String version = getVersion(); + + String jvmVersion = getJvmVersion(); + String[] options = getJvmStartupOptionsArray(); + + String javaHome = System.getenv("JAVA_HOME"); + + String inheritedLogs = System.getenv("SHUFFLE_INHERITED_LOGS"); + + long maxHeapMegabytes = getMaxJvmHeapMemory() >>> 20; + + if (inheritedLogs != null) { + log.info( + "--------------------------------------------------------------------------------"); + log.info(" Preconfiguration: "); + log.info(inheritedLogs); + } + + log.info( + "--------------------------------------------------------------------------------"); + log.info( + " Starting " + + componentName + + " (Version: " + + version + + ", " + + "Rev:" + + rev.commitId + + ", " + + "Date:" + + rev.commitDate + + ")"); + log.info(" OS current user: " + System.getProperty("user.name")); + log.info(" JVM: " + jvmVersion); + log.info(" Maximum heap size: " + maxHeapMegabytes + " MiBytes"); + log.info(" JAVA_HOME: " + (javaHome == null ? "(not set)" : javaHome)); + + if (options.length == 0) { + log.info(" JVM Options: (none)"); + } else { + log.info(" JVM Options:"); + for (String s : options) { + log.info(" " + s); + } + } + + if (commandLineArgs == null || commandLineArgs.length == 0) { + log.info(" Program Arguments: (none)"); + } else { + log.info(" Program Arguments:"); + for (String s : commandLineArgs) { + if (isSensitive(s)) { + log.info(" " + HIDDEN_CONTENT + " (sensitive information)"); + } else { + log.info(" " + s); + } + } + } + + log.info(" Classpath: " + System.getProperty("java.class.path")); + + log.info( + "--------------------------------------------------------------------------------"); + } + } + + public static boolean isSensitive(String key) { + CommonUtils.checkArgument(key != null, "Must be not null."); + + final String keyInLower = key.toLowerCase(); + for (String hideKey : SENSITIVE_KEYS) { + if (keyInLower.length() >= hideKey.length() && keyInLower.contains(hideKey)) { + return true; + } + } + return false; + } + + // -------------------------------------------------------------------------------------------- + + /** Don't instantiate this class. */ + private EnvironmentInformation() {} + + // -------------------------------------------------------------------------------------------- + + /** + * Revision information encapsulates information about the source code revision of the flink + * remote shuffle code. + */ + public static class RevisionInformation { + + /** The git commit id (hash). */ + public final String commitId; + + /** The git commit date. */ + public final String commitDate; + + public RevisionInformation(String commitId, String commitDate) { + this.commitId = commitId; + this.commitDate = commitDate; + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/utils/LeaderConnectionInfo.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/utils/LeaderConnectionInfo.java new file mode 100644 index 00000000..db72eb01 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/utils/LeaderConnectionInfo.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.utils; + +import java.util.UUID; + +/** + * Wrapper class for a pair of connection address and leader session ID. + * + *

This class is copied from Apache Flink (org.apache.flink.runtime.util.LeaderConnectionInfo). + */ +public class LeaderConnectionInfo { + + private final UUID leaderSessionId; + + private final String address; + + public LeaderConnectionInfo(UUID leaderSessionId, String address) { + this.leaderSessionId = leaderSessionId; + this.address = address; + } + + public UUID getLeaderSessionId() { + return leaderSessionId; + } + + public String getAddress() { + return address; + } + + @Override + public String toString() { + return "LeaderConnectionInfo{" + + "leaderSessionId=" + + leaderSessionId + + ", address='" + + address + + '\'' + + '}'; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/utils/LeaderRetrievalUtils.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/utils/LeaderRetrievalUtils.java new file mode 100644 index 00000000..6f346cb8 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/utils/LeaderRetrievalUtils.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.utils; + +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalListener; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalService; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.InetAddress; +import java.time.Duration; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +/** + * Utility class to work with {@link LeaderRetrievalService} class. + * + *

This class is copied from Apache Flink (org.apache.flink.runtime.util.LeaderRetrievalUtils). + */ +public class LeaderRetrievalUtils { + + private static final Logger LOG = LoggerFactory.getLogger(LeaderRetrievalUtils.class); + + /** + * Retrieves the leader akka url and the current leader session ID. The values are stored in a + * {@link LeaderConnectionInfo} instance. + * + * @param leaderRetrievalService Leader retrieval service to retrieve the leader connection + * information + * @param timeout Timeout when to give up looking for the leader + * @return LeaderConnectionInfo containing the leader's akka URL and the current leader session + * ID + */ + public static LeaderConnectionInfo retrieveLeaderConnectionInfo( + LeaderRetrievalService leaderRetrievalService, Duration timeout) throws Exception { + + LeaderConnectionInfoListener listener = + new LeaderRetrievalUtils.LeaderConnectionInfoListener(); + + try { + leaderRetrievalService.start(listener); + + return listener.getLeaderConnectionInfoFuture() + .get(timeout.toMillis(), TimeUnit.MILLISECONDS); + } catch (Exception e) { + throw new Exception( + "Could not retrieve the leader address and leader " + "session ID.", e); + } finally { + try { + leaderRetrievalService.stop(); + } catch (Exception fe) { + LOG.warn("Could not stop the leader retrieval service.", fe); + } + } + } + + public static InetAddress findConnectingAddress( + LeaderRetrievalService leaderRetrievalService, Duration timeout) throws Exception { + + ConnectionUtils.LeaderConnectingAddressListener listener = + new ConnectionUtils.LeaderConnectingAddressListener(); + + try { + leaderRetrievalService.start(listener); + + LOG.info( + "Trying to select the network interface and address to use " + + "by connecting to the leading ShuffleManager."); + + LOG.info( + "ShuffleWorker will try to connect for " + + timeout + + " before falling back to heuristics"); + + return listener.findConnectingAddress(timeout); + } catch (Exception e) { + throw new Exception( + "Could not find the connecting address by connecting to the current leader.", + e); + } finally { + try { + leaderRetrievalService.stop(); + } catch (Exception fe) { + LOG.warn("Could not stop the leader retrieval service.", fe); + } + } + } + + /** + * Helper class which is used by the retrieveLeaderConnectionInfo method to retrieve the + * leader's akka URL and the current leader session ID. + */ + public static class LeaderConnectionInfoListener implements LeaderRetrievalListener { + private final CompletableFuture connectionInfoFuture = + new CompletableFuture<>(); + + public CompletableFuture getLeaderConnectionInfoFuture() { + return connectionInfoFuture; + } + + @Override + public void notifyLeaderAddress(LeaderInformation leaderInfo) { + String leaderAddress = leaderInfo.getLeaderAddress(); + UUID leaderSessionID = leaderInfo.getLeaderSessionID(); + if (leaderAddress != null + && !leaderAddress.equals("") + && !connectionInfoFuture.isDone()) { + final LeaderConnectionInfo leaderConnectionInfo = + new LeaderConnectionInfo(leaderSessionID, leaderAddress); + connectionInfoFuture.complete(leaderConnectionInfo); + } + } + + @Override + public void handleError(Exception exception) { + connectionInfoFuture.completeExceptionally(exception); + } + } + + // ------------------------------------------------------------------------ + + /** Private constructor to prevent instantiation. */ + private LeaderRetrievalUtils() { + throw new RuntimeException(); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/HostBindPolicy.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/HostBindPolicy.java new file mode 100644 index 00000000..88e455b4 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/HostBindPolicy.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.worker; + +/** A host binding address mechanism policy. */ +enum HostBindPolicy { + NAME, + IP; + + public static HostBindPolicy fromString(String configValue) { + try { + return HostBindPolicy.valueOf(configValue.toUpperCase()); + } catch (IllegalArgumentException ex) { + throw new IllegalArgumentException("Unknown host bind policy: " + configValue); + } + } + + @Override + public String toString() { + return name().toLowerCase(); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorker.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorker.java new file mode 100644 index 00000000..112e25c1 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorker.java @@ -0,0 +1,578 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.worker; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.common.utils.FutureUtils; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatListener; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatManager; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServices; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatTarget; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalListener; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalService; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleManagerGateway; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerRegistration; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerRegistrationSuccess; +import com.alibaba.flink.shuffle.coordinator.manager.WorkerToManagerHeartbeatPayload; +import com.alibaba.flink.shuffle.coordinator.registration.ConnectingConnection; +import com.alibaba.flink.shuffle.coordinator.registration.EstablishedConnection; +import com.alibaba.flink.shuffle.coordinator.registration.RegistrationConnectionListener; +import com.alibaba.flink.shuffle.coordinator.worker.metastore.Metastore; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcEndpoint; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; +import com.alibaba.flink.shuffle.rpc.RpcTargetAddress; +import com.alibaba.flink.shuffle.rpc.message.Acknowledge; +import com.alibaba.flink.shuffle.rpc.utils.AkkaRpcServiceUtils; +import com.alibaba.flink.shuffle.transfer.NettyServer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; +import static com.alibaba.flink.shuffle.common.utils.ProcessUtils.getProcessID; + +/** The worker that actually manages the data partitions. */ +public class ShuffleWorker extends RemoteShuffleRpcEndpoint implements ShuffleWorkerGateway { + + private static final Logger LOG = LoggerFactory.getLogger(ShuffleWorker.class); + + private static final String SHUFFLE_WORKER_NAME = "shuffleworker"; + + private final ShuffleWorkerConfiguration shuffleWorkerConfiguration; + + private final ShuffleWorkerLocation shuffleWorkerLocation; + + /** The fatal error handler to use in case of a fatal error. */ + private final FatalErrorHandler fatalErrorHandler; + + private final Metastore metaStore; + + private final PartitionedDataStore dataStore; + + private final NettyServer nettyServer; + + // ------------------------------------------------------------------------ + + private final HeartbeatManager heartbeatManager; + + private final LeaderRetrievalService leaderRetrieveService; + + @Nullable private RpcTargetAddress shuffleManagerAddress; + + @Nullable + private ConnectingConnection + connectingConnection; + + @Nullable + private EstablishedConnection + establishedConnection; + + @Nullable private UUID currentRegistrationTimeoutId; + + protected ShuffleWorker( + RemoteShuffleRpcService rpcService, + ShuffleWorkerConfiguration shuffleWorkerConfiguration, + HaServices haServices, + HeartbeatServices heartbeatServices, + FatalErrorHandler fatalErrorHandler, + ShuffleWorkerLocation shuffleWorkerLocation, + Metastore metaStore, + PartitionedDataStore dataStore, + NettyServer nettyServer) { + + super(rpcService, AkkaRpcServiceUtils.createRandomName(SHUFFLE_WORKER_NAME)); + + this.shuffleWorkerConfiguration = checkNotNull(shuffleWorkerConfiguration); + this.fatalErrorHandler = checkNotNull(fatalErrorHandler); + + this.shuffleWorkerLocation = checkNotNull(shuffleWorkerLocation); + + this.metaStore = checkNotNull(metaStore); + metaStore.setPartitionRemovedConsumer(this::onPartitionRemoved); + + this.dataStore = checkNotNull(dataStore); + this.nettyServer = checkNotNull(nettyServer); + + this.leaderRetrieveService = + haServices.createLeaderRetrievalService(HaServices.LeaderReceptor.SHUFFLE_WORKER); + + this.heartbeatManager = + heartbeatServices.createHeartbeatManager( + shuffleWorkerLocation.getWorkerID(), + new ManagerHeartbeatListener(), + getRpcMainThreadScheduledExecutor(), + log); + } + + // ------------------------------------------------------------------------ + // Life Cycle + // ------------------------------------------------------------------------ + + @Override + protected void onStart() throws Exception { + try { + leaderRetrieveService.start(new ShuffleManagerLeaderListener()); + } catch (Exception e) { + handleStartShuffleWorkerServicesException(e); + } + } + + private void handleStartShuffleWorkerServicesException(Exception e) throws Exception { + try { + stopShuffleWorkerServices(); + } catch (Exception inner) { + e.addSuppressed(inner); + } + + throw e; + } + + private void stopShuffleWorkerServices() throws Exception { + Exception exception = null; + + try { + dataStore.shutDown(false); + } catch (Exception e) { + exception = e; + } + + try { + metaStore.close(); + } catch (Exception e) { + exception = exception == null ? e : exception; + } + + try { + nettyServer.shutdown(); + } catch (Exception e) { + exception = exception == null ? e : exception; + } + + try { + heartbeatManager.stop(); + } catch (Exception e) { + exception = exception == null ? e : exception; + } + + try { + leaderRetrieveService.stop(); + } catch (Exception e) { + exception = exception == null ? e : exception; + } + + if (exception != null) { + ExceptionUtils.rethrowException(exception); + } + } + + /** Called to shut down the ShuffleWorker. The method closes all ShuffleWorker services. */ + @Override + public CompletableFuture onStop() { + LOG.info("Stopping ShuffleWorker {}.", getAddress()); + + ShuffleException cause = new ShuffleException("The ShuffleWorker is shutting down."); + + closeShuffleManagerConnection(cause); + + LOG.info("Stopped ShuffleWorker {}.", getAddress()); + + try { + stopShuffleWorkerServices(); + return CompletableFuture.completedFuture(null); + } catch (Exception e) { + LOG.warn("Failed to stop shuffle worker services", e); + return FutureUtils.completedExceptionally(e); + } + } + + @Override + public void heartbeatFromManager(InstanceID managerID) { + heartbeatManager.requestHeartbeat(managerID, null); + } + + // ------------------------------------------------------------------------ + // Internal shuffle manager connection methods + // ------------------------------------------------------------------------ + + private void notifyOfNewShuffleManagerLeader(LeaderInformation leaderInfo) { + shuffleManagerAddress = createShuffleManagerAddress(leaderInfo); + reconnectToShuffleManager( + new ShuffleException( + String.format( + "ShuffleManager leader changed to new address %s", + shuffleManagerAddress))); + } + + @Nullable + private RpcTargetAddress createShuffleManagerAddress(LeaderInformation leaderInfo) { + if (leaderInfo == LeaderInformation.empty()) { + return null; + } + return new RpcTargetAddress(leaderInfo.getLeaderAddress(), leaderInfo.getLeaderSessionID()); + } + + private void reconnectToShuffleManager(Exception cause) { + closeShuffleManagerConnection(cause); + startRegistrationTimeout(); + connectToShuffleManager(); + } + + private void connectToShuffleManager() { + if (shuffleManagerAddress == null) { + return; + } + + checkState(establishedConnection == null, "Must be null."); + checkState(connectingConnection == null, "Must be null."); + + LOG.info("Connecting to ShuffleManager {}.", shuffleManagerAddress); + + connectingConnection = + new ConnectingConnection<>( + LOG, + "ShuffleManager", + ShuffleManagerGateway.class, + getRpcService(), + shuffleWorkerConfiguration.getRetryingRegistrationConfiguration(), + shuffleManagerAddress.getTargetAddress(), + shuffleManagerAddress.getLeaderUUID(), + getRpcMainThreadScheduledExecutor(), + new ShuffleManagerRegistrationListener(), + (gateway) -> + gateway.registerWorker( + new ShuffleWorkerRegistration( + getAddress(), + getRpcService().getAddress(), + shuffleWorkerLocation.getWorkerID(), + shuffleWorkerLocation.getDataPort(), + getProcessID()))); + + connectingConnection.start(); + } + + private void establishShuffleManagerConnection( + ShuffleManagerGateway shuffleManagerGateway, + ShuffleWorkerRegistrationSuccess response) { + + CompletableFuture shuffleDataStatusReportResponseFuture; + try { + shuffleDataStatusReportResponseFuture = + shuffleManagerGateway.reportDataPartitionStatus( + shuffleWorkerLocation.getWorkerID(), + response.getRegistrationID(), + metaStore.listDataPartitions()); + + shuffleDataStatusReportResponseFuture.whenCompleteAsync( + (acknowledge, throwable) -> { + if (throwable != null) { + reconnectToShuffleManager( + new Exception( + "Failed to send initial shuffle data status report to shuffle manager.", + throwable)); + } + }, + getRpcMainThreadScheduledExecutor()); + } catch (Exception e) { + LOG.warn("Initial shuffle data partition status report failed", e); + } + + // monitor the shuffle manager as heartbeat target + heartbeatManager.monitorTarget( + response.getInstanceID(), + new HeartbeatTarget() { + @Override + public void receiveHeartbeat( + InstanceID instanceID, + WorkerToManagerHeartbeatPayload heartbeatPayload) { + shuffleManagerGateway.heartbeatFromWorker(instanceID, heartbeatPayload); + } + + @Override + public void requestHeartbeat( + InstanceID instanceID, + WorkerToManagerHeartbeatPayload heartbeatPayload) { + // the ShuffleWorker won't send heartbeat requests to the ShuffleManager + } + }); + + establishedConnection = new EstablishedConnection<>(shuffleManagerGateway, response); + + stopRegistrationTimeout(); + } + + private void closeShuffleManagerConnection(Exception cause) { + if (establishedConnection != null) { + final InstanceID shuffleManagerInstanceID = + establishedConnection.getResponse().getInstanceID(); + + if (LOG.isDebugEnabled()) { + LOG.debug("Close ShuffleManager connection {}.", shuffleManagerInstanceID, cause); + } else { + LOG.error("Close ShuffleManager connection " + shuffleManagerInstanceID, cause); + } + heartbeatManager.unmonitorTarget(shuffleManagerInstanceID); + + ShuffleManagerGateway shuffleManagerGateway = establishedConnection.getGateway(); + shuffleManagerGateway.disconnectWorker(shuffleWorkerLocation.getWorkerID(), cause); + + establishedConnection = null; + } + + if (connectingConnection != null) { + if (!connectingConnection.isConnected()) { + if (LOG.isDebugEnabled()) { + LOG.debug( + "Terminating registration attempts towards ShuffleManager {}.", + connectingConnection.getTargetAddress(), + cause); + } else { + LOG.info( + "Terminating registration attempts towards ShuffleManager {}.", + connectingConnection.getTargetAddress()); + } + } + + connectingConnection.close(); + connectingConnection = null; + } + } + + private void startRegistrationTimeout() { + final long maxRegistrationDuration = + shuffleWorkerConfiguration.getMaxRegistrationDuration(); + + final UUID newRegistrationTimeoutId = UUID.randomUUID(); + currentRegistrationTimeoutId = newRegistrationTimeoutId; + scheduleRunAsync( + () -> registrationTimeout(newRegistrationTimeoutId), + maxRegistrationDuration, + TimeUnit.MILLISECONDS); + } + + private void stopRegistrationTimeout() { + currentRegistrationTimeoutId = null; + } + + private void registrationTimeout(@Nonnull UUID registrationTimeoutId) { + if (registrationTimeoutId.equals(currentRegistrationTimeoutId)) { + onFatalError( + new Exception( + String.format( + "Could not register at the ShuffleManager within the specified maximum " + + "registration duration %s. This indicates a problem with this instance. Terminating now.", + shuffleWorkerConfiguration.getMaxRegistrationDuration()))); + } + } + + // ---------------------------------------------------------------------- + // Disconnection RPCs + // ---------------------------------------------------------------------- + + @Override + public void disconnectManager(Exception cause) { + if (isRunning()) { + reconnectToShuffleManager(cause); + } + } + + // ------------------------------------------------------------------------ + // Error Handling + // ------------------------------------------------------------------------ + + /** + * Notifies the ShuffleWorker that a fatal error has occurred and it cannot proceed. + * + * @param t The exception describing the fatal error + */ + void onFatalError(final Throwable t) { + try { + LOG.error("Fatal error occurred in ShuffleWorker {}.", getAddress(), t); + } catch (Throwable ignored) { + } + + // The fatal error handler implementation should make sure that this call is non-blocking + fatalErrorHandler.onFatalError(t); + } + + // ------------------------------------------------------------------------ + // RPC methods + // ------------------------------------------------------------------------ + + private void onPartitionRemoved(JobID jobID, DataPartitionCoordinate coordinate) { + runAsync( + () -> { + if (establishedConnection != null) { + establishedConnection + .getGateway() + .workerReportDataPartitionReleased( + shuffleWorkerLocation.getWorkerID(), + establishedConnection.getResponse().getRegistrationID(), + jobID, + coordinate.getDataSetId(), + coordinate.getDataPartitionId()); + } else { + LOG.warn( + "No connection to the shuffle manager and cannot notify of the partition {}-{} removed", + jobID, + coordinate); + } + }); + } + + @Override + public CompletableFuture releaseDataPartition( + JobID jobID, DataSetID dataSetID, DataPartitionID dataPartitionID) { + // DataStore will notify metastore after removing the data files + dataStore.releaseDataPartition(dataSetID, dataPartitionID, null); + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + @Override + public CompletableFuture removeReleasedDataPartitionMeta( + JobID jobID, DataSetID dataSetID, DataPartitionID dataPartitionID) { + metaStore.removeReleasingDataPartition( + new DataPartitionCoordinate(dataSetID, dataPartitionID)); + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + @Override + public CompletableFuture getWorkerMetrics() { + ShuffleWorkerMetrics workerMetrics = new ShuffleWorkerMetrics(); + workerMetrics.setMetric( + ShuffleWorkerMetricKeys.AVAILABLE_READING_BUFFERS_KEY, + dataStore.getReadingBufferDispatcher().numAvailableBuffers()); + workerMetrics.setMetric( + ShuffleWorkerMetricKeys.AVAILABLE_WRITING_BUFFERS_KEY, + dataStore.getWritingBufferDispatcher().numAvailableBuffers()); + workerMetrics.setMetric( + ShuffleWorkerMetricKeys.DATA_PARTITION_NUMBERS_KEY, metaStore.getSize()); + return CompletableFuture.completedFuture(workerMetrics); + } + + // ------------------------------------------------------------------------ + // Static utility classes + // ------------------------------------------------------------------------ + + private class ShuffleManagerLeaderListener implements LeaderRetrievalListener { + + @Override + public void notifyLeaderAddress(LeaderInformation leaderInfo) { + runAsync(() -> notifyOfNewShuffleManagerLeader(leaderInfo)); + } + + @Override + public void handleError(Exception exception) { + fatalErrorHandler.onFatalError( + new Exception("Failed to retrieve shuffle manager address", exception)); + } + } + + private final class ShuffleManagerRegistrationListener + implements RegistrationConnectionListener< + ConnectingConnection, + ShuffleWorkerRegistrationSuccess> { + + @Override + public void onRegistrationSuccess( + ConnectingConnection + connection, + ShuffleWorkerRegistrationSuccess success) { + final ShuffleManagerGateway shuffleManagerGateway = connection.getTargetGateway(); + + runAsync( + () -> { + // filter out outdated connections + //noinspection ObjectEquality + if (connectingConnection == connection) { + try { + establishShuffleManagerConnection(shuffleManagerGateway, success); + } catch (Throwable t) { + LOG.error( + "Establishing ShuffleManager connection in ShuffleWorker failed", + t); + } + } + }); + } + + @Override + public void onRegistrationFailure(Throwable failure) { + onFatalError(failure); + } + } + + private class ManagerHeartbeatListener + implements HeartbeatListener { + + @Override + public void notifyHeartbeatTimeout(InstanceID instanceID) { + validateRunsInMainThread(); + + // first check whether the timeout is still valid + if (establishedConnection != null + && establishedConnection.getResponse().getInstanceID().equals(instanceID)) { + LOG.info("The heartbeat of ShuffleManager with id {} timed out.", instanceID); + + reconnectToShuffleManager( + new Exception( + String.format( + "The heartbeat of ShuffleManager with id %s timed out.", + instanceID))); + } else { + LOG.debug( + "Received heartbeat timeout for outdated ShuffleManager id {}. Ignoring the timeout.", + instanceID); + } + } + + @Override + public void reportPayload(InstanceID instanceID, Void payload) { + // nothing to do since the payload is of type Void + } + + @Override + public WorkerToManagerHeartbeatPayload retrievePayload(InstanceID instanceID) { + validateRunsInMainThread(); + try { + return new WorkerToManagerHeartbeatPayload(metaStore.listDataPartitions()); + } catch (Exception e) { + return new WorkerToManagerHeartbeatPayload(new ArrayList<>()); + } + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerConfiguration.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerConfiguration.java new file mode 100644 index 00000000..912ee7a6 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerConfiguration.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.worker; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.coordinator.registration.RetryingRegistrationConfiguration; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; + +/** The parsed configuration values for {@link ShuffleWorker}. */ +public class ShuffleWorkerConfiguration { + + private final Configuration configuration; + + private final long maxRegistrationDuration; + + private final RetryingRegistrationConfiguration retryingRegistrationConfiguration; + + public ShuffleWorkerConfiguration( + Configuration configuration, + long maxRegistrationDuration, + RetryingRegistrationConfiguration retryingRegistrationConfiguration) { + CommonUtils.checkArgument(configuration != null, "Must be not null."); + CommonUtils.checkArgument(retryingRegistrationConfiguration != null, "Must be not null."); + + this.configuration = configuration; + this.maxRegistrationDuration = maxRegistrationDuration; + this.retryingRegistrationConfiguration = retryingRegistrationConfiguration; + } + + public static ShuffleWorkerConfiguration fromConfiguration(Configuration configuration) { + long maxRegistrationDuration; + try { + maxRegistrationDuration = + configuration.getDuration(ClusterOptions.REGISTRATION_TIMEOUT).toMillis(); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + String.format( + "Invalid format for parameter %s. Set the timeout to be infinite.", + ClusterOptions.REGISTRATION_TIMEOUT.key()), + e); + } + + RetryingRegistrationConfiguration retryingRegistrationConfiguration = + RetryingRegistrationConfiguration.fromConfiguration(configuration); + + return new ShuffleWorkerConfiguration( + configuration, maxRegistrationDuration, retryingRegistrationConfiguration); + } + + public Configuration getConfiguration() { + return configuration; + } + + public long getMaxRegistrationDuration() { + return maxRegistrationDuration; + } + + // -------------------------------------------------------------------------------------------- + // Static factory methods + // -------------------------------------------------------------------------------------------- + + public RetryingRegistrationConfiguration getRetryingRegistrationConfiguration() { + return retryingRegistrationConfiguration; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerGateway.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerGateway.java new file mode 100644 index 00000000..a2277ee4 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerGateway.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.worker; + +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcGateway; +import com.alibaba.flink.shuffle.rpc.message.Acknowledge; + +import java.util.concurrent.CompletableFuture; + +/** The rpc gateway for shuffle workers. */ +public interface ShuffleWorkerGateway extends RemoteShuffleRpcGateway { + + /** + * Receives heartbeat from the shuffle manager. + * + * @param managerID The InstanceID of the shuffle manager. + */ + void heartbeatFromManager(InstanceID managerID); + + /** + * Disconnects the worker from the shuffle manager. + * + * @param cause The reason for disconnecting. + */ + void disconnectManager(Exception cause); + + /** + * Releases the shuffle resource for one data partition. + * + * @param jobID The id of the job produces the data partition. + * @param dataSetID The id of the dataset that contains the data partition.. + * @param dataPartitionID The id of the data partition. + * @return The future indicating whether the request is submitted. + */ + CompletableFuture releaseDataPartition( + JobID jobID, DataSetID dataSetID, DataPartitionID dataPartitionID); + + /** + * Removes the meta for the already released data partition after the manager has also marked it + * as releasing. + * + * @param jobID The id of the job produces the data partition. + * @param dataSetID The id of the dataset that contains the data partition.. + * @param dataPartitionID The id of the data partition. + * @return The future indicating whether the request is submitted. + */ + CompletableFuture removeReleasedDataPartitionMeta( + JobID jobID, DataSetID dataSetID, DataPartitionID dataPartitionID); + + /** + * Get shuffle worker metrics. + * + * @return the shuffle worker metrics currently. + */ + CompletableFuture getWorkerMetrics(); +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerLocation.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerLocation.java new file mode 100644 index 00000000..fb60fc66 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerLocation.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.worker; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +import java.io.Serializable; + +/** Location information of shuffle workers. */ +public class ShuffleWorkerLocation implements Serializable { + + private static final long serialVersionUID = 5625612426180704010L; + + private final String externalAddress; + + private final int dataPort; + + private final InstanceID workerID; + + public ShuffleWorkerLocation(String externalAddress, int dataPort, InstanceID workerID) { + CommonUtils.checkArgument(externalAddress != null, "Must be not null."); + CommonUtils.checkArgument(CommonUtils.isValidHostPort(dataPort), "Illegal data port."); + CommonUtils.checkArgument(workerID != null, "Must be not null."); + + this.externalAddress = externalAddress; + this.dataPort = dataPort; + this.workerID = workerID; + } + + public InstanceID getWorkerID() { + return workerID; + } + + public int getDataPort() { + return dataPort; + } + + public String getExternalAddress() { + return externalAddress; + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerMetricKeys.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerMetricKeys.java new file mode 100644 index 00000000..7f1719fe --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerMetricKeys.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.worker; + +/** Keys for shuffle worker metrics. */ +public class ShuffleWorkerMetricKeys { + + public static final String AVAILABLE_READING_BUFFERS_KEY = "available_reading_buffers"; + + public static final String AVAILABLE_WRITING_BUFFERS_KEY = "available_writing_buffers"; + + public static final String DATA_PARTITION_NUMBERS_KEY = "data_partition_numbers"; +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerMetrics.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerMetrics.java new file mode 100644 index 00000000..10a37715 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerMetrics.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.worker; + +import java.io.Serializable; +import java.util.HashMap; + +/** Shuffle worker metrics. */ +public class ShuffleWorkerMetrics implements Serializable { + + private static final long serialVersionUID = -1460652050573532972L; + + private final HashMap metrics = new HashMap<>(); + + public void setMetric(String key, Serializable metric) { + metrics.put(key, metric); + } + + public Serializable getMetric(String key) { + return metrics.get(key); + } + + public Integer getIntegerMetric(String key) { + return (Integer) getMetric(key); + } + + public Long getLongMetric(String key) { + return (Long) getMetric(key); + } + + public Double getDoubleMetric(String key) { + return (Double) getMetric(key); + } + + public Float getFloatMetric(String key) { + return (Float) getMetric(key); + } + + public String getStringMetric(String key) { + return (String) getMetric(key); + } + + @Override + public String toString() { + return metrics.toString(); + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerRunner.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerRunner.java new file mode 100644 index 00000000..5ab8b3a7 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerRunner.java @@ -0,0 +1,390 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.worker; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.common.utils.FatalErrorExitUtils; +import com.alibaba.flink.shuffle.common.utils.FutureUtils; +import com.alibaba.flink.shuffle.common.utils.JvmShutdownSafeguard; +import com.alibaba.flink.shuffle.common.utils.SignalHandler; +import com.alibaba.flink.shuffle.common.utils.StringUtils; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServices; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServicesUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServiceUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.utils.ClusterEntrypointUtils; +import com.alibaba.flink.shuffle.coordinator.utils.EnvironmentInformation; +import com.alibaba.flink.shuffle.coordinator.utils.LeaderRetrievalUtils; +import com.alibaba.flink.shuffle.coordinator.worker.metastore.LocalShuffleMetaStore; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.config.WorkerOptions; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; +import com.alibaba.flink.shuffle.metrics.entry.MetricUtils; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; +import com.alibaba.flink.shuffle.rpc.utils.AkkaRpcServiceUtils; +import com.alibaba.flink.shuffle.storage.datastore.PartitionedDataStoreImpl; +import com.alibaba.flink.shuffle.storage.utils.StorageConfigParseUtils; +import com.alibaba.flink.shuffle.transfer.NettyConfig; +import com.alibaba.flink.shuffle.transfer.NettyServer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.InetAddress; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.coordinator.utils.ClusterEntrypointUtils.STARTUP_FAILURE_RETURN_CODE; + +/** The entrypoint of the ShuffleWorker. */ +public class ShuffleWorkerRunner implements FatalErrorHandler { + + private static final Logger LOG = LoggerFactory.getLogger(ShuffleWorkerRunner.class); + + private static final Duration LOOKUP_TIMEOUT_DURATION = Duration.ofSeconds(30); + + private static final int SUCCESS_EXIT_CODE = 0; + + private static final int FAILURE_EXIT_CODE = 1; + + private final Object lock = new Object(); + + private final RemoteShuffleRpcService rpcService; + + private final HaServices haServices; + + private final ShuffleWorker shuffleWorker; + + private final CompletableFuture terminationFuture = new CompletableFuture<>(); + + private boolean shutdown; + + public ShuffleWorkerRunner(Configuration configuration) throws Exception { + checkArgument(configuration != null, "Must be not null."); + + haServices = HaServiceUtils.createHAServices(configuration); + + HeartbeatServices heartbeatServices = + HeartbeatServicesUtils.createManagerWorkerHeartbeatServices(configuration); + + MetricUtils.startWorkerMetricSystem(configuration); + + AkkaRpcServiceUtils.loadRpcSystem(configuration); + this.rpcService = + AkkaRpcServiceUtils.createRemoteRpcService( + configuration, + determineShuffleWorkerBindAddress(configuration, haServices), + configuration.getString(WorkerOptions.RPC_PORT), + configuration.getString(WorkerOptions.BIND_HOST), + Optional.ofNullable(configuration.getInteger(WorkerOptions.RPC_BIND_PORT))); + + this.shuffleWorker = + createShuffleWorker(configuration, rpcService, haServices, heartbeatServices, this); + shuffleWorker + .getTerminationFuture() + .whenComplete( + ((unused, throwable) -> { + synchronized (lock) { + if (!shutdown) { + onFatalError( + new Exception( + "Unexpected termination of the ShuffleWorker.", + throwable)); + } + } + })); + } + + private static String determineShuffleWorkerBindAddress( + Configuration configuration, HaServices haServices) throws Exception { + + final String configuredShuffleWorkerHostname = configuration.getString(WorkerOptions.HOST); + + if (configuredShuffleWorkerHostname != null) { + LOG.info( + "Using configured hostname/address for ShuffleWorker: {}.", + configuredShuffleWorkerHostname); + return configuredShuffleWorkerHostname; + } else { + return determineShuffleWorkerBindAddressByConnectingToShuffleManager( + configuration, haServices); + } + } + + private static String determineShuffleWorkerBindAddressByConnectingToShuffleManager( + Configuration configuration, HaServices haServices) throws Exception { + final InetAddress shuffleWorkerAddress = + LeaderRetrievalUtils.findConnectingAddress( + haServices.createLeaderRetrievalService( + HaServices.LeaderReceptor.SHUFFLE_WORKER), + LOOKUP_TIMEOUT_DURATION); + + LOG.info( + "ShuffleWorker will use hostname/address '{}' ({}) for communication.", + shuffleWorkerAddress.getHostName(), + shuffleWorkerAddress.getHostAddress()); + + HostBindPolicy bindPolicy = + HostBindPolicy.fromString(configuration.getString(WorkerOptions.HOST_BIND_POLICY)); + return bindPolicy == HostBindPolicy.IP + ? shuffleWorkerAddress.getHostAddress() + : shuffleWorkerAddress.getHostName(); + } + + static InstanceID getShuffleWorkerID(String rpcAddress, int rpcPort) throws Exception { + String randomString = CommonUtils.randomHexString(16); + return new InstanceID( + StringUtils.isNullOrWhitespaceOnly(rpcAddress) + ? InetAddress.getLocalHost().getHostName() + "-" + randomString + : rpcAddress + ":" + rpcPort + "-" + randomString); + } + + // export the termination future for caller to know it is terminated + public CompletableFuture getTerminationFuture() { + return terminationFuture; + } + + // -------------------------------------------------------------------------------------------- + // Lifecycle management + // -------------------------------------------------------------------------------------------- + + public static void main(String[] args) throws Exception { + // startup checks and logging + EnvironmentInformation.logEnvironmentInfo(LOG, "Shuffle Worker", args); + SignalHandler.register(LOG); + JvmShutdownSafeguard.installAsShutdownHook(LOG); + + long maxOpenFileHandles = EnvironmentInformation.getOpenFileHandlesLimit(); + + if (maxOpenFileHandles != -1L) { + LOG.info("Maximum number of open file descriptors is {}.", maxOpenFileHandles); + } else { + LOG.info("Cannot determine the maximum number of open file descriptors"); + } + + try { + Configuration configuration = ClusterEntrypointUtils.parseParametersOrExit(args); + runShuffleWorker(configuration); + } catch (Throwable t) { + LOG.error("ShuffleWorker initialization failed.", t); + FatalErrorExitUtils.exitProcessIfNeeded(STARTUP_FAILURE_RETURN_CODE); + } + } + + public static ShuffleWorkerRunner runShuffleWorker(Configuration configuration) + throws Exception { + ShuffleWorkerRunner shuffleWorkerRunner = new ShuffleWorkerRunner(configuration); + shuffleWorkerRunner.start(); + LOG.info("Shuffle worker runner is started"); + return shuffleWorkerRunner; + } + + public static ShuffleWorker createShuffleWorker( + Configuration configuration, + RemoteShuffleRpcService rpcService, + HaServices haServices, + HeartbeatServices heartbeatServices, + FatalErrorHandler fatalErrorHandler) + throws Exception { + + InstanceID workerID = getShuffleWorkerID(rpcService.getAddress(), rpcService.getPort()); + + ShuffleWorkerConfiguration shuffleWorkerConfiguration = + ShuffleWorkerConfiguration.fromConfiguration(configuration); + + String directories = configuration.getString(StorageOptions.STORAGE_LOCAL_DATA_DIRS); + if (directories == null || directories.trim().isEmpty()) { + throw new ConfigurationException( + String.format( + "The data dir '%s' configured by '%s' is not valid.", + directories, StorageOptions.STORAGE_LOCAL_DATA_DIRS.key())); + } + + List allPaths = + StorageConfigParseUtils.parseStoragePaths(directories).getAllPaths(); + LocalShuffleMetaStore fileSystemShuffleMetaStore = + new LocalShuffleMetaStore(new HashSet<>(allPaths)); + + PartitionedDataStoreImpl dataStore = + new PartitionedDataStoreImpl(configuration, fileSystemShuffleMetaStore); + + // Initialize the data partitions on startup + int recoveredCount = 0; + int failedCount = 0; + for (DataPartitionMeta dataPartitionMeta : + fileSystemShuffleMetaStore.getAllDataPartitionMetas()) { + try { + dataStore.addDataPartition(dataPartitionMeta); + recoveredCount++; + } catch (Throwable t) { + LOG.warn( + "Failed to initialize the data partition {}-{}-{}", + dataPartitionMeta.getJobID(), + dataPartitionMeta.getDataSetID(), + dataPartitionMeta.getDataPartitionID()); + failedCount++; + } + } + LOG.info( + "Recovered {} partitions successfully and {} partitions with failure", + recoveredCount, + failedCount); + + NettyConfig nettyConfig = new NettyConfig(configuration); + NettyServer nettyServer = new NettyServer(dataStore, nettyConfig); + nettyServer.start(); + + ShuffleWorkerLocation shuffleWorkerLocation = + new ShuffleWorkerLocation( + rpcService.getAddress(), nettyConfig.getServerPort(), workerID); + + return new ShuffleWorker( + rpcService, + shuffleWorkerConfiguration, + haServices, + heartbeatServices, + fatalErrorHandler, + shuffleWorkerLocation, + fileSystemShuffleMetaStore, + dataStore, + nettyServer); + } + + public void start() { + shuffleWorker.start(); + } + + // -------------------------------------------------------------------------------------------- + // Static entry point + // -------------------------------------------------------------------------------------------- + + @Override + public void onFatalError(Throwable exception) { + LOG.error( + "Fatal error occurred while executing the ShuffleWorker. Shutting it down...", + exception); + + // In case of the Metaspace OutOfMemoryError, we expect that the graceful shutdown is + // possible, + // as it does not usually require more class loading to fail again with the Metaspace + // OutOfMemoryError. + if (ExceptionUtils.isJvmFatalOrOutOfMemoryError(exception) + && !ExceptionUtils.isMetaspaceOutOfMemoryError(exception)) { + terminateJVM(); + } else { + closeAsync(Result.FAILURE); + } + } + + private void terminateJVM() { + FatalErrorExitUtils.exitProcessIfNeeded(FAILURE_EXIT_CODE); + } + + public void close() throws Exception { + try { + closeAsync(Result.SUCCESS).get(); + } catch (ExecutionException e) { + ExceptionUtils.rethrowException( + ExceptionUtils.stripException(e, ExecutionException.class)); + } + } + + private CompletableFuture closeAsync(Result terminationResult) { + synchronized (lock) { + if (!shutdown) { + shutdown = true; + + final CompletableFuture shuffleWorkerTerminationFuture = + shuffleWorker.closeAsync(); + + final CompletableFuture serviceTerminationFuture = + FutureUtils.composeAfterwards( + shuffleWorkerTerminationFuture, this::shutDownServices); + + serviceTerminationFuture.whenComplete( + (Void ignored, Throwable throwable) -> { + if (throwable != null) { + terminationFuture.completeExceptionally(throwable); + } else { + terminationFuture.complete(terminationResult); + } + }); + } + } + + return terminationFuture; + } + + // -------------------------------------------------------------------------------------------- + // Static utilities + // -------------------------------------------------------------------------------------------- + + private CompletableFuture shutDownServices() { + synchronized (lock) { + Collection> terminationFutures = new ArrayList<>(3); + Throwable exception = null; + + try { + haServices.close(); + } catch (Throwable throwable) { + exception = throwable; + LOG.error("Failed to close HA service.", throwable); + } + + terminationFutures.add(rpcService.stopService()); + + if (exception != null) { + terminationFutures.add(FutureUtils.completedExceptionally(exception)); + } + + MetricUtils.stopMetricSystem(); + + return FutureUtils.completeAll(terminationFutures); + } + } + + /** The result of running Shuffle Worker. */ + public enum Result { + SUCCESS(SUCCESS_EXIT_CODE), + + FAILURE(FAILURE_EXIT_CODE); + + private final int exitCode; + + Result(int exitCode) { + this.exitCode = exitCode; + } + + public int getExitCode() { + return exitCode; + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/metastore/LocalShuffleMetaStore.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/metastore/LocalShuffleMetaStore.java new file mode 100644 index 00000000..d58dafc2 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/metastore/LocalShuffleMetaStore.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.worker.metastore; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionStatus; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; +import com.alibaba.flink.shuffle.storage.utils.DataPartitionUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.stream.Collectors; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** The meta store based on storing the meta files on the disk containing the data partitions. */ +public class LocalShuffleMetaStore implements Metastore { + + private static final Logger LOG = LoggerFactory.getLogger(LocalShuffleMetaStore.class); + + public static final String META_DIR_NAME = "_meta"; + + private final Set storagePaths; + + private final Map dataPartitions = + new HashMap<>(); + + @Nullable private volatile BiConsumer partitionRemovedConsumer; + + public LocalShuffleMetaStore(Set storagePaths) throws Exception { + this.storagePaths = checkNotNull(storagePaths); + initialize(); + } + + private void initialize() { + for (String base : storagePaths) { + File metaBaseDir = new File(base, META_DIR_NAME); + + if (!metaBaseDir.exists()) { + continue; + } + + List spoiledMetaFiles = new ArrayList<>(); + for (File subFile : + Optional.ofNullable(metaBaseDir.listFiles()).orElseGet(() -> new File[0])) { + try (FileInputStream fileInputStream = new FileInputStream(subFile); + DataInputStream dataInput = new DataInputStream(fileInputStream)) { + DataPartitionMeta dataPartitionMeta = + DataPartitionUtils.deserializePartitionMeta(dataInput); + dataPartitions.put( + new DataPartitionCoordinate( + dataPartitionMeta.getDataSetID(), + dataPartitionMeta.getDataPartitionID()), + new DataPartitionMetaStatus(dataPartitionMeta, false)); + } catch (Exception e) { + LOG.warn("Failed to parse " + subFile.getAbsolutePath(), e); + spoiledMetaFiles.add(subFile); + } + } + + for (File spoiledMetaFile : spoiledMetaFiles) { + try { + boolean deleted = spoiledMetaFile.delete(); + if (!deleted) { + LOG.warn( + "Failed to remove the spoiled meta file " + + spoiledMetaFile.getPath()); + } + } catch (Exception e) { + LOG.warn( + "Failed to remove the spoiled meta file " + spoiledMetaFile.getPath(), + e); + } + } + } + } + + @Override + public void setPartitionRemovedConsumer( + BiConsumer partitionRemovedConsumer) { + this.partitionRemovedConsumer = checkNotNull(partitionRemovedConsumer); + } + + @Override + public List listDataPartitions() { + List dataPartitionStatuses = new ArrayList<>(); + + synchronized (dataPartitions) { + dataPartitions.forEach( + ((coordinate, dataPartitionWorkerStatus) -> { + dataPartitionStatuses.add( + new DataPartitionStatus( + dataPartitionWorkerStatus.getMeta().getJobID(), + new DataPartitionCoordinate( + dataPartitionWorkerStatus.getMeta().getDataSetID(), + dataPartitionWorkerStatus + .getMeta() + .getDataPartitionID()), + dataPartitionWorkerStatus.isReleasing())); + })); + } + + return dataPartitionStatuses; + } + + public List getAllDataPartitionMetas() { + return dataPartitions.values().stream() + .map(DataPartitionMetaStatus::getMeta) + .collect(Collectors.toList()); + } + + @Override + public void onPartitionCreated(DataPartitionMeta partitionMeta) throws Exception { + checkState( + storagePaths.contains(partitionMeta.getStorageMeta().getStoragePath()), + String.format( + "The base path %s is not configured", + partitionMeta.getStorageMeta().getStoragePath())); + + DataPartitionCoordinate coordinate = + new DataPartitionCoordinate( + partitionMeta.getDataSetID(), partitionMeta.getDataPartitionID()); + checkState( + !dataPartitions.containsKey(coordinate), "The data partition is already exists."); + + synchronized (dataPartitions) { + dataPartitions.put(coordinate, new DataPartitionMetaStatus(partitionMeta, false)); + } + + File storageFile = new File(partitionMeta.getStorageMeta().getStoragePath()); + File metaFile = getDataPartitionPath(storageFile, partitionMeta); + + if (!metaFile.getParentFile().exists()) { + metaFile.getParentFile().mkdir(); + } + + try (FileOutputStream outputStream = new FileOutputStream(metaFile); + DataOutputStream dataOutput = new DataOutputStream(outputStream)) { + DataPartitionUtils.serializePartitionMeta(partitionMeta, dataOutput); + } + } + + @Override + public void onPartitionRemoved(DataPartitionMeta partitionMeta) { + checkState( + storagePaths.contains(partitionMeta.getStorageMeta().getStoragePath()), + String.format( + "The base path %s is not configured", + partitionMeta.getStorageMeta().getStoragePath())); + + File storageFile = new File(partitionMeta.getStorageMeta().getStoragePath()); + DataPartitionCoordinate coordinate = + new DataPartitionCoordinate( + partitionMeta.getDataSetID(), partitionMeta.getDataPartitionID()); + synchronized (dataPartitions) { + DataPartitionMetaStatus status = dataPartitions.get(coordinate); + + if (status == null) { + LOG.warn("Data partition {} not found", coordinate); + return; + } + + // Marks data partition as releasing, and remove the meta after master has also marked + // in removeReleasingDataPartition. + dataPartitions.get(coordinate).setReleasing(true); + } + + File metaFile = getDataPartitionPath(storageFile, partitionMeta); + try { + boolean deleted = metaFile.delete(); + if (!deleted) { + LOG.warn("Unable to remove the meta file " + metaFile.getAbsolutePath()); + } + } catch (Exception e) { + LOG.warn("Unable to remove the meta file " + metaFile.getAbsolutePath(), e); + } + + BiConsumer currentRemoveConsumer = partitionRemovedConsumer; + if (currentRemoveConsumer != null) { + currentRemoveConsumer.accept(partitionMeta.getJobID(), coordinate); + } + } + + @Override + public void removeReleasingDataPartition(DataPartitionCoordinate coordinate) { + synchronized (dataPartitions) { + dataPartitions.remove(coordinate); + } + } + + @Override + public int getSize() { + synchronized (dataPartitions) { + return dataPartitions.size(); + } + } + + private File getDataPartitionPath(File baseDir, DataPartitionMeta partitionMeta) { + String name = + CommonUtils.bytesToHexString(partitionMeta.getDataSetID().getId()) + + "-" + + CommonUtils.bytesToHexString(partitionMeta.getDataPartitionID().getId()); + + return new File(new File(baseDir, META_DIR_NAME), name); + } + + @Override + public void close() throws Exception { + // TODO: would be implemented later + } + + /** The status of one data partition on the worker side. */ + private static class DataPartitionMetaStatus { + + private final DataPartitionMeta meta; + + private boolean isReleasing; + + public DataPartitionMetaStatus(DataPartitionMeta meta, boolean isReleasing) { + this.meta = meta; + this.isReleasing = isReleasing; + } + + public DataPartitionMeta getMeta() { + return meta; + } + + public boolean isReleasing() { + return isReleasing; + } + + public void setReleasing(boolean releasing) { + isReleasing = releasing; + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/metastore/Metastore.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/metastore/Metastore.java new file mode 100644 index 00000000..426a36b9 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/coordinator/worker/metastore/Metastore.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.worker.metastore; + +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionStatus; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.listener.PartitionStateListener; + +import java.util.List; +import java.util.function.BiConsumer; + +/** The interface for metastore. */ +public interface Metastore extends PartitionStateListener, AutoCloseable { + + void setPartitionRemovedConsumer( + BiConsumer partitionRemovedConsumer); + + /** + * Lists the data partitions stored in the meta store. + * + * @return The list of data partitions according to the meta store. + */ + List listDataPartitions() throws Exception; + + /** + * Remove the releasing data partition after master has marked as released. + * + * @param dataPartitionCoordinate The coordinate of data partition to remove. + */ + void removeReleasingDataPartition(DataPartitionCoordinate dataPartitionCoordinate); + + int getSize(); +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/minicluster/ShuffleMiniCluster.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/minicluster/ShuffleMiniCluster.java new file mode 100644 index 00000000..019b0f40 --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/minicluster/ShuffleMiniCluster.java @@ -0,0 +1,589 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.minicluster; + +import com.alibaba.flink.shuffle.client.ShuffleManagerClient; +import com.alibaba.flink.shuffle.client.ShuffleManagerClientConfiguration; +import com.alibaba.flink.shuffle.client.ShuffleManagerClientImpl; +import com.alibaba.flink.shuffle.client.ShuffleWorkerStatusListener; +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.functions.AutoCloseableAsync; +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; +import com.alibaba.flink.shuffle.common.utils.FutureUtils; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServices; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServicesUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServiceUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleManager; +import com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker.AssignmentTrackerImpl; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorker; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerRunner; +import com.alibaba.flink.shuffle.core.config.MemoryOptions; +import com.alibaba.flink.shuffle.core.config.TransferOptions; +import com.alibaba.flink.shuffle.core.executor.ExecutorThreadFactory; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; +import com.alibaba.flink.shuffle.rpc.utils.AkkaRpcServiceUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.concurrent.GuardedBy; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** MiniCluster to execute shuffle locally. */ +public class ShuffleMiniCluster implements AutoCloseableAsync { + + private static final Logger LOG = LoggerFactory.getLogger(ShuffleMiniCluster.class); + + /** The lock to guard startup / shutdown / manipulation methods. */ + private final Object lock = new Object(); + + /** The configuration for this mini cluster. */ + private final ShuffleMiniClusterConfiguration miniClusterConfiguration; + + private final Random random = new Random(System.currentTimeMillis()); + + private static final int RETRY_TIMES_GET_WORKER_NUMBER = 5; + + private ShuffleManager shuffleManager; + + @GuardedBy("lock") + private final List shuffleWorkers; + + private final TerminatingFatalErrorHandlerFactory + shuffleWorkerTerminatingFatalErrorHandlerFactory = + new TerminatingFatalErrorHandlerFactory(); + + private CompletableFuture terminationFuture; + + @GuardedBy("lock") + private final Collection rpcServices; + + @GuardedBy("lock") + private ExecutorService ioExecutor; + + @GuardedBy("lock") + private HeartbeatServices workerHeartbeatServices; + + @GuardedBy("lock") + private HeartbeatServices jobHeartbeatServices; + + @GuardedBy("lock") + private RpcServiceFactory shuffleWorkerRpcServiceFactory; + + @GuardedBy("lock") + private RpcServiceFactory shuffleManagerRpcServiceFactory; + + @GuardedBy("lock") + private RpcServiceFactory shuffleMangerClientRpcServiceFactory; + + /** Flag marking the mini cluster as started/running. */ + private volatile boolean running; + + // ------------------------------------------------------------------------ + + /** + * Creates a new remote shuffle mini cluster based on the given configuration. + * + * @param miniClusterConfiguration The configuration for the mini cluster + */ + public ShuffleMiniCluster(ShuffleMiniClusterConfiguration miniClusterConfiguration) { + this.miniClusterConfiguration = checkNotNull(miniClusterConfiguration); + this.rpcServices = + new ArrayList<>( + 1 + + 1 + + miniClusterConfiguration + .getNumShuffleWorkers()); // common + manager + workers + + this.terminationFuture = CompletableFuture.completedFuture(null); + running = false; + + this.shuffleWorkers = new ArrayList<>(miniClusterConfiguration.getNumShuffleWorkers()); + } + + // ------------------------------------------------------------------------ + // life cycle + // ------------------------------------------------------------------------ + + /** Checks if the mini cluster was started and is running. */ + public boolean isRunning() { + return running; + } + + /** + * Starts the mini cluster, based on the configured properties. + * + * @throws Exception This method passes on any exception that occurs during the startup of the + * mini cluster. + */ + public void start() throws Exception { + synchronized (lock) { + checkState(!running, "MiniCluster is already running"); + + LOG.info("Starting the shuffle mini cluster."); + LOG.debug("Using configuration {}", miniClusterConfiguration); + + Configuration configuration = miniClusterConfiguration.getConfiguration(); + + try { + // bring up all the RPC services + LOG.info("Starting RPC Service(s)"); + + // start a new service per component, possibly with custom bind addresses + final String shuffleManagerExternalAddress = + miniClusterConfiguration.getShuffleManagerExternalAddress(); + final String shuffleWorkerExternalAddress = + miniClusterConfiguration.getShuffleWorkerExternalAddress(); + final String shuffleManagerExternalPortRange = + miniClusterConfiguration.getShuffleManagerExternalPortRange(); + final String shuffleWorkerExternalPortRange = + miniClusterConfiguration.getShuffleWorkerExternalPortRange(); + final String shuffleManagerBindAddress = + miniClusterConfiguration.getShuffleManagerBindAddress(); + final String shuffleWorkerBindAddress = + miniClusterConfiguration.getShuffleWorkerBindAddress(); + + shuffleWorkerRpcServiceFactory = + new DedicatedRpcServiceFactory( + configuration, + shuffleWorkerExternalAddress, + shuffleWorkerExternalPortRange, + shuffleWorkerBindAddress); + + shuffleManagerRpcServiceFactory = + new DedicatedRpcServiceFactory( + configuration, + shuffleManagerExternalAddress, + shuffleManagerExternalPortRange, + shuffleManagerBindAddress); + + shuffleMangerClientRpcServiceFactory = + new DedicatedRpcServiceFactory( + configuration, null, "0", shuffleManagerBindAddress); + + ioExecutor = + Executors.newFixedThreadPool( + 1, new ExecutorThreadFactory("mini-cluster-io")); + + workerHeartbeatServices = + HeartbeatServicesUtils.createManagerWorkerHeartbeatServices(configuration); + jobHeartbeatServices = + HeartbeatServicesUtils.createManagerJobHeartbeatServices(configuration); + + AkkaRpcServiceUtils.loadRpcSystem(configuration); + startShuffleManager(); + + startShuffleWorkers(); + + LOG.info("Waiting for all the workers to register."); + int retryCount = 0; + while (retryCount < RETRY_TIMES_GET_WORKER_NUMBER) { + boolean getWorkerNumSuccess = false; + try { + LOG.info( + "Retrying to get number of registered workers, retry times: " + + retryCount); + while (shuffleManager + .getNumberOfRegisteredWorkers() + .get(30, TimeUnit.SECONDS) + < miniClusterConfiguration.getNumShuffleWorkers()) { + Thread.sleep(500); + } + getWorkerNumSuccess = true; + } catch (Exception e) { + LOG.error("Failed to get number of registered workers, ", e); + } + if (getWorkerNumSuccess) { + break; + } + retryCount++; + } + LOG.info("All the workers have registered."); + } catch (Exception e) { + // cleanup everything + try { + close(); + } catch (Exception ee) { + e.addSuppressed(ee); + } + throw e; + } + + // create a new termination future + terminationFuture = new CompletableFuture<>(); + + // now officially mark this as running + running = true; + + LOG.info("Shuffle mini cluster started successfully"); + } + } + + private HaServices createHaService(Configuration configuration) throws Exception { + LOG.info("Starting high-availability services"); + return HaServiceUtils.createHAServices(configuration); + } + + /** + * Shuts down the mini cluster, failing all currently executing jobs. The mini cluster can be + * started again by calling the {@link #start()} method again. + * + *

This method shuts down all started services and components, even if an exception occurs in + * the process of shutting down some component. + * + * @return Future which is completed once the MiniCluster has been completely shut down + */ + @Override + public CompletableFuture closeAsync() { + synchronized (lock) { + if (running) { + LOG.info("Shutting down the shuffle mini cluster."); + try { + final int numComponents = 1 + miniClusterConfiguration.getNumShuffleWorkers(); + final Collection> componentTerminationFutures = + new ArrayList<>(numComponents); + + componentTerminationFutures.addAll(terminateShuffleWorkers()); + + componentTerminationFutures.add(terminateShuffleManager()); + + final FutureUtils.ConjunctFuture componentsTerminationFuture = + FutureUtils.completeAll(componentTerminationFutures); + + final CompletableFuture rpcServicesTerminationFuture = + FutureUtils.composeAfterwards( + componentsTerminationFuture, this::terminateRpcServices); + + final CompletableFuture executorsTerminationFuture = + FutureUtils.composeAfterwards( + rpcServicesTerminationFuture, this::terminateExecutors); + + executorsTerminationFuture.whenComplete( + (Void ignored, Throwable throwable) -> { + if (throwable != null) { + terminationFuture.completeExceptionally(throwable); + } else { + terminationFuture.complete(null); + } + }); + } finally { + running = false; + } + } + + return terminationFuture; + } + } + + @GuardedBy("lock") + private void startShuffleWorkers() throws Exception { + final int numShuffleWorkers = miniClusterConfiguration.getNumShuffleWorkers(); + + LOG.info("Starting {} ShuffleWorkers(s)", numShuffleWorkers); + + for (int i = 0; i < numShuffleWorkers; i++) { + startShuffleWorker(); + } + } + + private void startShuffleWorker() throws Exception { + synchronized (lock) { + Configuration configuration = miniClusterConfiguration.getConfiguration(); + + // Choose a random rpc port for the configuration + Configuration workerConfiguration = new Configuration(configuration); + // TODO: only choose available port + int randomPort = random.nextInt(30000) + 20000; + workerConfiguration.setInteger(TransferOptions.SERVER_DATA_PORT, randomPort); + workerConfiguration.setMemorySize( + MemoryOptions.MEMORY_SIZE_FOR_DATA_READING, + MemoryOptions.MIN_VALID_MEMORY_SIZE); + workerConfiguration.setMemorySize( + MemoryOptions.MEMORY_SIZE_FOR_DATA_WRITING, + MemoryOptions.MIN_VALID_MEMORY_SIZE); + + RemoteShuffleRpcService rpcService = shuffleWorkerRpcServiceFactory.createRpcService(); + HaServices haServices = createHaService(miniClusterConfiguration.getConfiguration()); + + ShuffleWorker shuffleWorker = + ShuffleWorkerRunner.createShuffleWorker( + workerConfiguration, + rpcService, + haServices, + workerHeartbeatServices, + shuffleWorkerTerminatingFatalErrorHandlerFactory.create( + shuffleWorkers.size())); + + shuffleWorker.start(); + shuffleWorkers.add(shuffleWorker); + } + } + + @GuardedBy("lock") + private Collection> terminateShuffleWorkers() { + final Collection> terminationFutures = + new ArrayList<>(shuffleWorkers.size()); + for (int i = 0; i < shuffleWorkers.size(); i++) { + terminationFutures.add(terminateShuffleWorker(i)); + } + + return terminationFutures; + } + + @Nonnull + private CompletableFuture terminateShuffleWorker(int index) { + synchronized (lock) { + final ShuffleWorker shuffleWorker = shuffleWorkers.get(index); + return shuffleWorker.closeAsync(); + } + } + + private void startShuffleManager() throws Exception { + RemoteShuffleRpcService rpcService = shuffleManagerRpcServiceFactory.createRpcService(); + HaServices haServices = createHaService(miniClusterConfiguration.getConfiguration()); + + shuffleManager = + new ShuffleManager( + rpcService, + new InstanceID("Shuffle manager"), + haServices, + new ShutDownFatalErrorHandler(), + ioExecutor, + jobHeartbeatServices, + workerHeartbeatServices, + new AssignmentTrackerImpl()); + + shuffleManager.start(); + } + + public ShuffleManager getShuffleManager() { + return shuffleManager; + } + + protected CompletableFuture terminateShuffleManager() { + synchronized (lock) { + return shuffleManager.closeAsync(); + } + } + + // ------------------------------------------------------------------------ + // factories - can be overridden by subclasses to alter behavior + // ------------------------------------------------------------------------ + + /** + * Factory method to instantiate the remote RPC service. + * + * @param configuration shuffle configuration. + * @param externalAddress The external address to access the RPC service. + * @param externalPortRange The external port range to access the RPC service. + * @param bindAddress The address to bind the RPC service to. + * @return The instantiated RPC service + */ + protected RemoteShuffleRpcService createRemoteRpcService( + Configuration configuration, + String externalAddress, + String externalPortRange, + String bindAddress) + throws Exception { + return AkkaRpcServiceUtils.remoteServiceBuilder( + configuration, externalAddress, externalPortRange) + .withBindAddress(bindAddress) + .createAndStart(); + } + + // ------------------------------------------------------------------------ + // data client + // ------------------------------------------------------------------------ + + public ShuffleManagerClient createClient(JobID jobID) throws Exception { + return createClient(jobID, new InstanceID()); + } + + public ShuffleManagerClient createClient(JobID jobId, InstanceID clientID) throws Exception { + ShuffleManagerClientConfiguration clientConfiguration = + ShuffleManagerClientConfiguration.fromConfiguration( + miniClusterConfiguration.getConfiguration()); + + RemoteShuffleRpcService rpcService = + shuffleMangerClientRpcServiceFactory.createRpcService(); + HaServices haServices = createHaService(miniClusterConfiguration.getConfiguration()); + + ShuffleManagerClient client = + new ShuffleManagerClientImpl( + jobId, + new ShuffleWorkerStatusListener() { + @Override + public void notifyIrrelevantWorker(InstanceID workerID) {} + + @Override + public void notifyRelevantWorker( + InstanceID workerID, + Set recoveredDataPartitions) {} + }, + rpcService, + (Throwable throwable) -> new ShutDownFatalErrorHandler(), + clientConfiguration, + haServices, + jobHeartbeatServices, + clientID); + client.start(); + return client; + } + + // ------------------------------------------------------------------------ + // Internal methods + // ------------------------------------------------------------------------ + + @Nonnull + private CompletableFuture terminateRpcServices() { + synchronized (lock) { + final int numRpcServices = 1 + rpcServices.size(); + + final Collection> rpcTerminationFutures = + new ArrayList<>(numRpcServices); + + for (RemoteShuffleRpcService rpcService : rpcServices) { + rpcTerminationFutures.add(rpcService.stopService()); + } + + rpcServices.clear(); + + return FutureUtils.completeAll(rpcTerminationFutures); + } + } + + private CompletableFuture terminateExecutors() { + synchronized (lock) { + try { + if (ioExecutor != null) { + ioExecutor.shutdown(); + } + } catch (Throwable throwable) { + LOG.error("Failed to close the executor service.", throwable); + return FutureUtils.completedExceptionally(throwable); + } + return CompletableFuture.completedFuture(null); + } + } + + /** Internal factory for {@link RemoteShuffleRpcService}. */ + protected interface RpcServiceFactory { + RemoteShuffleRpcService createRpcService() throws Exception; + } + + /** Factory which creates and registers new {@link RemoteShuffleRpcService}. */ + protected class DedicatedRpcServiceFactory implements RpcServiceFactory { + + private final Configuration configuration; + private final String externalAddress; + private final String externalPortRange; + private final String bindAddress; + + DedicatedRpcServiceFactory( + Configuration configuration, + String externalAddress, + String externalPortRange, + String bindAddress) { + this.configuration = configuration; + this.externalAddress = externalAddress; + this.externalPortRange = externalPortRange; + this.bindAddress = bindAddress; + } + + @Override + public RemoteShuffleRpcService createRpcService() throws Exception { + RemoteShuffleRpcService rpcService = + ShuffleMiniCluster.this.createRemoteRpcService( + configuration, externalAddress, externalPortRange, bindAddress); + + synchronized (lock) { + rpcServices.add(rpcService); + } + + return rpcService; + } + } + + // ------------------------------------------------------------------------ + // miscellaneous utilities + // ------------------------------------------------------------------------ + + private class TerminatingFatalErrorHandler implements FatalErrorHandler { + + private final int index; + + private TerminatingFatalErrorHandler(int index) { + this.index = index; + } + + @Override + public void onFatalError(Throwable exception) { + // first check if we are still running + if (running) { + LOG.error("ShuffleWorker #{} failed.", index, exception); + + synchronized (lock) { + shuffleWorkers.get(index).closeAsync(); + } + } + } + } + + private class ShutDownFatalErrorHandler implements FatalErrorHandler { + + @Override + public void onFatalError(Throwable exception) { + LOG.warn("Error in MiniCluster. Shutting the MiniCluster down.", exception); + closeAsync(); + } + } + + private class TerminatingFatalErrorHandlerFactory { + + /** + * Create a new {@link TerminatingFatalErrorHandler} for the {@link ShuffleWorker} with the + * given index. + * + * @param index into the {@link #shuffleWorkers} collection to identify the correct {@link + * ShuffleWorker}. + * @return {@link TerminatingFatalErrorHandler} for the given index + */ + @GuardedBy("lock") + private TerminatingFatalErrorHandler create(int index) { + return new TerminatingFatalErrorHandler(index); + } + } +} diff --git a/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/minicluster/ShuffleMiniClusterConfiguration.java b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/minicluster/ShuffleMiniClusterConfiguration.java new file mode 100644 index 00000000..9dad24aa --- /dev/null +++ b/shuffle-coordinator/src/main/java/com/alibaba/flink/shuffle/minicluster/ShuffleMiniClusterConfiguration.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.minicluster; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; +import com.alibaba.flink.shuffle.core.config.WorkerOptions; + +import javax.annotation.Nullable; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** Configuration object for the {@link ShuffleMiniCluster}. */ +public class ShuffleMiniClusterConfiguration { + + private final Configuration configuration; + + private final int numShuffleWorkers; + + @Nullable private final String commonBindAddress; + + // ------------------------------------------------------------------------ + // Construction + // ------------------------------------------------------------------------ + + public ShuffleMiniClusterConfiguration( + Configuration configuration, + int numShuffleWorkers, + @Nullable String commonBindAddress) { + + this.numShuffleWorkers = numShuffleWorkers; + this.configuration = checkNotNull(configuration); + this.configuration.setInteger( + ManagerOptions.RPC_BIND_PORT, ManagerOptions.RPC_PORT.defaultValue()); + this.commonBindAddress = commonBindAddress; + } + + // ------------------------------------------------------------------------ + // getters + // ------------------------------------------------------------------------ + + public int getNumShuffleWorkers() { + return numShuffleWorkers; + } + + public String getShuffleManagerExternalAddress() { + return commonBindAddress != null + ? commonBindAddress + : configuration.getString(ManagerOptions.RPC_ADDRESS, "localhost"); + } + + public String getShuffleWorkerExternalAddress() { + return commonBindAddress != null + ? commonBindAddress + : configuration.getString(WorkerOptions.HOST, "localhost"); + } + + public String getShuffleManagerExternalPortRange() { + return String.valueOf(configuration.getInteger(ManagerOptions.RPC_PORT)); + } + + public String getShuffleWorkerExternalPortRange() { + return configuration.getString(WorkerOptions.RPC_PORT); + } + + public String getShuffleManagerBindAddress() { + return commonBindAddress != null + ? commonBindAddress + : configuration.getString(ManagerOptions.RPC_BIND_ADDRESS, "localhost"); + } + + public String getShuffleWorkerBindAddress() { + return commonBindAddress != null + ? commonBindAddress + : configuration.getString(WorkerOptions.BIND_HOST, "localhost"); + } + + public Configuration getConfiguration() { + return configuration; + } + + @Override + public String toString() { + return "MiniClusterConfiguration {" + + "numShuffleWorker=" + + numShuffleWorkers + + ", commonBindAddress='" + + commonBindAddress + + '\'' + + ", config=" + + configuration + + '}'; + } + + // ---------------------------------------------------------------------------------- + // Enums + // ---------------------------------------------------------------------------------- + + // ---------------------------------------------------------------------------------- + // Builder + // ---------------------------------------------------------------------------------- + + /** Builder for the MiniClusterConfiguration. */ + public static class Builder { + private Configuration configuration = new Configuration(); + private int numShuffleWorkers = 1; + @Nullable private String commonBindAddress = null; + + public Builder setConfiguration(Configuration configuration) { + this.configuration = checkNotNull(configuration); + return this; + } + + public Builder setNumShuffleWorkers(int numShuffleWorkers) { + this.numShuffleWorkers = numShuffleWorkers; + return this; + } + + public Builder setCommonBindAddress(String commonBindAddress) { + this.commonBindAddress = commonBindAddress; + return this; + } + + public ShuffleMiniClusterConfiguration build() { + return new ShuffleMiniClusterConfiguration( + configuration, numShuffleWorkers, commonBindAddress); + } + } +} diff --git a/shuffle-coordinator/src/main/resources-filtered/.coordinator.version.properties b/shuffle-coordinator/src/main/resources-filtered/.coordinator.version.properties new file mode 100644 index 00000000..e6b5d977 --- /dev/null +++ b/shuffle-coordinator/src/main/resources-filtered/.coordinator.version.properties @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +project.version=${project.version} + +git.commit.id=${git.commit.id} +git.commit.id.abbrev=${git.commit.id.abbrev} +git.commit.time=${git.commit.time} +git.build.time=${git.build.time} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/client/ShuffleManagerClientTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/client/ShuffleManagerClientTest.java new file mode 100644 index 00000000..a1a829ed --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/client/ShuffleManagerClientTest.java @@ -0,0 +1,488 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.client; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.TestingHaServices; +import com.alibaba.flink.shuffle.coordinator.leaderretrieval.SettableLeaderRetrievalService; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.coordinator.manager.DefaultShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ManagerToJobHeartbeatPayload; +import com.alibaba.flink.shuffle.coordinator.manager.RegistrationSuccess; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerDescriptor; +import com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker.ChangedWorkerStatus; +import com.alibaba.flink.shuffle.coordinator.utils.RandomIDUtils; +import com.alibaba.flink.shuffle.coordinator.utils.RecordingHeartbeatServices; +import com.alibaba.flink.shuffle.coordinator.utils.TestingFatalErrorHandler; +import com.alibaba.flink.shuffle.coordinator.utils.TestingShuffleManagerGateway; +import com.alibaba.flink.shuffle.core.config.RpcOptions; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.rpc.message.Acknowledge; +import com.alibaba.flink.shuffle.rpc.test.TestingRpcService; +import com.alibaba.flink.shuffle.rpc.utils.RpcUtils; +import com.alibaba.flink.shuffle.utils.Tuple4; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.Triple; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** Tests the behaviour of {@link ShuffleManagerClientImpl}. */ +public class ShuffleManagerClientTest { + + private static final String partitionFactoryName = + "com.alibaba.flink.shuffle.storage.partition.LocalFileMapPartitionFactory"; + + private static final long timeout = 60000L; + + private ExecutorService jobMainThreadExecutor; + + private TestingRpcService rpcService; + + private Configuration configuration; + + private SettableLeaderRetrievalService shuffleManagerLeaderRetrieveService; + + private TestingHaServices haServices; + + private TestingFatalErrorHandler testingFatalErrorHandler; + + @Before + public void setup() throws IOException { + jobMainThreadExecutor = Executors.newSingleThreadExecutor(); + + rpcService = new TestingRpcService(); + + configuration = new Configuration(); + + shuffleManagerLeaderRetrieveService = new SettableLeaderRetrievalService(); + haServices = new TestingHaServices(); + haServices.setShuffleManagerLeaderRetrieveService(shuffleManagerLeaderRetrieveService); + + testingFatalErrorHandler = new TestingFatalErrorHandler(); + } + + @After + public void teardown() throws Exception { + if (rpcService != null) { + RpcUtils.terminateRpcService(rpcService, timeout); + rpcService = null; + } + + if (jobMainThreadExecutor != null) { + jobMainThreadExecutor.shutdownNow(); + } + } + + @Test + public void testClientRegisterAndHeartbeat() throws Exception { + TestingShuffleManagerGateway smGateway = new TestingShuffleManagerGateway(); + + // Registration future + CompletableFuture clientRegistrationFuture = new CompletableFuture<>(); + RegistrationSuccess registrationSuccess = + new RegistrationSuccess(smGateway.getInstanceID()); + smGateway.setRegisterClientConsumer( + jobID -> { + clientRegistrationFuture.complete(jobID); + return CompletableFuture.completedFuture(registrationSuccess); + }); + + // heartbeat future + CompletableFuture heartbeatFuture = new CompletableFuture<>(); + List unrelatedShuffleWorkers = + Collections.singletonList(new InstanceID("worker1")); + Map> relatedShuffleWorkers = + Collections.singletonMap( + new InstanceID("worker2"), + Collections.singleton( + new DataPartitionCoordinate( + RandomIDUtils.randomDataSetId(), + RandomIDUtils.randomMapPartitionId()))); + + smGateway.setHeartbeatFromClientConsumer( + (jobId, cachedWorkerList) -> { + heartbeatFuture.complete(jobId); + return CompletableFuture.completedFuture( + new ManagerToJobHeartbeatPayload( + smGateway.getInstanceID(), + new ChangedWorkerStatus( + unrelatedShuffleWorkers, relatedShuffleWorkers))); + }); + rpcService.registerGateway(smGateway.getAddress(), smGateway); + + RecordShuffleWorkerStatusListener shuffleWorkerStatusListener = + new RecordShuffleWorkerStatusListener(); + + JobID jobId = RandomIDUtils.randomJobId(); + try (ShuffleManagerClientImpl shuffleManagerClient = + new ShuffleManagerClientImpl( + jobId, + shuffleWorkerStatusListener, + rpcService, + testingFatalErrorHandler, + ShuffleManagerClientConfiguration.fromConfiguration(configuration), + haServices, + new HeartbeatServices(1000L, 2000L))) { + shuffleManagerLeaderRetrieveService.notifyListener( + new LeaderInformation(smGateway.getFencingToken(), smGateway.getAddress())); + shuffleManagerClient.start(); + assertThat( + clientRegistrationFuture.get(timeout, TimeUnit.MILLISECONDS), equalTo(jobId)); + + assertThat(heartbeatFuture.get(timeout, TimeUnit.MILLISECONDS), equalTo(jobId)); + + assertEquals( + unrelatedShuffleWorkers.get(0), + shuffleWorkerStatusListener.pollUnrelatedWorkers(timeout)); + assertEquals( + relatedShuffleWorkers.entrySet().stream() + .map(entry -> Pair.of(entry.getKey(), entry.getValue())) + .findFirst() + .get(), + shuffleWorkerStatusListener.pollRelatedWorkers(timeout)); + } + } + + @Test + public void testSynchronizeWorkerStatusWithManager() throws Exception { + TestingShuffleManagerGateway smGateway = new TestingShuffleManagerGateway(); + + // Registration future + CompletableFuture clientRegistrationFuture = new CompletableFuture<>(); + RegistrationSuccess registrationSuccess = + new RegistrationSuccess(smGateway.getInstanceID()); + smGateway.setRegisterClientConsumer( + jobID -> { + clientRegistrationFuture.complete(jobID); + return CompletableFuture.completedFuture(registrationSuccess); + }); + + InstanceID[] workerIds = new InstanceID[3]; + for (int i = 0; i < workerIds.length; ++i) { + workerIds[i] = new InstanceID("worker" + i); + } + + Set initWorkers = new HashSet<>(Arrays.asList(workerIds[0], workerIds[1])); + smGateway.setHeartbeatFromClientConsumer( + (jobID, instanceIDS) -> { + // Dedicated delay to check if the call is synchronous. + try { + Thread.sleep(2000); + } catch (InterruptedException e) { + // ignored. + } + + assertEquals(initWorkers, instanceIDS); + return CompletableFuture.completedFuture( + new ManagerToJobHeartbeatPayload( + smGateway.getInstanceID(), + new ChangedWorkerStatus( + Collections.singletonList(workerIds[1]), + Collections.singletonMap( + workerIds[2], Collections.emptySet())))); + }); + rpcService.registerGateway(smGateway.getAddress(), smGateway); + + RecordShuffleWorkerStatusListener shuffleWorkerStatusListener = + new RecordShuffleWorkerStatusListener(); + JobID jobId = RandomIDUtils.randomJobId(); + configuration.setDuration(RpcOptions.RPC_TIMEOUT, Duration.ofSeconds(100)); + try (ShuffleManagerClientImpl shuffleManagerClient = + new ShuffleManagerClientImpl( + jobId, + shuffleWorkerStatusListener, + rpcService, + testingFatalErrorHandler, + ShuffleManagerClientConfiguration.fromConfiguration(configuration), + haServices, + new HeartbeatServices(Long.MAX_VALUE, Long.MAX_VALUE))) { + shuffleManagerLeaderRetrieveService.notifyListener( + new LeaderInformation(smGateway.getFencingToken(), smGateway.getAddress())); + + // Simulates running in the JM main executor + jobMainThreadExecutor + .submit( + () -> { + try { + shuffleManagerClient.start(); + assertThat( + clientRegistrationFuture.get( + timeout, TimeUnit.MILLISECONDS), + equalTo(jobId)); + + shuffleManagerClient.synchronizeWorkerStatus(initWorkers); + + assertEquals( + workerIds[1], + shuffleWorkerStatusListener.pollUnrelatedWorkers(100)); + assertEquals( + workerIds[2], + shuffleWorkerStatusListener + .pollRelatedWorkers(100) + .getLeft()); + } catch (Exception e) { + throw new RuntimeException(e); + } + }) + .get(); + } + } + + @Test + public void testHeartbeatTimeoutWithShuffleManager() throws Exception { + TestingShuffleManagerGateway smGateway = new TestingShuffleManagerGateway(); + + // Registration future + CompletableFuture clientRegistrationFuture = new CompletableFuture<>(); + CountDownLatch registrationAttempts = new CountDownLatch(2); + RegistrationSuccess registrationSuccess = + new RegistrationSuccess(smGateway.getInstanceID()); + smGateway.setRegisterClientConsumer( + jobID -> { + clientRegistrationFuture.complete(jobID); + registrationAttempts.countDown(); + return CompletableFuture.completedFuture(registrationSuccess); + }); + + // heartbeat future which never terminate to trigger timeout + smGateway.setHeartbeatFromClientConsumer( + (jobIds, relatedWorkIds) -> new CompletableFuture<>()); + + // Disconnect Future + CompletableFuture clientDisconnectFuture = new CompletableFuture<>(); + smGateway.setUnregisterClientConsumer( + jobID -> { + clientDisconnectFuture.complete(jobID); + return CompletableFuture.completedFuture(Acknowledge.get()); + }); + + rpcService.registerGateway(smGateway.getAddress(), smGateway); + + JobID jobId = RandomIDUtils.randomJobId(); + HeartbeatServices heartbeatServices = new HeartbeatServices(1L, 3L); + try (ShuffleManagerClientImpl shuffleManagerClient = + new ShuffleManagerClientImpl( + jobId, + new TestingShuffleWorkerStatusListener(), + rpcService, + testingFatalErrorHandler, + ShuffleManagerClientConfiguration.fromConfiguration(configuration), + haServices, + heartbeatServices)) { + + shuffleManagerLeaderRetrieveService.notifyListener( + new LeaderInformation(smGateway.getFencingToken(), smGateway.getAddress())); + shuffleManagerClient.start(); + assertThat( + clientRegistrationFuture.get(timeout, TimeUnit.MILLISECONDS), + Matchers.equalTo(jobId)); + + assertThat( + clientDisconnectFuture.get(timeout, TimeUnit.MILLISECONDS), + Matchers.equalTo(jobId)); + + assertTrue( + "The Shuffle Worker should try to reconnect to the RM", + registrationAttempts.await(timeout, TimeUnit.SECONDS)); + } + } + + @Test + public void testUnMonitorShuffleManagerOnLeadershipRevoked() throws Exception { + TestingShuffleManagerGateway smGateway = new TestingShuffleManagerGateway(); + + rpcService.registerGateway(smGateway.getAddress(), smGateway); + + JobID jobId = RandomIDUtils.randomJobId(); + RecordingHeartbeatServices heartbeatServices = new RecordingHeartbeatServices(1L, 100000L); + try (ShuffleManagerClientImpl shuffleManagerClient = + new ShuffleManagerClientImpl( + jobId, + new TestingShuffleWorkerStatusListener(), + rpcService, + testingFatalErrorHandler, + ShuffleManagerClientConfiguration.fromConfiguration(configuration), + haServices, + heartbeatServices)) { + BlockingQueue monitoredTargets = heartbeatServices.getMonitoredTargets(); + BlockingQueue unmonitoredTargets = + heartbeatServices.getUnmonitoredTargets(); + + shuffleManagerLeaderRetrieveService.notifyListener( + new LeaderInformation(smGateway.getFencingToken(), smGateway.getAddress())); + shuffleManagerClient.start(); + + assertThat( + monitoredTargets.poll(timeout, TimeUnit.MILLISECONDS), + Matchers.equalTo(smGateway.getInstanceID())); + + shuffleManagerLeaderRetrieveService.notifyListener(LeaderInformation.empty()); + assertThat( + unmonitoredTargets.poll(timeout, TimeUnit.MILLISECONDS), + Matchers.equalTo(smGateway.getInstanceID())); + } + } + + @Test + public void testRequestAndReleaseShuffleResource() throws Exception { + TestingShuffleManagerGateway smGateway = new TestingShuffleManagerGateway(); + + // Registration future + CompletableFuture clientRegistrationFuture = new CompletableFuture<>(); + CountDownLatch registrationAttempts = new CountDownLatch(2); + RegistrationSuccess registrationSuccess = + new RegistrationSuccess(smGateway.getInstanceID()); + smGateway.setRegisterClientConsumer( + jobID -> { + clientRegistrationFuture.complete(jobID); + registrationAttempts.countDown(); + return CompletableFuture.completedFuture(registrationSuccess); + }); + + // Resource request + CompletableFuture> resourceRequestFuture = + new CompletableFuture<>(); + ShuffleResource shuffleResource = + new DefaultShuffleResource( + new ShuffleWorkerDescriptor[] { + new ShuffleWorkerDescriptor(new InstanceID("worker1"), "worker1", 20480) + }, + DataPartition.DataPartitionType.MAP_PARTITION); + smGateway.setAllocateShuffleResourceConsumer( + (jobID, dataSetID, mapPartitionID, numberOfSubpartitions) -> { + resourceRequestFuture.complete( + new Tuple4<>(jobID, dataSetID, mapPartitionID, numberOfSubpartitions)); + return CompletableFuture.completedFuture(shuffleResource); + }); + + // Resource release + CompletableFuture> resourceReleaseFuture = + new CompletableFuture<>(); + smGateway.setReleaseShuffleResourceConsumer( + (jobID, dataSetID, mapPartitionID) -> { + resourceReleaseFuture.complete(Triple.of(jobID, dataSetID, mapPartitionID)); + return CompletableFuture.completedFuture(Acknowledge.get()); + }); + + rpcService.registerGateway(smGateway.getAddress(), smGateway); + + JobID jobId = RandomIDUtils.randomJobId(); + DataSetID dataSetId = RandomIDUtils.randomDataSetId(); + MapPartitionID dataPartitionId = RandomIDUtils.randomMapPartitionId(); + int numberOfSubpartitions = 10; + + try (ShuffleManagerClientImpl shuffleManagerClient = + new ShuffleManagerClientImpl( + jobId, + new TestingShuffleWorkerStatusListener(), + rpcService, + testingFatalErrorHandler, + ShuffleManagerClientConfiguration.fromConfiguration(configuration), + haServices, + new HeartbeatServices(1000L, 3000L))) { + + shuffleManagerLeaderRetrieveService.notifyListener( + new LeaderInformation(smGateway.getFencingToken(), smGateway.getAddress())); + shuffleManagerClient.start(); + + CompletableFuture result1 = + shuffleManagerClient.requestShuffleResource( + dataSetId, + dataPartitionId, + numberOfSubpartitions, + partitionFactoryName); + result1.join(); + + assertThat( + clientRegistrationFuture.get(timeout, TimeUnit.MILLISECONDS), + Matchers.equalTo(jobId)); + assertThat( + resourceRequestFuture.get(timeout, TimeUnit.MILLISECONDS), + Matchers.equalTo( + new Tuple4<>( + jobId, dataSetId, dataPartitionId, numberOfSubpartitions))); + assertTrue(result1.isDone()); + assertEquals(shuffleResource, result1.get()); + + shuffleManagerClient.releaseShuffleResource(dataSetId, dataPartitionId); + assertThat( + resourceReleaseFuture.get(timeout, TimeUnit.MILLISECONDS), + Matchers.equalTo(Triple.of(jobId, dataSetId, dataPartitionId))); + } + } + + private static class RecordShuffleWorkerStatusListener implements ShuffleWorkerStatusListener { + + private final BlockingQueue unrelatedWorkers = new LinkedBlockingQueue<>(); + private final BlockingQueue>> + relatedWorkers = new LinkedBlockingQueue<>(); + + @Override + public void notifyIrrelevantWorker(InstanceID workerID) { + unrelatedWorkers.add(workerID); + } + + @Override + public void notifyRelevantWorker( + InstanceID workerID, Set dataPartitions) { + relatedWorkers.add(Pair.of(workerID, dataPartitions)); + } + + public InstanceID pollUnrelatedWorkers(long timeout) throws InterruptedException { + return unrelatedWorkers.poll(timeout, TimeUnit.MILLISECONDS); + } + + public Pair> pollRelatedWorkers( + long timeout) throws InterruptedException { + return relatedWorkers.poll(timeout, TimeUnit.MILLISECONDS); + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/client/TestingShuffleWorkerStatusListener.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/client/TestingShuffleWorkerStatusListener.java new file mode 100644 index 00000000..d07132c0 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/client/TestingShuffleWorkerStatusListener.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.client; + +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +import java.util.Set; + +/** A testing {@link ShuffleWorkerStatusListener} that takes no action. */ +public class TestingShuffleWorkerStatusListener implements ShuffleWorkerStatusListener { + + @Override + public void notifyIrrelevantWorker(InstanceID workerID) {} + + @Override + public void notifyRelevantWorker( + InstanceID workerID, Set dataPartitions) {} +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatManagerTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatManagerTest.java new file mode 100644 index 00000000..18a7e68d --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/HeartbeatManagerTest.java @@ -0,0 +1,541 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.coordinator.utils.TestingUtils; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.utils.OneShotLatch; +import com.alibaba.flink.shuffle.core.utils.TestLogger; +import com.alibaba.flink.shuffle.rpc.executor.ScheduledExecutor; +import com.alibaba.flink.shuffle.rpc.executor.ScheduledExecutorServiceAdapter; + +import org.hamcrest.Matcher; +import org.junit.Test; +import org.mockito.ArgumentMatchers; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** Tests for the {@link HeartbeatManager}. */ +public class HeartbeatManagerTest extends TestLogger { + private static final Logger LOG = LoggerFactory.getLogger(HeartbeatManagerTest.class); + public static final long HEARTBEAT_INTERVAL = 50L; + public static final long HEARTBEAT_TIMEOUT = 200L; + + /** + * Tests that regular heartbeat signal triggers the right callback functions in the {@link + * HeartbeatListener}. + */ + @Test + public void testRegularHeartbeat() throws InterruptedException { + final long heartbeatTimeout = 1000L; + InstanceID ownInstanceID = new InstanceID("foobar"); + InstanceID targetInstanceID = new InstanceID("barfoo"); + final int outputPayload = 42; + final ArrayBlockingQueue reportedPayloads = new ArrayBlockingQueue<>(2); + final TestingHeartbeatListener heartbeatListener = + new TestingHeartbeatListenerBuilder() + .setReportPayloadConsumer( + (ignored, payload) -> reportedPayloads.offer(payload)) + .setRetrievePayloadFunction((ignored) -> outputPayload) + .createNewTestingHeartbeatListener(); + + HeartbeatManagerImpl heartbeatManager = + new HeartbeatManagerImpl<>( + heartbeatTimeout, + ownInstanceID, + heartbeatListener, + TestingUtils.defaultScheduledExecutor(), + LOG); + + final ArrayBlockingQueue reportedPayloadsHeartbeatTarget = + new ArrayBlockingQueue<>(2); + final TestingHeartbeatTarget heartbeatTarget = + new TestingHeartbeatTargetBuilder() + .setReceiveHeartbeatConsumer( + (ignoredA, payload) -> + reportedPayloadsHeartbeatTarget.offer(payload)) + .createTestingHeartbeatTarget(); + + heartbeatManager.monitorTarget(targetInstanceID, heartbeatTarget); + + final String inputPayload1 = "foobar"; + heartbeatManager.requestHeartbeat(targetInstanceID, inputPayload1); + + assertThat(reportedPayloads.take(), is(inputPayload1)); + assertThat(reportedPayloadsHeartbeatTarget.take(), is(outputPayload)); + + final String inputPayload2 = "barfoo"; + heartbeatManager.receiveHeartbeat(targetInstanceID, inputPayload2); + assertThat(reportedPayloads.take(), is(inputPayload2)); + } + + /** Tests that the heartbeat monitors are updated when receiving a new heartbeat signal. */ + @Test + public void testHeartbeatMonitorUpdate() { + long heartbeatTimeout = 1000L; + InstanceID ownInstanceID = new InstanceID(); + InstanceID targetInstanceID = new InstanceID("barfoo"); + @SuppressWarnings("unchecked") + HeartbeatListener heartbeatListener = mock(HeartbeatListener.class); + ScheduledExecutor scheduledExecutor = mock(ScheduledExecutor.class); + ScheduledFuture scheduledFuture = mock(ScheduledFuture.class); + + doReturn(scheduledFuture) + .when(scheduledExecutor) + .schedule( + ArgumentMatchers.any(Runnable.class), + ArgumentMatchers.anyLong(), + ArgumentMatchers.any(TimeUnit.class)); + + Object expectedObject = new Object(); + + when(heartbeatListener.retrievePayload(ArgumentMatchers.any(InstanceID.class))) + .thenReturn(CompletableFuture.completedFuture(expectedObject)); + + HeartbeatManagerImpl heartbeatManager = + new HeartbeatManagerImpl<>( + heartbeatTimeout, ownInstanceID, heartbeatListener, scheduledExecutor, LOG); + + @SuppressWarnings("unchecked") + HeartbeatTarget heartbeatTarget = mock(HeartbeatTarget.class); + + heartbeatManager.monitorTarget(targetInstanceID, heartbeatTarget); + + heartbeatManager.receiveHeartbeat(targetInstanceID, expectedObject); + + verify(scheduledFuture, times(1)).cancel(true); + verify(scheduledExecutor, times(2)) + .schedule( + ArgumentMatchers.any(Runnable.class), + ArgumentMatchers.eq(heartbeatTimeout), + ArgumentMatchers.eq(TimeUnit.MILLISECONDS)); + } + + /** Tests that a heartbeat timeout is signaled if the heartbeat is not reported in time. */ + @Test + public void testHeartbeatTimeout() throws Exception { + int numHeartbeats = 6; + final int payload = 42; + + InstanceID ownInstanceID = new InstanceID("foobar"); + InstanceID targetInstanceID = new InstanceID("barfoo"); + + final CompletableFuture timeoutFuture = new CompletableFuture<>(); + final TestingHeartbeatListener heartbeatListener = + new TestingHeartbeatListenerBuilder() + .setRetrievePayloadFunction(ignored -> payload) + .setNotifyHeartbeatTimeoutConsumer(timeoutFuture::complete) + .createNewTestingHeartbeatListener(); + + HeartbeatManagerImpl heartbeatManager = + new HeartbeatManagerImpl<>( + HEARTBEAT_TIMEOUT, + ownInstanceID, + heartbeatListener, + TestingUtils.defaultScheduledExecutor(), + LOG); + + final HeartbeatTarget heartbeatTarget = + new TestingHeartbeatTargetBuilder().createTestingHeartbeatTarget(); + + heartbeatManager.monitorTarget(targetInstanceID, heartbeatTarget); + + for (int i = 0; i < numHeartbeats; i++) { + heartbeatManager.receiveHeartbeat(targetInstanceID, payload); + Thread.sleep(HEARTBEAT_INTERVAL); + } + + assertFalse(timeoutFuture.isDone()); + + InstanceID timeoutInstanceID = + timeoutFuture.get(2 * HEARTBEAT_TIMEOUT, TimeUnit.MILLISECONDS); + + assertEquals(targetInstanceID, timeoutInstanceID); + } + + /** + * Tests the heartbeat interplay between the {@link HeartbeatManagerImpl} and the {@link + * HeartbeatManagerSenderImpl}. The sender should regularly trigger heartbeat requests which are + * fulfilled by the receiver. Upon stopping the receiver, the sender should notify the heartbeat + * listener about the heartbeat timeout. + * + * @throws Exception when error happen. + */ + @Test + public void testHeartbeatCluster() throws Exception { + InstanceID instanceIDTarget = new InstanceID("foobar"); + InstanceID instanceIDSender = new InstanceID("barfoo"); + final int targetPayload = 42; + final AtomicInteger numReportPayloadCallsTarget = new AtomicInteger(0); + final TestingHeartbeatListener heartbeatListenerTarget = + new TestingHeartbeatListenerBuilder() + .setRetrievePayloadFunction(ignored -> targetPayload) + .setReportPayloadConsumer( + (ignoredA, ignoredB) -> + numReportPayloadCallsTarget.incrementAndGet()) + .createNewTestingHeartbeatListener(); + + final String senderPayload = "1337"; + final CompletableFuture targetHeartbeatTimeoutFuture = + new CompletableFuture<>(); + final AtomicInteger numReportPayloadCallsSender = new AtomicInteger(0); + final TestingHeartbeatListener heartbeatListenerSender = + new TestingHeartbeatListenerBuilder() + .setRetrievePayloadFunction(ignored -> senderPayload) + .setNotifyHeartbeatTimeoutConsumer(targetHeartbeatTimeoutFuture::complete) + .setReportPayloadConsumer( + (ignoredA, ignoredB) -> + numReportPayloadCallsSender.incrementAndGet()) + .createNewTestingHeartbeatListener(); + + HeartbeatManagerImpl heartbeatManagerTarget = + new HeartbeatManagerImpl<>( + HEARTBEAT_TIMEOUT, + instanceIDTarget, + heartbeatListenerTarget, + TestingUtils.defaultScheduledExecutor(), + LOG); + + HeartbeatManagerSenderImpl heartbeatManagerSender = + new HeartbeatManagerSenderImpl<>( + HEARTBEAT_INTERVAL, + HEARTBEAT_TIMEOUT, + instanceIDSender, + heartbeatListenerSender, + TestingUtils.defaultScheduledExecutor(), + LOG); + + heartbeatManagerTarget.monitorTarget(instanceIDSender, heartbeatManagerSender); + heartbeatManagerSender.monitorTarget(instanceIDTarget, heartbeatManagerTarget); + + Thread.sleep(2 * HEARTBEAT_TIMEOUT); + + assertFalse(targetHeartbeatTimeoutFuture.isDone()); + + heartbeatManagerTarget.stop(); + + InstanceID timeoutInstanceID = + targetHeartbeatTimeoutFuture.get(2 * HEARTBEAT_TIMEOUT, TimeUnit.MILLISECONDS); + + assertThat(timeoutInstanceID, is(instanceIDTarget)); + + int numberHeartbeats = (int) (2 * HEARTBEAT_TIMEOUT / HEARTBEAT_INTERVAL); + + final Matcher numberHeartbeatsMatcher = greaterThanOrEqualTo(numberHeartbeats / 2); + assertThat(numReportPayloadCallsTarget.get(), is(numberHeartbeatsMatcher)); + assertThat(numReportPayloadCallsSender.get(), is(numberHeartbeatsMatcher)); + } + + /** Tests that after unmonitoring a target, there won't be a timeout triggered. */ + @Test + public void testTargetUnmonitoring() throws Exception { + // this might be too aggressive for Travis, let's see... + long heartbeatTimeout = 50L; + InstanceID instanceID = new InstanceID("foobar"); + InstanceID targetID = new InstanceID("target"); + final int payload = 42; + + final CompletableFuture timeoutFuture = new CompletableFuture<>(); + final TestingHeartbeatListener heartbeatListener = + new TestingHeartbeatListenerBuilder() + .setRetrievePayloadFunction(ignored -> payload) + .setNotifyHeartbeatTimeoutConsumer(timeoutFuture::complete) + .createNewTestingHeartbeatListener(); + + HeartbeatManager heartbeatManager = + new HeartbeatManagerImpl<>( + heartbeatTimeout, + instanceID, + heartbeatListener, + TestingUtils.defaultScheduledExecutor(), + LOG); + + final HeartbeatTarget heartbeatTarget = + new TestingHeartbeatTargetBuilder().createTestingHeartbeatTarget(); + heartbeatManager.monitorTarget(targetID, heartbeatTarget); + + heartbeatManager.unmonitorTarget(targetID); + + try { + timeoutFuture.get(2 * heartbeatTimeout, TimeUnit.MILLISECONDS); + fail("Timeout should time out."); + } catch (TimeoutException ignored) { + // the timeout should not be completed since we unmonitored the target + } + } + + /** Tests that the last heartbeat from an unregistered target equals -1. */ + @Test + public void testLastHeartbeatFromUnregisteredTarget() { + final long heartbeatTimeout = 100L; + final InstanceID instanceID = new InstanceID(); + @SuppressWarnings("unchecked") + final HeartbeatListener heartbeatListener = mock(HeartbeatListener.class); + + HeartbeatManager heartbeatManager = + new HeartbeatManagerImpl<>( + heartbeatTimeout, + instanceID, + heartbeatListener, + mock(ScheduledExecutor.class), + LOG); + + try { + assertEquals(-1L, heartbeatManager.getLastHeartbeatFrom(new InstanceID())); + } finally { + heartbeatManager.stop(); + } + } + + /** Tests that we can correctly retrieve the last heartbeat for registered targets. */ + @Test + public void testLastHeartbeatFrom() { + final long heartbeatTimeout = 100L; + final InstanceID instanceID = new InstanceID(); + @SuppressWarnings("unchecked") + final HeartbeatListener heartbeatListener = mock(HeartbeatListener.class); + @SuppressWarnings("unchecked") + final HeartbeatTarget heartbeatTarget = mock(HeartbeatTarget.class); + final InstanceID target = new InstanceID(); + + HeartbeatManager heartbeatManager = + new HeartbeatManagerImpl<>( + heartbeatTimeout, + instanceID, + heartbeatListener, + mock(ScheduledExecutor.class), + LOG); + + try { + heartbeatManager.monitorTarget(target, heartbeatTarget); + + assertEquals(0L, heartbeatManager.getLastHeartbeatFrom(target)); + + final long currentTime = System.currentTimeMillis(); + + heartbeatManager.receiveHeartbeat(target, null); + + assertTrue(heartbeatManager.getLastHeartbeatFrom(target) >= currentTime); + } finally { + heartbeatManager.stop(); + } + } + + /** + * Tests that the heartbeat target {@link InstanceID} is properly passed to the {@link + * HeartbeatListener} by the {@link HeartbeatManagerImpl}. + */ + @Test + public void testHeartbeatManagerTargetPayload() throws Exception { + final long heartbeatTimeout = 100L; + + final InstanceID someTargetId = new InstanceID(); + final InstanceID specialTargetId = new InstanceID(); + + final Map payloads = new HashMap<>(2); + payloads.put(someTargetId, 0); + payloads.put(specialTargetId, 1); + + final CompletableFuture someHeartbeatPayloadFuture = new CompletableFuture<>(); + final TestingHeartbeatTarget someHeartbeatTarget = + new TestingHeartbeatTargetBuilder() + .setReceiveHeartbeatConsumer( + (ignored, payload) -> someHeartbeatPayloadFuture.complete(payload)) + .createTestingHeartbeatTarget(); + + final CompletableFuture specialHeartbeatPayloadFuture = new CompletableFuture<>(); + final TestingHeartbeatTarget specialHeartbeatTarget = + new TestingHeartbeatTargetBuilder() + .setReceiveHeartbeatConsumer( + (ignored, payload) -> + specialHeartbeatPayloadFuture.complete(payload)) + .createTestingHeartbeatTarget(); + + final TestingHeartbeatListener testingHeartbeatListener = + new TestingHeartbeatListenerBuilder() + .setRetrievePayloadFunction(payloads::get) + .createNewTestingHeartbeatListener(); + + HeartbeatManager heartbeatManager = + new HeartbeatManagerImpl<>( + heartbeatTimeout, + new InstanceID(), + testingHeartbeatListener, + TestingUtils.defaultScheduledExecutor(), + LOG); + + try { + heartbeatManager.monitorTarget(someTargetId, someHeartbeatTarget); + heartbeatManager.monitorTarget(specialTargetId, specialHeartbeatTarget); + + heartbeatManager.requestHeartbeat(someTargetId, null); + assertThat(someHeartbeatPayloadFuture.get(), is(payloads.get(someTargetId))); + + heartbeatManager.requestHeartbeat(specialTargetId, null); + assertThat(specialHeartbeatPayloadFuture.get(), is(payloads.get(specialTargetId))); + } finally { + heartbeatManager.stop(); + } + } + + /** + * Tests that the heartbeat target {@link InstanceID} is properly passed to the {@link + * HeartbeatListener} by the {@link HeartbeatManagerSenderImpl}. + */ + @Test + public void testHeartbeatManagerSenderTargetPayload() throws Exception { + final long heartbeatTimeout = 100L; + final long heartbeatPeriod = 2000L; + + final ScheduledThreadPoolExecutor scheduledThreadPoolExecutor = + new ScheduledThreadPoolExecutor(1); + + final InstanceID someTargetId = new InstanceID(); + final InstanceID specialTargetId = new InstanceID(); + + final OneShotLatch someTargetReceivedLatch = new OneShotLatch(); + final OneShotLatch specialTargetReceivedLatch = new OneShotLatch(); + + final TargetDependentHeartbeatReceiver someHeartbeatTarget = + new TargetDependentHeartbeatReceiver(someTargetReceivedLatch); + final TargetDependentHeartbeatReceiver specialHeartbeatTarget = + new TargetDependentHeartbeatReceiver(specialTargetReceivedLatch); + + final int defaultResponse = 0; + final int specialResponse = 1; + + HeartbeatManager heartbeatManager = + new HeartbeatManagerSenderImpl<>( + heartbeatPeriod, + heartbeatTimeout, + new InstanceID(), + new TargetDependentHeartbeatSender( + specialTargetId, specialResponse, defaultResponse), + new ScheduledExecutorServiceAdapter(scheduledThreadPoolExecutor), + LOG); + + try { + heartbeatManager.monitorTarget(someTargetId, someHeartbeatTarget); + heartbeatManager.monitorTarget(specialTargetId, specialHeartbeatTarget); + + someTargetReceivedLatch.await(5, TimeUnit.SECONDS); + specialTargetReceivedLatch.await(5, TimeUnit.SECONDS); + + assertEquals(defaultResponse, someHeartbeatTarget.getLastRequestedHeartbeatPayload()); + assertEquals( + specialResponse, specialHeartbeatTarget.getLastRequestedHeartbeatPayload()); + } finally { + heartbeatManager.stop(); + scheduledThreadPoolExecutor.shutdown(); + } + } + + /** Test {@link HeartbeatTarget} that exposes the last received payload. */ + private static class TargetDependentHeartbeatReceiver implements HeartbeatTarget { + + private volatile int lastReceivedHeartbeatPayload = -1; + private volatile int lastRequestedHeartbeatPayload = -1; + + private final OneShotLatch latch; + + public TargetDependentHeartbeatReceiver() { + this(new OneShotLatch()); + } + + public TargetDependentHeartbeatReceiver(OneShotLatch latch) { + this.latch = latch; + } + + @Override + public void receiveHeartbeat(InstanceID heartbeatOrigin, Integer heartbeatPayload) { + this.lastReceivedHeartbeatPayload = heartbeatPayload; + latch.trigger(); + } + + @Override + public void requestHeartbeat(InstanceID requestOrigin, Integer heartbeatPayload) { + this.lastRequestedHeartbeatPayload = heartbeatPayload; + latch.trigger(); + } + + public int getLastReceivedHeartbeatPayload() { + return lastReceivedHeartbeatPayload; + } + + public int getLastRequestedHeartbeatPayload() { + return lastRequestedHeartbeatPayload; + } + } + + /** + * Test {@link HeartbeatListener} that returns different payloads based on the target {@link + * InstanceID}. + */ + private static class TargetDependentHeartbeatSender + implements HeartbeatListener { + private final InstanceID specialId; + private final int specialResponse; + private final int defaultResponse; + + TargetDependentHeartbeatSender( + InstanceID specialId, int specialResponse, int defaultResponse) { + this.specialId = specialId; + this.specialResponse = specialResponse; + this.defaultResponse = defaultResponse; + } + + @Override + public void notifyHeartbeatTimeout(InstanceID instanceID) {} + + @Override + public void reportPayload(InstanceID instanceID, Object payload) {} + + @Override + public Integer retrievePayload(InstanceID instanceID) { + if (instanceID.equals(specialId)) { + return specialResponse; + } else { + return defaultResponse; + } + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/TestingHeartbeatListener.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/TestingHeartbeatListener.java new file mode 100644 index 00000000..2d1e3499 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/TestingHeartbeatListener.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; + +/** {@link HeartbeatListener} implementation for tests. */ +final class TestingHeartbeatListener implements HeartbeatListener { + + private final Consumer notifyHeartbeatTimeoutConsumer; + + private final BiConsumer reportPayloadConsumer; + + private final Function retrievePayloadFunction; + + TestingHeartbeatListener( + Consumer notifyHeartbeatTimeoutConsumer, + BiConsumer reportPayloadConsumer, + Function retrievePayloadFunction) { + this.notifyHeartbeatTimeoutConsumer = notifyHeartbeatTimeoutConsumer; + this.reportPayloadConsumer = reportPayloadConsumer; + this.retrievePayloadFunction = retrievePayloadFunction; + } + + @Override + public void notifyHeartbeatTimeout(InstanceID instanceID) { + notifyHeartbeatTimeoutConsumer.accept(instanceID); + } + + @Override + public void reportPayload(InstanceID instanceID, I payload) { + reportPayloadConsumer.accept(instanceID, payload); + } + + @Override + public O retrievePayload(InstanceID instanceID) { + return retrievePayloadFunction.apply(instanceID); + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/TestingHeartbeatListenerBuilder.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/TestingHeartbeatListenerBuilder.java new file mode 100644 index 00000000..7569601e --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/TestingHeartbeatListenerBuilder.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; + +class TestingHeartbeatListenerBuilder { + private Consumer notifyHeartbeatTimeoutConsumer = ignored -> {}; + private BiConsumer reportPayloadConsumer = (ignoredA, ignoredB) -> {}; + private Function retrievePayloadFunction = ignored -> null; + + public TestingHeartbeatListenerBuilder setNotifyHeartbeatTimeoutConsumer( + Consumer notifyHeartbeatTimeoutConsumer) { + this.notifyHeartbeatTimeoutConsumer = notifyHeartbeatTimeoutConsumer; + return this; + } + + public TestingHeartbeatListenerBuilder setReportPayloadConsumer( + BiConsumer reportPayloadConsumer) { + this.reportPayloadConsumer = reportPayloadConsumer; + return this; + } + + public TestingHeartbeatListenerBuilder setRetrievePayloadFunction( + Function retrievePayloadFunction) { + this.retrievePayloadFunction = retrievePayloadFunction; + return this; + } + + public TestingHeartbeatListener createNewTestingHeartbeatListener() { + return new TestingHeartbeatListener<>( + notifyHeartbeatTimeoutConsumer, reportPayloadConsumer, retrievePayloadFunction); + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/TestingHeartbeatServices.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/TestingHeartbeatServices.java new file mode 100644 index 00000000..57c05e77 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/TestingHeartbeatServices.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.rpc.executor.ScheduledExecutor; + +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * A {@link HeartbeatServices} implementation for testing purposes. This implementation is able to + * trigger a timeout of specific component manually. + */ +public class TestingHeartbeatServices extends HeartbeatServices { + + private static final long DEFAULT_HEARTBEAT_TIMEOUT = 10000L; + + private static final long DEFAULT_HEARTBEAT_INTERVAL = 1000L; + + @SuppressWarnings("rawtypes") + private final Map> heartbeatManagers = + new ConcurrentHashMap<>(); + + @SuppressWarnings("rawtypes") + private final Map> heartbeatManagerSenders = + new ConcurrentHashMap<>(); + + public TestingHeartbeatServices() { + super(DEFAULT_HEARTBEAT_INTERVAL, DEFAULT_HEARTBEAT_TIMEOUT); + } + + public TestingHeartbeatServices(long heartbeatInterval) { + super(heartbeatInterval, DEFAULT_HEARTBEAT_TIMEOUT); + } + + public TestingHeartbeatServices(long heartbeatInterval, long heartbeatTimeout) { + super(heartbeatInterval, heartbeatTimeout); + } + + @Override + public HeartbeatManager createHeartbeatManager( + InstanceID instanceID, + HeartbeatListener heartbeatListener, + ScheduledExecutor mainThreadExecutor, + Logger log) { + + HeartbeatManagerImpl heartbeatManager = + new HeartbeatManagerImpl<>( + heartbeatTimeout, + instanceID, + heartbeatListener, + mainThreadExecutor, + log, + new TestingHeartbeatMonitorFactory<>()); + + heartbeatManagers.compute( + instanceID, + (ignored, heartbeatManagers) -> { + @SuppressWarnings("rawtypes") + final Collection result; + + if (heartbeatManagers != null) { + result = heartbeatManagers; + } else { + result = new ArrayList<>(); + } + + result.add(heartbeatManager); + return result; + }); + + return heartbeatManager; + } + + @Override + public HeartbeatManager createHeartbeatManagerSender( + InstanceID instanceID, + HeartbeatListener heartbeatListener, + ScheduledExecutor mainThreadExecutor, + Logger log) { + + HeartbeatManagerSenderImpl heartbeatManager = + new HeartbeatManagerSenderImpl<>( + heartbeatInterval, + heartbeatTimeout, + instanceID, + heartbeatListener, + mainThreadExecutor, + log, + new TestingHeartbeatMonitorFactory<>()); + + heartbeatManagerSenders.compute( + instanceID, + (ignored, heartbeatManagers) -> { + final Collection result; + + if (heartbeatManagers != null) { + result = heartbeatManagers; + } else { + result = new ArrayList<>(); + } + + result.add(heartbeatManager); + return result; + }); + + return heartbeatManager; + } + + public void triggerHeartbeatTimeout(InstanceID managerInstanceID, InstanceID targetInstanceID) { + + boolean triggered = false; + Collection heartbeatManagerList = + heartbeatManagers.get(managerInstanceID); + if (heartbeatManagerList != null) { + for (HeartbeatManagerImpl heartbeatManager : heartbeatManagerList) { + final TestingHeartbeatMonitor monitor = + (TestingHeartbeatMonitor) + heartbeatManager.getHeartbeatTargets().get(targetInstanceID); + if (monitor != null) { + monitor.triggerHeartbeatTimeout(); + triggered = true; + } + } + } + + final Collection heartbeatManagerSenderList = + this.heartbeatManagerSenders.get(managerInstanceID); + if (heartbeatManagerSenderList != null) { + for (HeartbeatManagerSenderImpl heartbeatManagerSender : heartbeatManagerSenderList) { + final TestingHeartbeatMonitor monitor = + (TestingHeartbeatMonitor) + heartbeatManagerSender.getHeartbeatTargets().get(targetInstanceID); + if (monitor != null) { + monitor.triggerHeartbeatTimeout(); + triggered = true; + } + } + } + + checkState( + triggered, + "There is no target " + + targetInstanceID + + " monitored under Heartbeat manager " + + managerInstanceID); + } + + /** + * Factory instantiates testing monitor instance. + * + * @param Type of the outgoing heartbeat payload + */ + static class TestingHeartbeatMonitorFactory implements HeartbeatMonitor.Factory { + + @Override + public HeartbeatMonitor createHeartbeatMonitor( + InstanceID instanceID, + HeartbeatTarget heartbeatTarget, + ScheduledExecutor mainThreadExecutor, + HeartbeatListener heartbeatListener, + long heartbeatTimeoutIntervalMs) { + + return new TestingHeartbeatMonitor<>( + instanceID, + heartbeatTarget, + mainThreadExecutor, + heartbeatListener, + heartbeatTimeoutIntervalMs); + } + } + + /** + * A heartbeat monitor for testing which supports triggering timeout manually. + * + * @param Type of the outgoing heartbeat payload + */ + static class TestingHeartbeatMonitor extends HeartbeatMonitorImpl { + + private volatile boolean timeoutTriggered = false; + + TestingHeartbeatMonitor( + InstanceID instanceID, + HeartbeatTarget heartbeatTarget, + ScheduledExecutor scheduledExecutor, + HeartbeatListener heartbeatListener, + long heartbeatTimeoutIntervalMs) { + + super( + instanceID, + heartbeatTarget, + scheduledExecutor, + heartbeatListener, + heartbeatTimeoutIntervalMs); + } + + @Override + public void reportHeartbeat() { + if (!timeoutTriggered) { + super.reportHeartbeat(); + } + // just swallow the heartbeat report + } + + @Override + void resetHeartbeatTimeout(long heartbeatTimeout) { + synchronized (this) { + if (timeoutTriggered) { + super.resetHeartbeatTimeout(0); + } else { + super.resetHeartbeatTimeout(heartbeatTimeout); + } + } + } + + void triggerHeartbeatTimeout() { + synchronized (this) { + timeoutTriggered = true; + resetHeartbeatTimeout(0); + } + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/TestingHeartbeatTarget.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/TestingHeartbeatTarget.java new file mode 100644 index 00000000..ef5ca9c4 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/TestingHeartbeatTarget.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +import java.util.function.BiConsumer; + +/** {@link HeartbeatTarget} implementation for tests. */ +class TestingHeartbeatTarget implements HeartbeatTarget { + private final BiConsumer receiveHeartbeatConsumer; + + private final BiConsumer requestHeartbeatConsumer; + + TestingHeartbeatTarget( + BiConsumer receiveHeartbeatConsumer, + BiConsumer requestHeartbeatConsumer) { + this.receiveHeartbeatConsumer = receiveHeartbeatConsumer; + this.requestHeartbeatConsumer = requestHeartbeatConsumer; + } + + @Override + public void receiveHeartbeat(InstanceID heartbeatOrigin, T heartbeatPayload) { + receiveHeartbeatConsumer.accept(heartbeatOrigin, heartbeatPayload); + } + + @Override + public void requestHeartbeat(InstanceID requestOrigin, T heartbeatPayload) { + requestHeartbeatConsumer.accept(requestOrigin, heartbeatPayload); + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/TestingHeartbeatTargetBuilder.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/TestingHeartbeatTargetBuilder.java new file mode 100644 index 00000000..bd0ff595 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/heartbeat/TestingHeartbeatTargetBuilder.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.heartbeat; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +import java.util.function.BiConsumer; + +class TestingHeartbeatTargetBuilder { + private BiConsumer receiveHeartbeatConsumer = (ignoredA, ignoredB) -> {}; + private BiConsumer requestHeartbeatConsumer = (ignoredA, ignoredB) -> {}; + + public TestingHeartbeatTargetBuilder setReceiveHeartbeatConsumer( + BiConsumer receiveHeartbeatConsumer) { + this.receiveHeartbeatConsumer = receiveHeartbeatConsumer; + return this; + } + + public TestingHeartbeatTargetBuilder setRequestHeartbeatConsumer( + BiConsumer requestHeartbeatConsumer) { + this.requestHeartbeatConsumer = requestHeartbeatConsumer; + return this; + } + + public TestingHeartbeatTarget createTestingHeartbeatTarget() { + return new TestingHeartbeatTarget<>(receiveHeartbeatConsumer, requestHeartbeatConsumer); + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/highavailability/DirectlyFailingFatalErrorHandler.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/highavailability/DirectlyFailingFatalErrorHandler.java new file mode 100644 index 00000000..d0cb1011 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/highavailability/DirectlyFailingFatalErrorHandler.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; + +/** Testing {@link FatalErrorHandler} implementation which directly failed. */ +public enum DirectlyFailingFatalErrorHandler implements FatalErrorHandler { + INSTANCE; + + @Override + public void onFatalError(Throwable exception) { + throw new RuntimeException("Could not handle the fatal error, failing", exception); + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/highavailability/HighAvailabilityServicesUtilsTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/highavailability/HighAvailabilityServicesUtilsTest.java new file mode 100644 index 00000000..a74df7c2 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/highavailability/HighAvailabilityServicesUtilsTest.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; +import com.alibaba.flink.shuffle.core.utils.TestLogger; + +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.concurrent.Executor; + +import static org.junit.Assert.assertSame; + +/** Tests for the {@link HaServiceUtils} class. */ +public class HighAvailabilityServicesUtilsTest extends TestLogger { + + @ClassRule public static TemporaryFolder temporaryFolder = new TemporaryFolder(); + + @Test + public void testCreateCustomHAServices() throws Exception { + Configuration config = new Configuration(); + + HaServices haServices = new TestingHighAvailabilityServices(); + TestHAFactory.haServices = haServices; + + Executor executor = Runnable::run; + + config.setString(HighAvailabilityOptions.HA_MODE, TestHAFactory.class.getName()); + + // when + HaServices actualHaServices = + HaServiceUtils.createAvailableOrEmbeddedServices(config, executor); + + // then + assertSame(haServices, actualHaServices); + + // when + actualHaServices = HaServiceUtils.createHAServices(config); + // then + assertSame(haServices, actualHaServices); + } + + @Test(expected = Exception.class) + public void testCustomHAServicesFactoryNotDefined() throws Exception { + Configuration config = new Configuration(); + + Executor executor = Runnable::run; + + config.setString( + HighAvailabilityOptions.HA_MODE, HaMode.FACTORY_CLASS.name().toLowerCase()); + + // expect + HaServiceUtils.createAvailableOrEmbeddedServices(config, executor); + } + + /** Testing class which needs to be public in order to be instantiatable. */ + public static class TestHAFactory implements HaServicesFactory { + + static HaServices haServices; + + @Override + public HaServices createHAServices(Configuration configuration) { + return haServices; + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/highavailability/TestingHaServices.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/highavailability/TestingHaServices.java new file mode 100644 index 00000000..5b882bb0 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/highavailability/TestingHaServices.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** Testing {@link HaServices} implementation. */ +public class TestingHaServices implements HaServices { + + private LeaderRetrievalService shuffleManagerLeaderRetrieveService; + + private LeaderElectionService shuffleManagerLeaderElectionService; + + // ------------------------------------------------------------------------ + // Setters for mock / testing implementations + // ------------------------------------------------------------------------ + + public void setShuffleManagerLeaderRetrieveService( + LeaderRetrievalService shuffleManagerLeaderRetrieveService) { + this.shuffleManagerLeaderRetrieveService = shuffleManagerLeaderRetrieveService; + } + + public void setShuffleManagerLeaderElectionService( + LeaderElectionService shuffleManagerLeaderElectionService) { + this.shuffleManagerLeaderElectionService = shuffleManagerLeaderElectionService; + } + + // ------------------------------------------------------------------------ + // HA Services Methods + // ------------------------------------------------------------------------ + + @Override + public LeaderRetrievalService createLeaderRetrievalService(LeaderReceptor receptor) { + checkNotNull(shuffleManagerLeaderRetrieveService); + + return shuffleManagerLeaderRetrieveService; + } + + @Override + public LeaderElectionService createLeaderElectionService() { + checkNotNull(shuffleManagerLeaderElectionService); + + return shuffleManagerLeaderElectionService; + } + + @Override + public void closeAndCleanupAllData() throws Exception {} + + @Override + public void close() throws Exception {} +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/highavailability/TestingHighAvailabilityServices.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/highavailability/TestingHighAvailabilityServices.java new file mode 100644 index 00000000..2b6f8156 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/highavailability/TestingHighAvailabilityServices.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +/** + * A variant of the HighAvailabilityServices for testing. Each individual service can be set to an + * arbitrary implementation, such as a mock or default service. + */ +public class TestingHighAvailabilityServices implements HaServices { + + private LeaderRetrievalService shuffleManagerLeaderRetrievalService; + + private LeaderElectionService shuffleManagerLeaderElectionService; + + public void setShuffleManagerLeaderRetrievalService( + LeaderRetrievalService shuffleManagerLeaderRetrievalService) { + this.shuffleManagerLeaderRetrievalService = shuffleManagerLeaderRetrievalService; + } + + public void setShuffleManagerLeaderElectionService( + LeaderElectionService shuffleManagerLeaderElectionService) { + this.shuffleManagerLeaderElectionService = shuffleManagerLeaderElectionService; + } + + // ------------------------------------------------------------------------ + // Shutdown + // ------------------------------------------------------------------------ + + @Override + public void close() throws Exception { + // nothing to do + } + + @Override + public LeaderRetrievalService createLeaderRetrievalService(LeaderReceptor receptor) { + return shuffleManagerLeaderRetrievalService; + } + + @Override + public LeaderElectionService createLeaderElectionService() { + return shuffleManagerLeaderElectionService; + } + + @Override + public void closeAndCleanupAllData() throws Exception { + // nothing to do + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/highavailability/ZooKeeperUtilsTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/highavailability/ZooKeeperUtilsTest.java new file mode 100644 index 00000000..a91eb050 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/highavailability/ZooKeeperUtilsTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.highavailability; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperUtils; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; +import com.alibaba.flink.shuffle.core.utils.TestLogger; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** Tests for {@link ZooKeeperUtils}. */ +public class ZooKeeperUtilsTest extends TestLogger { + + @Test + public void testZooKeeperEnsembleConnectStringConfiguration() throws Exception { + // ZooKeeper does not like whitespace in the quorum connect String. + String actual, expected; + Configuration conf = new Configuration(); + + { + expected = "localhost:2891"; + + setQuorum(conf, expected); + actual = ZooKeeperUtils.getZooKeeperEnsemble(conf); + assertEquals(expected, actual); + + setQuorum(conf, " localhost:2891 "); // with leading and trailing whitespace + actual = ZooKeeperUtils.getZooKeeperEnsemble(conf); + assertEquals(expected, actual); + + setQuorum(conf, "localhost :2891"); // whitespace after port + actual = ZooKeeperUtils.getZooKeeperEnsemble(conf); + assertEquals(expected, actual); + } + + { + expected = "localhost:2891,localhost:2891"; + + setQuorum(conf, "localhost:2891,localhost:2891"); + actual = ZooKeeperUtils.getZooKeeperEnsemble(conf); + assertEquals(expected, actual); + + setQuorum(conf, "localhost:2891, localhost:2891"); + actual = ZooKeeperUtils.getZooKeeperEnsemble(conf); + assertEquals(expected, actual); + + setQuorum(conf, "localhost :2891, localhost:2891"); + actual = ZooKeeperUtils.getZooKeeperEnsemble(conf); + assertEquals(expected, actual); + + setQuorum(conf, " localhost:2891, localhost:2891 "); + actual = ZooKeeperUtils.getZooKeeperEnsemble(conf); + assertEquals(expected, actual); + } + } + + private Configuration setQuorum(Configuration conf, String quorum) { + conf.setString(HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM, quorum); + return conf; + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/LeaderElectionTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/LeaderElectionTest.java new file mode 100644 index 00000000..bd9f412e --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/LeaderElectionTest.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderelection; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderContender; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.embeded.EmbeddedLeaderService; +import com.alibaba.flink.shuffle.coordinator.highavailability.standalone.StandaloneLeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperHaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperUtils; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; +import com.alibaba.flink.shuffle.core.utils.TestLogger; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; + +import org.apache.curator.test.TestingServer; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; +import java.util.UUID; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; + +/** Tests for leader election. */ +@RunWith(Parameterized.class) +public class LeaderElectionTest extends TestLogger { + + enum LeaderElectionType { + ZooKeeper, + Embedded, + Standalone + } + + @Parameterized.Parameters(name = "Leader election: {0}") + public static Collection parameters() { + return Arrays.asList(LeaderElectionType.values()); + } + + private final ServiceClass serviceClass; + + public LeaderElectionTest(LeaderElectionType leaderElectionType) { + switch (leaderElectionType) { + case ZooKeeper: + serviceClass = new ZooKeeperServiceClass(); + break; + case Embedded: + serviceClass = new EmbeddedServiceClass(); + break; + case Standalone: + serviceClass = new StandaloneServiceClass(); + break; + default: + throw new IllegalArgumentException( + String.format("Unknown leader election type: %s.", leaderElectionType)); + } + } + + @Before + public void setup() throws Exception { + serviceClass.setup(); + } + + @After + public void teardown() throws Exception { + serviceClass.teardown(); + } + + @Test + public void testHasLeadership() throws Exception { + final LeaderElectionService leaderElectionService = + serviceClass.createLeaderElectionService(); + final ManualLeaderContender manualLeaderContender = new ManualLeaderContender(); + + try { + assertThat(leaderElectionService.hasLeadership(UUID.randomUUID()), is(false)); + + leaderElectionService.start(manualLeaderContender); + + final UUID leaderSessionId = manualLeaderContender.waitForLeaderSessionId(); + + assertThat(leaderElectionService.hasLeadership(leaderSessionId), is(true)); + assertThat(leaderElectionService.hasLeadership(UUID.randomUUID()), is(false)); + + leaderElectionService.confirmLeadership( + new LeaderInformation(leaderSessionId, "foobar")); + + assertThat(leaderElectionService.hasLeadership(leaderSessionId), is(true)); + + leaderElectionService.stop(); + + assertThat(leaderElectionService.hasLeadership(leaderSessionId), is(false)); + } finally { + manualLeaderContender.rethrowError(); + } + } + + private static final class ManualLeaderContender implements LeaderContender { + + private static final UUID NULL_LEADER_SESSION_ID = new UUID(0L, 0L); + + private final ArrayBlockingQueue leaderSessionIds = new ArrayBlockingQueue<>(10); + + private volatile Throwable throwable; + + @Override + public void grantLeadership(UUID leaderSessionID) { + leaderSessionIds.offer(leaderSessionID); + } + + @Override + public void revokeLeadership() { + leaderSessionIds.offer(NULL_LEADER_SESSION_ID); + } + + @Override + public String getDescription() { + return "foobar"; + } + + @Override + public void handleError(Throwable throwable) { + this.throwable = throwable; + } + + void rethrowError() throws Exception { + if (throwable != null) { + ExceptionUtils.rethrowException(throwable); + } + } + + UUID waitForLeaderSessionId() throws InterruptedException { + return leaderSessionIds.take(); + } + } + + private interface ServiceClass { + void setup() throws Exception; + + void teardown() throws Exception; + + LeaderElectionService createLeaderElectionService() throws Exception; + } + + private static final class ZooKeeperServiceClass implements ServiceClass { + + private TestingServer testingServer; + + private CuratorFramework client; + + private Configuration configuration; + + @Override + public void setup() throws Exception { + try { + testingServer = new TestingServer(); + } catch (Exception e) { + throw new RuntimeException("Could not start ZooKeeper testing cluster.", e); + } + + configuration = new Configuration(); + + configuration.setString( + HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM, testingServer.getConnectString()); + configuration.setString(HighAvailabilityOptions.HA_MODE, "zookeeper"); + + client = ZooKeeperUtils.startCuratorFramework(configuration); + } + + @Override + public void teardown() throws Exception { + if (client != null) { + client.close(); + client = null; + } + + if (testingServer != null) { + testingServer.stop(); + testingServer = null; + } + } + + @Override + public LeaderElectionService createLeaderElectionService() throws Exception { + return ZooKeeperUtils.createLeaderElectionService( + client, + configuration, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_LATCH_PATH, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH); + } + } + + private static final class EmbeddedServiceClass implements ServiceClass { + + private EmbeddedLeaderService embeddedLeaderService; + + private ExecutorService executor; + + @Override + public void setup() { + executor = Executors.newSingleThreadExecutor(); + embeddedLeaderService = new EmbeddedLeaderService(executor); + } + + @Override + public void teardown() { + if (embeddedLeaderService != null) { + embeddedLeaderService.shutdown(); + embeddedLeaderService = null; + } + + if (executor != null) { + executor.shutdown(); + executor = null; + } + } + + @Override + public LeaderElectionService createLeaderElectionService() throws Exception { + return embeddedLeaderService.createLeaderElectionService(); + } + } + + private static final class StandaloneServiceClass implements ServiceClass { + + @Override + public void setup() throws Exception { + // noop + } + + @Override + public void teardown() throws Exception { + // noop + } + + @Override + public LeaderElectionService createLeaderElectionService() throws Exception { + return new StandaloneLeaderElectionService(); + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/StandaloneLeaderElectionTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/StandaloneLeaderElectionTest.java new file mode 100644 index 00000000..9ceed074 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/StandaloneLeaderElectionTest.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderelection; + +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.standalone.StandaloneLeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.standalone.StandaloneLeaderRetrievalService; +import com.alibaba.flink.shuffle.core.utils.TestLogger; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** Tests for {@link StandaloneLeaderElectionService}. */ +public class StandaloneLeaderElectionTest extends TestLogger { + + private static final String TEST_URL = "akka://users/shufflemanager"; + + /** + * Tests that the standalone leader election and retrieval service return the same leader URL. + */ + @Test + public void testStandaloneLeaderElectionRetrieval() throws Exception { + StandaloneLeaderElectionService leaderElectionService = + new StandaloneLeaderElectionService(); + StandaloneLeaderRetrievalService leaderRetrievalService = + new StandaloneLeaderRetrievalService( + new LeaderInformation(HaServices.DEFAULT_LEADER_ID, TEST_URL)); + TestingContender contender = new TestingContender(TEST_URL, leaderElectionService); + TestingListener testingListener = new TestingListener(); + + try { + leaderElectionService.start(contender); + leaderRetrievalService.start(testingListener); + + contender.waitForLeader(1000L); + + assertTrue(contender.isLeader()); + assertEquals(HaServices.DEFAULT_LEADER_ID, contender.getLeaderSessionID()); + + testingListener.waitForNewLeader(1000L); + + assertEquals(TEST_URL, testingListener.getAddress()); + assertEquals(HaServices.DEFAULT_LEADER_ID, testingListener.getLeaderSessionID()); + } finally { + leaderElectionService.stop(); + leaderRetrievalService.stop(); + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingContender.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingContender.java new file mode 100644 index 00000000..177c8446 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingContender.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderelection; + +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderContender; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.UUID; + +/** + * {@link LeaderContender} implementation which provides some convenience functions for testing + * purposes. + */ +public class TestingContender extends TestingLeaderBase implements LeaderContender { + private static final Logger LOG = LoggerFactory.getLogger(TestingContender.class); + + private final String address; + private final LeaderElectionService leaderElectionService; + + private UUID leaderSessionID = null; + + public TestingContender( + final String address, final LeaderElectionService leaderElectionService) { + this.address = address; + this.leaderElectionService = leaderElectionService; + } + + @Override + public void grantLeadership(UUID leaderSessionID) { + LOG.debug("Was granted leadership with session ID {}.", leaderSessionID); + + this.leaderSessionID = leaderSessionID; + + leaderElectionService.confirmLeadership(new LeaderInformation(leaderSessionID, address)); + + leaderEventQueue.offer(new LeaderInformation(leaderSessionID, address)); + } + + @Override + public void revokeLeadership() { + LOG.debug("Was revoked leadership. Old session ID {}.", leaderSessionID); + + leaderSessionID = null; + leaderEventQueue.offer(LeaderInformation.empty()); + } + + @Override + public String getDescription() { + return address; + } + + @Override + public void handleError(Throwable throwable) { + super.handleError(throwable); + } + + public UUID getLeaderSessionID() { + return leaderSessionID; + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingLeaderBase.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingLeaderBase.java new file mode 100644 index 00000000..c2432d2a --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingLeaderBase.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderelection; + +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderContender; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionEventHandler; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.utils.CommonTestUtils; + +import javax.annotation.Nullable; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** + * Base class which provides some convenience functions for testing purposes of {@link + * LeaderContender} and {@link LeaderElectionEventHandler}. + */ +public class TestingLeaderBase { + // The queues will be offered by subclasses + protected final BlockingQueue leaderEventQueue = new LinkedBlockingQueue<>(); + private final BlockingQueue errorQueue = new LinkedBlockingQueue<>(); + + private boolean isLeader = false; + private Throwable error; + + public void waitForLeader(long timeout) throws Exception { + throwExceptionIfNotNull(); + + final String errorMsg = "Contender was not elected as the leader within " + timeout + "ms"; + CommonTestUtils.waitUntilCondition( + () -> { + final LeaderInformation leader = + leaderEventQueue.poll(timeout, TimeUnit.MILLISECONDS); + return leader != null && !leader.isEmpty(); + }, + timeout, + errorMsg); + + isLeader = true; + } + + public void waitForRevokeLeader(long timeout) throws Exception { + throwExceptionIfNotNull(); + + final String errorMsg = "Contender was not revoked within " + timeout + "ms"; + CommonTestUtils.waitUntilCondition( + () -> { + final LeaderInformation leader = + leaderEventQueue.poll(timeout, TimeUnit.MILLISECONDS); + return leader != null && leader.isEmpty(); + }, + timeout, + errorMsg); + + isLeader = false; + } + + public void waitForError(long timeout) throws Exception { + final String errorMsg = "Contender did not see an exception with " + timeout + "ms"; + CommonTestUtils.waitUntilCondition( + () -> { + error = errorQueue.poll(timeout, TimeUnit.MILLISECONDS); + return error != null; + }, + timeout, + errorMsg); + } + + public void handleError(Throwable ex) { + errorQueue.offer(ex); + } + + /** + * Please use {@link #waitForError} before get the error. + * + * @return the error has been handled. + */ + @Nullable + public Throwable getError() { + return this.error; + } + + public boolean isLeader() { + return isLeader; + } + + private void throwExceptionIfNotNull() throws Exception { + if (error != null) { + ExceptionUtils.rethrowException(error); + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingLeaderElectionEventHandler.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingLeaderElectionEventHandler.java new file mode 100644 index 00000000..f15eb183 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingLeaderElectionEventHandler.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderelection; + +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionDriver; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionEventHandler; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.core.utils.OneShotLatch; + +import javax.annotation.Nullable; + +import java.util.function.Consumer; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * {@link LeaderElectionEventHandler} implementation which provides some convenience functions for + * testing purposes. + */ +public class TestingLeaderElectionEventHandler extends TestingLeaderBase + implements LeaderElectionEventHandler { + + private final LeaderInformation leaderInformation; + + private final OneShotLatch initializationLatch; + + @Nullable private LeaderElectionDriver initializedLeaderElectionDriver = null; + + private LeaderInformation confirmedLeaderInformation = LeaderInformation.empty(); + + public TestingLeaderElectionEventHandler(LeaderInformation leaderInformation) { + this.leaderInformation = leaderInformation; + this.initializationLatch = new OneShotLatch(); + } + + public void init(LeaderElectionDriver leaderElectionDriver) { + checkState(initializedLeaderElectionDriver == null); + this.initializedLeaderElectionDriver = leaderElectionDriver; + initializationLatch.trigger(); + } + + @Override + public void onGrantLeadership() { + waitForInitialization( + leaderElectionDriver -> { + confirmedLeaderInformation = leaderInformation; + leaderElectionDriver.writeLeaderInformation(confirmedLeaderInformation); + leaderEventQueue.offer(confirmedLeaderInformation); + }); + } + + @Override + public void onRevokeLeadership() { + waitForInitialization( + (leaderElectionDriver) -> { + confirmedLeaderInformation = LeaderInformation.empty(); + leaderElectionDriver.writeLeaderInformation(confirmedLeaderInformation); + leaderEventQueue.offer(confirmedLeaderInformation); + }); + } + + @Override + public void onLeaderInformationChange(LeaderInformation leaderInfo) { + waitForInitialization( + leaderElectionDriver -> { + if (confirmedLeaderInformation.getLeaderSessionID() != null + && !this.confirmedLeaderInformation.equals(leaderInfo)) { + leaderElectionDriver.writeLeaderInformation(confirmedLeaderInformation); + } + }); + } + + private void waitForInitialization(Consumer operation) { + try { + initializationLatch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + checkState(initializedLeaderElectionDriver != null); + operation.accept(initializedLeaderElectionDriver); + } + + public LeaderInformation getConfirmedLeaderInformation() { + return confirmedLeaderInformation; + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingLeaderElectionService.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingLeaderElectionService.java new file mode 100644 index 00000000..fc59e7e0 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingLeaderElectionService.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderelection; + +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderContender; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.utils.LeaderConnectionInfo; + +import javax.annotation.Nonnull; + +import java.util.UUID; +import java.util.concurrent.CompletableFuture; + +/** + * Test {@link LeaderElectionService} implementation which directly forwards isLeader and notLeader + * calls to the contender. + */ +public class TestingLeaderElectionService implements LeaderElectionService { + + private LeaderContender contender = null; + private boolean hasLeadership = false; + private CompletableFuture confirmationFuture = null; + private CompletableFuture startFuture = new CompletableFuture<>(); + private UUID issuedLeaderSessionId = null; + + /** + * Gets a future that completes when leadership is confirmed. + * + *

Note: the future is created upon calling {@link #isLeader(UUID)}. + */ + public synchronized CompletableFuture getConfirmationFuture() { + return confirmationFuture; + } + + @Override + public synchronized void start(LeaderContender contender) { + assert (!getStartFuture().isDone()); + + this.contender = contender; + + if (hasLeadership) { + contender.grantLeadership(issuedLeaderSessionId); + } + + startFuture.complete(null); + } + + @Override + public synchronized void stop() throws Exception { + contender = null; + hasLeadership = false; + issuedLeaderSessionId = null; + startFuture.cancel(false); + startFuture = new CompletableFuture<>(); + } + + @Override + public synchronized void confirmLeadership(LeaderInformation leaderInfo) { + if (confirmationFuture != null) { + confirmationFuture.complete( + new LeaderConnectionInfo( + leaderInfo.getLeaderSessionID(), leaderInfo.getLeaderAddress())); + } + } + + @Override + public synchronized boolean hasLeadership(@Nonnull UUID leaderSessionId) { + return hasLeadership && leaderSessionId.equals(issuedLeaderSessionId); + } + + public synchronized CompletableFuture isLeader(UUID leaderSessionID) { + if (confirmationFuture != null) { + confirmationFuture.cancel(false); + } + confirmationFuture = new CompletableFuture<>(); + hasLeadership = true; + issuedLeaderSessionId = leaderSessionID; + + if (contender != null) { + contender.grantLeadership(leaderSessionID); + } + + return confirmationFuture.thenApply(LeaderConnectionInfo::getLeaderSessionId); + } + + public synchronized void notLeader() { + hasLeadership = false; + + if (contender != null) { + contender.revokeLeadership(); + } + } + + public synchronized String getAddress() { + if (confirmationFuture.isDone()) { + return confirmationFuture.join().getAddress(); + } else { + throw new IllegalStateException("TestingLeaderElectionService has not been started."); + } + } + + /** + * Returns the start future indicating whether this leader election service has been started or + * not. + * + * @return Future which is completed once this service has been started + */ + public synchronized CompletableFuture getStartFuture() { + return startFuture; + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingListener.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingListener.java new file mode 100644 index 00000000..ba19747a --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingListener.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderelection; + +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalListener; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Test {@link LeaderRetrievalListener} implementation which offers some convenience functions for + * testing purposes. + */ +public class TestingListener extends TestingRetrievalBase implements LeaderRetrievalListener { + + private static final Logger LOG = LoggerFactory.getLogger(TestingListener.class); + + @Override + public void notifyLeaderAddress(LeaderInformation leaderInfo) { + LOG.info("Notified about new leader {}.", leaderInfo); + offerToLeaderQueue(leaderInfo); + } + + @Override + public void handleError(Exception exception) { + super.handleError(exception); + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingRetrievalBase.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingRetrievalBase.java new file mode 100644 index 00000000..afa3fb43 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/TestingRetrievalBase.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderelection; + +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalListener; +import com.alibaba.flink.shuffle.coordinator.utils.CommonTestUtils; + +import javax.annotation.Nullable; + +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** + * Base class which provides some convenience functions for testing purposes of {@link + * LeaderRetrievalListener}. + */ +public class TestingRetrievalBase { + + private final BlockingQueue leaderEventQueue = new LinkedBlockingQueue<>(); + private final BlockingQueue errorQueue = new LinkedBlockingQueue<>(); + + private LeaderInformation leader = LeaderInformation.empty(); + private String oldAddress; + private Throwable error; + + public String waitForNewLeader(long timeout) throws Exception { + throwExceptionIfNotNull(); + + final String errorMsg = + "Listener was not notified about a new leader within " + timeout + "ms"; + CommonTestUtils.waitUntilCondition( + () -> { + leader = leaderEventQueue.poll(timeout, TimeUnit.MILLISECONDS); + return leader != null + && !leader.isEmpty() + && !leader.getLeaderAddress().equals(oldAddress); + }, + timeout, + errorMsg); + + oldAddress = leader.getLeaderAddress(); + + return leader.getLeaderAddress(); + } + + public void waitForEmptyLeaderInformation(long timeout) throws Exception { + throwExceptionIfNotNull(); + + final String errorMsg = + "Listener was not notified about an empty leader within " + timeout + "ms"; + CommonTestUtils.waitUntilCondition( + () -> { + leader = leaderEventQueue.poll(timeout, TimeUnit.MILLISECONDS); + return leader != null && leader.isEmpty(); + }, + timeout, + errorMsg); + + oldAddress = null; + } + + public void waitForError(long timeout) throws Exception { + final String errorMsg = "Listener did not see an exception with " + timeout + "ms"; + CommonTestUtils.waitUntilCondition( + () -> { + error = errorQueue.poll(timeout, TimeUnit.MILLISECONDS); + return error != null; + }, + timeout, + errorMsg); + } + + public void handleError(Throwable ex) { + errorQueue.offer(ex); + } + + public LeaderInformation getLeader() { + return leader; + } + + public String getAddress() { + return leader.getLeaderAddress(); + } + + public UUID getLeaderSessionID() { + return leader.getLeaderSessionID(); + } + + public void offerToLeaderQueue(LeaderInformation leaderInfo) { + leaderEventQueue.offer(leaderInfo); + this.leader = leaderInfo; + } + + public int getLeaderEventQueueSize() { + return leaderEventQueue.size(); + } + + /** + * Please use {@link #waitForError} before get the error. + * + * @return the error has been handled. + */ + @Nullable + public Throwable getError() { + return this.error; + } + + private void throwExceptionIfNotNull() throws Exception { + if (error != null) { + ExceptionUtils.rethrowException(error); + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/ZooKeeperLeaderElectionConnectionHandlingTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/ZooKeeperLeaderElectionConnectionHandlingTest.java new file mode 100644 index 00000000..ca576ef9 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/ZooKeeperLeaderElectionConnectionHandlingTest.java @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderelection; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; +import com.alibaba.flink.shuffle.coordinator.highavailability.DirectlyFailingFatalErrorHandler; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalDriver; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalEventHandler; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperHaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperMultiLeaderRetrievalDriver; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperSingleLeaderRetrievalDriver; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperUtils; +import com.alibaba.flink.shuffle.coordinator.utils.CommonTestUtils; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; +import com.alibaba.flink.shuffle.core.utils.TestLogger; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; + +import org.apache.curator.test.TestingServer; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.time.Duration; +import java.util.UUID; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +/** Tests for the error handling in case of a suspended connection to the ZooKeeper instance. */ +@RunWith(Parameterized.class) +public class ZooKeeperLeaderElectionConnectionHandlingTest extends TestLogger { + + private TestingServer testingServer; + + private CuratorFramework zooKeeperClient; + + private final FatalErrorHandler fatalErrorHandler = DirectlyFailingFatalErrorHandler.INSTANCE; + + public final HaServices.LeaderReceptor leaderReceptor; + + private final String retrievalPath = + ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID.defaultValue() + + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH; + + @Parameterized.Parameters(name = "leader receptor ={0}") + public static Object[] parameter() { + return HaServices.LeaderReceptor.values(); + } + + public ZooKeeperLeaderElectionConnectionHandlingTest(HaServices.LeaderReceptor leaderReceptor) { + this.leaderReceptor = leaderReceptor; + } + + @Before + public void before() throws Exception { + testingServer = new TestingServer(); + + Configuration config = new Configuration(); + config.setString(HighAvailabilityOptions.HA_MODE, "zookeeper"); + config.setString( + HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM, testingServer.getConnectString()); + config.setDuration( + HighAvailabilityOptions.ZOOKEEPER_SESSION_TIMEOUT, Duration.ofSeconds(10)); + + zooKeeperClient = ZooKeeperUtils.startCuratorFramework(config); + } + + @After + public void after() throws Exception { + closeTestServer(); + + if (zooKeeperClient != null) { + zooKeeperClient.close(); + zooKeeperClient = null; + } + } + + @Test + public void testConnectionSuspendedWhenMultipleLeaderSelection() throws Exception { + int numLeaders = 10; + int leaderIndex = 5; + for (int i = 0; i < numLeaders; ++i) { + LeaderInformation leaderInfo = + new LeaderInformation( + i, + i <= leaderIndex ? 0 : leaderIndex, + UUID.randomUUID(), + "test address " + i); + writeLeaderInformationToZooKeeper( + "/cluster-" + i + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH, + leaderInfo); + } + + Thread.sleep(2000); + + QueueLeaderListener leaderListener = new QueueLeaderListener(10); + try (LeaderRetrievalDriver ignored = + createLeaderRetrievalDriver( + HaServices.LeaderReceptor.SHUFFLE_CLIENT, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH, + leaderListener)) { + assertEquals( + "test address " + leaderIndex, leaderListener.next().get().getLeaderAddress()); + + testingServer.stop(); + // connection suspended + assertEquals(LeaderInformation.empty(), leaderListener.next().get()); + + testingServer.restart(); + LeaderInformation leaderInfo = leaderListener.next().get(); + if (leaderInfo.isEmpty()) { + leaderInfo = leaderListener.next().get(); + } + assertEquals("test address " + leaderIndex, leaderInfo.getLeaderAddress()); + + testingServer.stop(); + // connection lost + assertEquals(LeaderInformation.empty(), leaderListener.next().get()); + assertEquals(LeaderInformation.empty(), leaderListener.next().get()); + + testingServer.restart(); + assertEquals( + "test address " + leaderIndex, leaderListener.next().get().getLeaderAddress()); + } + } + + @Test + public void testConnectionSuspendedHandlingDuringInitialization() throws Exception { + final QueueLeaderListener leaderListener = new QueueLeaderListener(1); + try (LeaderRetrievalDriver ignored = createLeaderRetrievalDriver(leaderListener)) { + + // do the testing + final CompletableFuture firstAddress = + leaderListener.next(Duration.ofMillis(50)); + assertNull(firstAddress); + + closeTestServer(); + + // QueueLeaderElectionListener will be notified with an empty leader when ZK connection + // is suspended + final CompletableFuture secondAddress = leaderListener.next(); + assertNotNull(secondAddress); + assertEquals(LeaderInformation.empty(), secondAddress.get()); + } + } + + @Test + public void testConnectionSuspendedHandling() throws Exception { + final String leaderAddress = "localhost"; + final QueueLeaderListener leaderListener = new QueueLeaderListener(1); + try (LeaderRetrievalDriver ignored = createLeaderRetrievalDriver(leaderListener)) { + + LeaderInformation leaderInfo = new LeaderInformation(UUID.randomUUID(), leaderAddress); + writeLeaderInformationToZooKeeper(leaderInfo); + + // do the testing + CompletableFuture firstAddress = leaderListener.next(); + assertEquals(leaderInfo, firstAddress.get()); + + closeTestServer(); + + CompletableFuture secondAddress = leaderListener.next(); + assertNotNull(secondAddress); + assertEquals(LeaderInformation.empty(), secondAddress.get()); + } + } + + @Test + public void testSameLeaderAfterReconnectTriggersListenerNotification() throws Exception { + final QueueLeaderListener leaderListener = new QueueLeaderListener(1); + try (LeaderRetrievalDriver ignored = createLeaderRetrievalDriver(leaderListener)) { + + String leaderAddress = "foobar"; + UUID sessionId = UUID.randomUUID(); + LeaderInformation leaderInfo = new LeaderInformation(sessionId, leaderAddress); + writeLeaderInformationToZooKeeper(leaderInfo); + + // pop new leader + leaderListener.next(); + + testingServer.stop(); + + final CompletableFuture connectionSuspension = leaderListener.next(); + + // wait until the ZK connection is suspended + connectionSuspension.join(); + + testingServer.restart(); + + // new old leader information should be announced + final CompletableFuture connectionReconnect = leaderListener.next(); + assertEquals(leaderInfo, connectionReconnect.get()); + } + } + + private void writeLeaderInformationToZooKeeper(LeaderInformation leaderInfo) throws Exception { + writeLeaderInformationToZooKeeper(retrievalPath, leaderInfo); + } + + private void writeLeaderInformationToZooKeeper( + String retrievalPath, LeaderInformation leaderInfo) throws Exception { + final byte[] data = leaderInfo.toByteArray(); + if (zooKeeperClient.checkExists().forPath(retrievalPath) != null) { + zooKeeperClient.setData().forPath(retrievalPath, data); + } else { + zooKeeperClient.create().creatingParentsIfNeeded().forPath(retrievalPath, data); + } + } + + @Test + public void testNewLeaderAfterReconnectTriggersListenerNotification() throws Exception { + final QueueLeaderListener leaderListener = new QueueLeaderListener(1); + + try (LeaderRetrievalDriver ignored = createLeaderRetrievalDriver(leaderListener)) { + + final String leaderAddress = "foobar"; + final UUID sessionId = UUID.randomUUID(); + writeLeaderInformationToZooKeeper(new LeaderInformation(sessionId, leaderAddress)); + + // pop new leader + leaderListener.next(); + + testingServer.stop(); + + final CompletableFuture connectionSuspension = leaderListener.next(); + + // wait until the ZK connection is suspended + connectionSuspension.join(); + + testingServer.restart(); + + String newLeaderAddress = "barfoo"; + UUID newSessionId = UUID.randomUUID(); + LeaderInformation leaderInfo = new LeaderInformation(newSessionId, newLeaderAddress); + writeLeaderInformationToZooKeeper(leaderInfo); + + // check that we find the new leader information eventually + CommonTestUtils.waitUntilCondition( + () -> { + final CompletableFuture afterConnectionReconnect = + leaderListener.next(); + return afterConnectionReconnect.get().equals(leaderInfo); + }, + 30L * 1000); + } + } + + private LeaderRetrievalDriver createLeaderRetrievalDriver( + LeaderRetrievalEventHandler leaderListener) throws Exception { + return createLeaderRetrievalDriver(leaderReceptor, retrievalPath, leaderListener); + } + + private LeaderRetrievalDriver createLeaderRetrievalDriver( + HaServices.LeaderReceptor leaderReceptor, + String retrievalPath, + LeaderRetrievalEventHandler leaderListener) + throws Exception { + switch (leaderReceptor) { + case SHUFFLE_WORKER: + return new ZooKeeperSingleLeaderRetrievalDriver( + zooKeeperClient, retrievalPath, leaderListener, fatalErrorHandler); + case SHUFFLE_CLIENT: + return new ZooKeeperMultiLeaderRetrievalDriver( + zooKeeperClient, retrievalPath, leaderListener, fatalErrorHandler); + default: + throw new Exception("Unknown leader receptor type: " + leaderReceptor); + } + } + + private void closeTestServer() throws IOException { + if (testingServer != null) { + testingServer.close(); + testingServer = null; + } + } + + private static class QueueLeaderListener implements LeaderRetrievalEventHandler { + + private final BlockingQueue> queue; + + public QueueLeaderListener(int expectedCalls) { + this.queue = new ArrayBlockingQueue<>(expectedCalls); + } + + @Override + public void notifyLeaderAddress(LeaderInformation leaderInfo) { + try { + queue.put(CompletableFuture.completedFuture(leaderInfo)); + } catch (InterruptedException e) { + throw new IllegalStateException(e); + } + } + + public CompletableFuture next() { + return next(null); + } + + public CompletableFuture next(@Nullable Duration timeout) { + try { + if (timeout == null) { + return queue.take(); + } else { + return this.queue.poll(timeout.toMillis(), TimeUnit.MILLISECONDS); + } + } catch (InterruptedException e) { + throw new IllegalStateException(e); + } + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/ZooKeeperLeaderElectionTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/ZooKeeperLeaderElectionTest.java new file mode 100644 index 00000000..91f9cc93 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/ZooKeeperLeaderElectionTest.java @@ -0,0 +1,690 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderelection; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.coordinator.highavailability.DefaultLeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.DefaultLeaderRetrievalService; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderContender; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionDriver; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalDriver; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperHaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperLeaderElectionDriver; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperSingleLeaderRetrievalDriver; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperUtils; +import com.alibaba.flink.shuffle.coordinator.leaderretrieval.TestingLeaderRetrievalEventHandler; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; +import com.alibaba.flink.shuffle.core.utils.TestLogger; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.api.CreateBuilder; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.ChildData; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.NodeCache; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.NodeCacheListener; +import org.apache.flink.shaded.zookeeper3.org.apache.zookeeper.CreateMode; +import org.apache.flink.shaded.zookeeper3.org.apache.zookeeper.KeeperException; + +import org.apache.curator.test.TestingServer; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Matchers; +import org.mockito.Mockito; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; + +import java.io.IOException; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +/** + * Tests for the {@link ZooKeeperLeaderElectionDriver} and the {@link + * ZooKeeperSingleLeaderRetrievalDriver}. To directly test the {@link ZooKeeperLeaderElectionDriver} + * and {@link ZooKeeperSingleLeaderRetrievalDriver}, some simple tests will use {@link + * TestingLeaderElectionEventHandler} which will not write the leader information to ZooKeeper. For + * the complicated tests(e.g. multiple leaders), we will use {@link DefaultLeaderElectionService} + * with {@link TestingContender}. + */ +public class ZooKeeperLeaderElectionTest extends TestLogger { + + private static final Logger LOG = LoggerFactory.getLogger(ZooKeeperLeaderElectionTest.class); + + private TestingServer testingServer; + + private Configuration configuration; + + private CuratorFramework client; + + private static final String TEST_URL = "akka//user/shufflmanager"; + + private static final LeaderInformation TEST_LEADER = + new LeaderInformation(UUID.randomUUID(), TEST_URL); + + private static final long timeout = 200L * 1000L; + + @Before + public void before() { + try { + testingServer = new TestingServer(); + } catch (Exception e) { + throw new RuntimeException("Could not start ZooKeeper testing cluster.", e); + } + + configuration = new Configuration(); + + configuration.setString( + HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM, testingServer.getConnectString()); + configuration.setString(HighAvailabilityOptions.HA_MODE, "zookeeper"); + + client = ZooKeeperUtils.startCuratorFramework(configuration); + } + + @After + public void after() throws IOException { + if (client != null) { + client.close(); + client = null; + } + + if (testingServer != null) { + testingServer.stop(); + testingServer = null; + } + } + + /** Tests that the ZooKeeperLeaderElection/RetrievalService return both the correct URL. */ + @Test + public void testZooKeeperLeaderElectionRetrieval() throws Exception { + + final TestingLeaderElectionEventHandler electionEventHandler = + new TestingLeaderElectionEventHandler(TEST_LEADER); + final TestingLeaderRetrievalEventHandler retrievalEventHandler = + new TestingLeaderRetrievalEventHandler(); + LeaderElectionDriver leaderElectionDriver = null; + LeaderRetrievalDriver leaderRetrievalDriver = null; + try { + + leaderElectionDriver = createAndInitLeaderElectionDriver(client, electionEventHandler); + leaderRetrievalDriver = + ZooKeeperUtils.createLeaderRetrievalDriverFactory( + client, + configuration, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH, + HaServices.LeaderReceptor.SHUFFLE_WORKER) + .createLeaderRetrievalDriver( + retrievalEventHandler, retrievalEventHandler::handleError); + + electionEventHandler.waitForLeader(timeout); + assertThat(electionEventHandler.getConfirmedLeaderInformation(), is(TEST_LEADER)); + + retrievalEventHandler.waitForNewLeader(timeout); + + assertThat( + retrievalEventHandler.getLeaderSessionID(), + is(TEST_LEADER.getLeaderSessionID())); + assertThat(retrievalEventHandler.getAddress(), is(TEST_LEADER.getLeaderAddress())); + } finally { + if (leaderElectionDriver != null) { + leaderElectionDriver.close(); + } + if (leaderRetrievalDriver != null) { + leaderRetrievalDriver.close(); + } + } + } + + /** + * Tests repeatedly the reelection of still available LeaderContender. After a contender has + * been elected as the leader, it is removed. This forces the DefaultLeaderElectionService to + * elect a new leader. + */ + @Test + public void testZooKeeperReelection() throws Exception { + long deadlineTime = System.nanoTime() + 5L * 60 * 1000000000; + int num = 10; + + DefaultLeaderElectionService[] leaderElectionService = + new DefaultLeaderElectionService[num]; + TestingContender[] contenders = new TestingContender[num]; + DefaultLeaderRetrievalService leaderRetrievalService = null; + + TestingListener listener = new TestingListener(); + + try { + leaderRetrievalService = + ZooKeeperUtils.createLeaderRetrievalService( + client, + configuration, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH, + HaServices.LeaderReceptor.SHUFFLE_WORKER); + + LOG.debug("Start leader retrieval service for the TestingListener."); + + leaderRetrievalService.start(listener); + + for (int i = 0; i < num; i++) { + leaderElectionService[i] = + ZooKeeperUtils.createLeaderElectionService( + client, + configuration, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_LATCH_PATH, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH); + contenders[i] = new TestingContender(createAddress(i), leaderElectionService[i]); + + LOG.debug("Start leader election service for contender #{}.", i); + + leaderElectionService[i].start(contenders[i]); + } + + String pattern = TEST_URL + "_" + "(\\d+)"; + Pattern regex = Pattern.compile(pattern); + + int numberSeenLeaders = 0; + long timeLeft = deadlineTime - System.nanoTime(); + while (timeLeft > 0 && numberSeenLeaders < num) { + LOG.debug("Wait for new leader #{}.", numberSeenLeaders); + String address = listener.waitForNewLeader(timeLeft); + + Matcher m = regex.matcher(address); + + if (m.find()) { + int index = Integer.parseInt(m.group(1)); + + TestingContender contender = contenders[index]; + + // check that the retrieval service has retrieved the correct leader + if (address.equals(createAddress(index)) + && listener.getLeaderSessionID() + .equals(contender.getLeaderSessionID())) { + // kill the election service of the leader + LOG.debug( + "Stop leader election service of contender #{}.", + numberSeenLeaders); + leaderElectionService[index].stop(); + leaderElectionService[index] = null; + + numberSeenLeaders++; + } + } else { + fail("Did not find the leader's index."); + } + } + + assertFalse( + "Did not complete the leader reelection in time.", + System.nanoTime() >= deadlineTime); + assertEquals(num, numberSeenLeaders); + + } finally { + if (leaderRetrievalService != null) { + leaderRetrievalService.stop(); + } + + for (DefaultLeaderElectionService electionService : leaderElectionService) { + if (electionService != null) { + electionService.stop(); + } + } + } + } + + @Nonnull + private String createAddress(int i) { + return TEST_URL + "_" + i; + } + + /** + * Tests the repeated reelection of {@link LeaderContender} once the current leader dies. + * Furthermore, it tests that new LeaderElectionServices can be started later on and that they + * successfully register at ZooKeeper and take part in the leader election. + */ + @Test + public void testZooKeeperReelectionWithReplacement() throws Exception { + int num = 3; + int numTries = 30; + + DefaultLeaderElectionService[] leaderElectionService = + new DefaultLeaderElectionService[num]; + TestingContender[] contenders = new TestingContender[num]; + DefaultLeaderRetrievalService leaderRetrievalService = null; + + TestingListener listener = new TestingListener(); + + try { + leaderRetrievalService = + ZooKeeperUtils.createLeaderRetrievalService( + client, + configuration, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH, + HaServices.LeaderReceptor.SHUFFLE_WORKER); + + leaderRetrievalService.start(listener); + + for (int i = 0; i < num; i++) { + leaderElectionService[i] = + ZooKeeperUtils.createLeaderElectionService( + client, + configuration, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_LATCH_PATH, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH); + contenders[i] = + new TestingContender(TEST_URL + "_" + i + "_0", leaderElectionService[i]); + + leaderElectionService[i].start(contenders[i]); + } + + String pattern = TEST_URL + "_" + "(\\d+)" + "_" + "(\\d+)"; + Pattern regex = Pattern.compile(pattern); + + for (int i = 0; i < numTries; i++) { + listener.waitForNewLeader(timeout); + + String address = listener.getAddress(); + + Matcher m = regex.matcher(address); + + if (m.find()) { + int index = Integer.parseInt(m.group(1)); + int lastTry = Integer.parseInt(m.group(2)); + + assertEquals( + listener.getLeaderSessionID(), contenders[index].getLeaderSessionID()); + + // stop leader election service = revoke leadership + leaderElectionService[index].stop(); + // create new leader election service which takes part in the leader election + leaderElectionService[index] = + ZooKeeperUtils.createLeaderElectionService( + client, + configuration, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_LATCH_PATH, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH); + contenders[index] = + new TestingContender( + TEST_URL + "_" + index + "_" + (lastTry + 1), + leaderElectionService[index]); + + leaderElectionService[index].start(contenders[index]); + } else { + throw new Exception("Did not find the leader's index."); + } + } + + } finally { + if (leaderRetrievalService != null) { + leaderRetrievalService.stop(); + } + + for (DefaultLeaderElectionService electionService : leaderElectionService) { + if (electionService != null) { + electionService.stop(); + } + } + } + } + + /** + * Tests that the current leader is notified when his leader connection information in ZooKeeper + * are overwritten. The leader must re-establish the correct leader connection information in + * ZooKeeper. + */ + @Test + public void testLeaderShouldBeCorrectedWhenOverwritten() throws Exception { + final String faultyContenderUrl = "faultyContender"; + final String leaderPath = + configuration.getString(ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID) + + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH; + + final TestingLeaderElectionEventHandler electionEventHandler = + new TestingLeaderElectionEventHandler(TEST_LEADER); + final TestingLeaderRetrievalEventHandler retrievalEventHandler = + new TestingLeaderRetrievalEventHandler(); + + LeaderElectionDriver leaderElectionDriver = null; + LeaderRetrievalDriver leaderRetrievalDriver = null; + + CuratorFramework anotherClient = null; + + try { + + leaderElectionDriver = createAndInitLeaderElectionDriver(client, electionEventHandler); + + electionEventHandler.waitForLeader(timeout); + assertThat(electionEventHandler.getConfirmedLeaderInformation(), is(TEST_LEADER)); + + anotherClient = ZooKeeperUtils.startCuratorFramework(configuration); + + LeaderInformation leaderInfo = + new LeaderInformation(UUID.randomUUID(), faultyContenderUrl); + // overwrite the current leader address, the leader should notice that + boolean dataWritten = false; + + while (!dataWritten) { + anotherClient.delete().forPath(leaderPath); + + try { + anotherClient.create().forPath(leaderPath, leaderInfo.toByteArray()); + + dataWritten = true; + } catch (KeeperException.NodeExistsException e) { + // this can happen if the leader election service was faster + } + } + + // The faulty leader should be corrected on ZooKeeper + leaderRetrievalDriver = + ZooKeeperUtils.createLeaderRetrievalDriverFactory( + client, + configuration, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH, + HaServices.LeaderReceptor.SHUFFLE_WORKER) + .createLeaderRetrievalDriver( + retrievalEventHandler, retrievalEventHandler::handleError); + + if (retrievalEventHandler.waitForNewLeader(timeout).equals(faultyContenderUrl)) { + retrievalEventHandler.waitForNewLeader(timeout); + } + + assertThat( + retrievalEventHandler.getLeaderSessionID(), + is(TEST_LEADER.getLeaderSessionID())); + assertThat(retrievalEventHandler.getAddress(), is(TEST_LEADER.getLeaderAddress())); + } finally { + if (leaderElectionDriver != null) { + leaderElectionDriver.close(); + } + if (leaderRetrievalDriver != null) { + leaderRetrievalDriver.close(); + } + if (anotherClient != null) { + anotherClient.close(); + } + } + } + + /** + * Test that errors in the {@link LeaderElectionService} are correctly forwarded to the {@link + * LeaderContender}. + */ + @Test + public void testExceptionForwarding() throws Exception { + LeaderElectionDriver leaderElectionDriver = null; + final TestingLeaderElectionEventHandler electionEventHandler = + new TestingLeaderElectionEventHandler(TEST_LEADER); + + CuratorFramework client = null; + final CreateBuilder mockCreateBuilder = + mock(CreateBuilder.class, Mockito.RETURNS_DEEP_STUBS); + final String exMsg = "Test exception"; + final Exception testException = new Exception(exMsg); + + try { + client = spy(ZooKeeperUtils.startCuratorFramework(configuration)); + + doAnswer(invocation -> mockCreateBuilder).when(client).create(); + + when(mockCreateBuilder + .creatingParentsIfNeeded() + .withMode(Matchers.any(CreateMode.class)) + .forPath(anyString(), any(byte[].class))) + .thenThrow(testException); + + leaderElectionDriver = createAndInitLeaderElectionDriver(client, electionEventHandler); + + electionEventHandler.waitForError(timeout); + + assertNotNull(electionEventHandler.getError()); + assertTrue(electionEventHandler.getError().getCause().getMessage().contains(exMsg)); + } finally { + if (leaderElectionDriver != null) { + leaderElectionDriver.close(); + } + + if (client != null) { + client.close(); + } + } + } + + /** + * Tests that there is no information left in the ZooKeeper cluster after the ZooKeeper client + * has terminated. In other words, checks that the ZooKeeperLeaderElection service uses + * ephemeral nodes. + */ + @Test + public void testEphemeralZooKeeperNodes() throws Exception { + LeaderElectionDriver leaderElectionDriver = null; + LeaderRetrievalDriver leaderRetrievalDriver = null; + final TestingLeaderElectionEventHandler electionEventHandler = + new TestingLeaderElectionEventHandler(TEST_LEADER); + final TestingLeaderRetrievalEventHandler retrievalEventHandler = + new TestingLeaderRetrievalEventHandler(); + + CuratorFramework client = null; + CuratorFramework client2 = null; + NodeCache cache = null; + + try { + client = ZooKeeperUtils.startCuratorFramework(configuration); + client2 = ZooKeeperUtils.startCuratorFramework(configuration); + + leaderElectionDriver = createAndInitLeaderElectionDriver(client, electionEventHandler); + leaderRetrievalDriver = + ZooKeeperUtils.createLeaderRetrievalDriverFactory( + client2, + configuration, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH, + HaServices.LeaderReceptor.SHUFFLE_WORKER) + .createLeaderRetrievalDriver( + retrievalEventHandler, retrievalEventHandler::handleError); + + final String leaderPath = + configuration.getString(ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID) + + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH; + cache = new NodeCache(client2, leaderPath); + + ExistsCacheListener existsListener = new ExistsCacheListener(cache); + DeletedCacheListener deletedCacheListener = new DeletedCacheListener(cache); + + cache.getListenable().addListener(existsListener); + cache.start(); + + electionEventHandler.waitForLeader(timeout); + + retrievalEventHandler.waitForNewLeader(timeout); + + Future existsFuture = existsListener.nodeExists(); + + existsFuture.get(timeout, TimeUnit.MILLISECONDS); + + cache.getListenable().addListener(deletedCacheListener); + + leaderElectionDriver.close(); + + // now stop the underlying client + client.close(); + + Future deletedFuture = deletedCacheListener.nodeDeleted(); + + // make sure that the leader node has been deleted + deletedFuture.get(timeout, TimeUnit.MILLISECONDS); + + try { + retrievalEventHandler.waitForNewLeader(1000L); + + fail( + "TimeoutException was expected because there is no leader registered and " + + "thus there shouldn't be any leader information in ZooKeeper."); + } catch (TimeoutException e) { + // that was expected + } + } finally { + if (leaderRetrievalDriver != null) { + leaderRetrievalDriver.close(); + } + + if (cache != null) { + cache.close(); + } + + if (client2 != null) { + client2.close(); + } + } + } + + @Test + public void testNotLeaderShouldNotCleanUpTheLeaderInformation() throws Exception { + + final TestingLeaderElectionEventHandler electionEventHandler = + new TestingLeaderElectionEventHandler(TEST_LEADER); + final TestingLeaderRetrievalEventHandler retrievalEventHandler = + new TestingLeaderRetrievalEventHandler(); + ZooKeeperLeaderElectionDriver leaderElectionDriver = null; + LeaderRetrievalDriver leaderRetrievalDriver = null; + + try { + leaderElectionDriver = createAndInitLeaderElectionDriver(client, electionEventHandler); + + electionEventHandler.waitForLeader(timeout); + assertThat(electionEventHandler.getConfirmedLeaderInformation(), is(TEST_LEADER)); + + // Leader is revoked + leaderElectionDriver.notLeader(); + electionEventHandler.waitForRevokeLeader(timeout); + assertThat( + electionEventHandler.getConfirmedLeaderInformation(), + is(LeaderInformation.empty())); + // The data on ZooKeeper it not be cleared + leaderRetrievalDriver = + ZooKeeperUtils.createLeaderRetrievalDriverFactory( + client, + configuration, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH, + HaServices.LeaderReceptor.SHUFFLE_WORKER) + .createLeaderRetrievalDriver( + retrievalEventHandler, retrievalEventHandler::handleError); + + retrievalEventHandler.waitForNewLeader(timeout); + + assertThat( + retrievalEventHandler.getLeaderSessionID(), + is(TEST_LEADER.getLeaderSessionID())); + assertThat(retrievalEventHandler.getAddress(), is(TEST_LEADER.getLeaderAddress())); + } finally { + if (leaderElectionDriver != null) { + leaderElectionDriver.close(); + } + if (leaderRetrievalDriver != null) { + leaderRetrievalDriver.close(); + } + } + } + + private static class ExistsCacheListener implements NodeCacheListener { + + final CompletableFuture existsPromise = new CompletableFuture<>(); + + final NodeCache cache; + + public ExistsCacheListener(final NodeCache cache) { + this.cache = cache; + } + + public Future nodeExists() { + return existsPromise; + } + + @Override + public void nodeChanged() throws Exception { + ChildData data = cache.getCurrentData(); + + if (data != null && !existsPromise.isDone()) { + existsPromise.complete(true); + cache.getListenable().removeListener(this); + } + } + } + + private static class DeletedCacheListener implements NodeCacheListener { + + final CompletableFuture deletedPromise = new CompletableFuture<>(); + + final NodeCache cache; + + public DeletedCacheListener(final NodeCache cache) { + this.cache = cache; + } + + public Future nodeDeleted() { + return deletedPromise; + } + + @Override + public void nodeChanged() throws Exception { + ChildData data = cache.getCurrentData(); + + if (data == null && !deletedPromise.isDone()) { + deletedPromise.complete(true); + cache.getListenable().removeListener(this); + } + } + } + + private ZooKeeperLeaderElectionDriver createAndInitLeaderElectionDriver( + CuratorFramework client, TestingLeaderElectionEventHandler electionEventHandler) + throws Exception { + + final ZooKeeperLeaderElectionDriver leaderElectionDriver = + ZooKeeperUtils.createLeaderElectionDriverFactory( + client, + configuration, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_LATCH_PATH, + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH) + .createLeaderElectionDriver( + electionEventHandler, electionEventHandler::handleError, TEST_URL); + electionEventHandler.init(leaderElectionDriver); + return leaderElectionDriver; + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/ZooKeeperLeaderRetrievalTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/ZooKeeperLeaderRetrievalTest.java new file mode 100644 index 00000000..2a801f2c --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderelection/ZooKeeperLeaderRetrievalTest.java @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderelection; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalService; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperHaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperUtils; +import com.alibaba.flink.shuffle.coordinator.utils.LeaderRetrievalUtils; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; +import com.alibaba.flink.shuffle.core.utils.TestLogger; +import com.alibaba.flink.shuffle.rpc.utils.AkkaRpcServiceUtils; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; + +import org.apache.curator.test.TestingServer; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketAddress; +import java.net.UnknownHostException; +import java.time.Duration; +import java.util.UUID; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** Tests for the ZooKeeper based leader election and retrieval. */ +@RunWith(Parameterized.class) +public class ZooKeeperLeaderRetrievalTest extends TestLogger { + + private TestingServer testingServer; + + private HaServices haServices; + + public final HaServices.LeaderReceptor leaderReceptor; + + @Parameterized.Parameters(name = "leader receptor ={0}") + public static Object[] parameter() { + return HaServices.LeaderReceptor.values(); + } + + public ZooKeeperLeaderRetrievalTest(HaServices.LeaderReceptor leaderReceptor) { + this.leaderReceptor = leaderReceptor; + } + + @Before + public void before() throws Exception { + testingServer = new TestingServer(); + haServices = createHaService(ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID.defaultValue()); + AkkaRpcServiceUtils.loadRpcSystem(new Configuration()); + } + + @After + public void after() throws Exception { + if (haServices != null) { + haServices.closeAndCleanupAllData(); + + haServices = null; + } + + if (testingServer != null) { + testingServer.stop(); + + testingServer = null; + } + + AkkaRpcServiceUtils.closeRpcSystem(); + } + + private HaServices createHaService(String clusterID) { + Configuration config = new Configuration(); + return new ZooKeeperHaServices(config, createZooKeeperClient(config, clusterID)); + } + + private CuratorFramework createZooKeeperClient(Configuration config, String clusterID) { + config.setString(HighAvailabilityOptions.HA_MODE, "zookeeper"); + config.setString( + HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM, testingServer.getConnectString()); + config.setString(ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID, clusterID); + return ZooKeeperUtils.startCuratorFramework(config); + } + + @Test + public void testLeaderRetrievalAfterLeaderElection() throws Exception { + LeaderElectionService leaderElectionService = haServices.createLeaderElectionService(); + String address = "test leader address"; + TestingContender testingContender = new TestingContender(address, leaderElectionService); + leaderElectionService.start(testingContender); + + LeaderRetrievalService leaderRetrievalService = + haServices.createLeaderRetrievalService(leaderReceptor); + TestingListener testingListener = new TestingListener(); + leaderRetrievalService.start(testingListener); + + assertNotNull(address, testingListener.waitForNewLeader(60000)); + + leaderElectionService.stop(); + leaderRetrievalService.stop(); + } + + @Test + public void testLeaderRetrievalBeforeLeaderElection() throws Exception { + LeaderRetrievalService leaderRetrievalService = + haServices.createLeaderRetrievalService(leaderReceptor); + TestingListener testingListener = new TestingListener(); + leaderRetrievalService.start(testingListener); + + LeaderElectionService leaderElectionService = haServices.createLeaderElectionService(); + String address = "test leader address"; + TestingContender testingContender = new TestingContender(address, leaderElectionService); + leaderElectionService.start(testingContender); + + assertNotNull(address, testingListener.waitForNewLeader(60000)); + + leaderElectionService.stop(); + leaderRetrievalService.stop(); + } + + @Test + public void testLeaderChange() throws Exception { + LeaderElectionService leaderElectionService1 = haServices.createLeaderElectionService(); + String address1 = "test leader address1"; + TestingContender testingContender1 = new TestingContender(address1, leaderElectionService1); + leaderElectionService1.start(testingContender1); + + LeaderRetrievalService leaderRetrievalService = + haServices.createLeaderRetrievalService(leaderReceptor); + TestingListener testingListener = new TestingListener(); + leaderRetrievalService.start(testingListener); + + assertNotNull(address1, testingListener.waitForNewLeader(60000)); + + LeaderElectionService leaderElectionService2 = haServices.createLeaderElectionService(); + String address2 = "test leader address2"; + TestingContender testingContender2 = new TestingContender(address2, leaderElectionService2); + leaderElectionService2.start(testingContender2); + + leaderElectionService1.stop(); + assertNotNull(address2, testingListener.waitForNewLeader(60000)); + } + + private void writeLeaderInformationToZooKeeper( + CuratorFramework client, String retrievalPath, LeaderInformation leaderInfo) + throws Exception { + final byte[] data = leaderInfo.toByteArray(); + if (client.checkExists().forPath(retrievalPath) != null) { + client.setData().forPath(retrievalPath, data); + } else { + client.create().creatingParentsIfNeeded().forPath(retrievalPath, data); + } + } + + @Test + public void testMultipleLeaderSelection() throws Exception { + int numLeaders = 10; + int leaderIndex = 5; + CuratorFramework client = createZooKeeperClient(new Configuration(), "ignored"); + for (int i = 0; i < numLeaders; ++i) { + LeaderInformation leaderInfo = + new LeaderInformation( + i, + i <= leaderIndex ? 0 : leaderIndex, + UUID.randomUUID(), + "test address " + i); + writeLeaderInformationToZooKeeper( + client, + "/cluster-" + i + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH, + leaderInfo); + } + + Thread.sleep(2000); + + LeaderRetrievalService leaderRetrievalService = + haServices.createLeaderRetrievalService(HaServices.LeaderReceptor.SHUFFLE_CLIENT); + TestingListener testingListener = new TestingListener(); + leaderRetrievalService.start(testingListener); + assertEquals("test address " + leaderIndex, testingListener.waitForNewLeader(60000)); + + String leaderPath = + "/cluster-" + + leaderIndex + + ZooKeeperHaServices.SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH; + LeaderInformation newLeaderInfo = new LeaderInformation(UUID.randomUUID(), "mew address"); + client.setData().forPath(leaderPath, newLeaderInfo.toByteArray()); + assertEquals("mew address", testingListener.waitForNewLeader(60000)); + + // remove the leader node + client.delete().forPath(leaderPath); + assertEquals("test address " + (leaderIndex - 1), testingListener.waitForNewLeader(60000)); + + leaderRetrievalService.stop(); + client.close(); + } + + /** + * Tests that LeaderRetrievalUtils.findConnectingAddress finds the correct connecting address in + * case of an old leader address in ZooKeeper and a subsequent election of a new leader. The + * findConnectingAddress should block until the new leader has been elected and his address has + * been written to ZooKeeper. + */ + @Test + public void testConnectingAddressRetrievalWithDelayedLeaderElection() throws Exception { + Duration timeout = Duration.ofMinutes(1L); + + long sleepingTime = 1000; + + LeaderElectionService leaderElectionService = null; + LeaderElectionService faultyLeaderElectionService; + + ServerSocket serverSocket; + InetAddress localHost; + + Thread thread; + + try { + String wrongAddress = + AkkaRpcServiceUtils.getRpcUrl( + "1.1.1.1", 1234, "foobar", AkkaRpcServiceUtils.AkkaProtocol.TCP); + + try { + localHost = InetAddress.getLocalHost(); + serverSocket = new ServerSocket(0, 50, localHost); + } catch (UnknownHostException e) { + // may happen if disconnected. skip test. + System.err.println("Skipping 'testNetworkInterfaceSelection' test."); + return; + } catch (IOException e) { + // may happen in certain test setups, skip test. + System.err.println("Skipping 'testNetworkInterfaceSelection' test."); + return; + } + + InetSocketAddress correctInetSocketAddress = + new InetSocketAddress(localHost, serverSocket.getLocalPort()); + + String correctAddress = + AkkaRpcServiceUtils.getRpcUrl( + localHost.getHostName(), + correctInetSocketAddress.getPort(), + "Test", + AkkaRpcServiceUtils.AkkaProtocol.TCP); + + faultyLeaderElectionService = haServices.createLeaderElectionService(); + TestingContender wrongLeaderAddressContender = + new TestingContender(wrongAddress, faultyLeaderElectionService); + + faultyLeaderElectionService.start(wrongLeaderAddressContender); + + FindConnectingAddress findConnectingAddress = + new FindConnectingAddress( + timeout, haServices.createLeaderRetrievalService(leaderReceptor)); + + thread = new Thread(findConnectingAddress); + + thread.start(); + + leaderElectionService = haServices.createLeaderElectionService(); + TestingContender correctLeaderAddressContender = + new TestingContender(correctAddress, leaderElectionService); + + Thread.sleep(sleepingTime); + + faultyLeaderElectionService.stop(); + + leaderElectionService.start(correctLeaderAddressContender); + + thread.join(); + + InetAddress result = findConnectingAddress.getInetAddress(); + + // check that we can connect to the localHost + Socket socket = new Socket(); + try { + // port 0 = let the OS choose the port + SocketAddress bindP = new InetSocketAddress(result, 0); + // machine + socket.bind(bindP); + socket.connect(correctInetSocketAddress, 1000); + } finally { + socket.close(); + } + } finally { + if (leaderElectionService != null) { + leaderElectionService.stop(); + } + } + } + + /** + * Tests that the LeaderRetrievalUtils.findConnectingAddress stops trying to find the connecting + * address if no leader address has been specified. The call should return then + * InetAddress.getLocalHost(). + */ + @Test + public void testTimeoutOfFindConnectingAddress() throws Exception { + Duration timeout = Duration.ofSeconds(1L); + + LeaderRetrievalService leaderRetrievalService = + haServices.createLeaderRetrievalService(leaderReceptor); + InetAddress result = + LeaderRetrievalUtils.findConnectingAddress(leaderRetrievalService, timeout); + + assertEquals(InetAddress.getLocalHost(), result); + } + + static class FindConnectingAddress implements Runnable { + + private final Duration timeout; + private final LeaderRetrievalService leaderRetrievalService; + + private InetAddress result; + private Exception exception; + + public FindConnectingAddress( + Duration timeout, LeaderRetrievalService leaderRetrievalService) { + this.timeout = timeout; + this.leaderRetrievalService = leaderRetrievalService; + } + + @Override + public void run() { + try { + result = + LeaderRetrievalUtils.findConnectingAddress(leaderRetrievalService, timeout); + } catch (Exception e) { + exception = e; + } + } + + public InetAddress getInetAddress() throws Exception { + if (exception != null) { + throw exception; + } else { + return result; + } + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderretrieval/DefaultLeaderRetrievalServiceTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderretrieval/DefaultLeaderRetrievalServiceTest.java new file mode 100644 index 00000000..41897a9b --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderretrieval/DefaultLeaderRetrievalServiceTest.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderretrieval; + +import com.alibaba.flink.shuffle.common.functions.RunnableWithException; +import com.alibaba.flink.shuffle.coordinator.highavailability.DefaultLeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.DefaultLeaderRetrievalService; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.leaderelection.TestingListener; +import com.alibaba.flink.shuffle.core.utils.TestLogger; + +import org.junit.Test; + +import java.util.UUID; +import java.util.concurrent.TimeoutException; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** Tests for {@link DefaultLeaderElectionService}. */ +public class DefaultLeaderRetrievalServiceTest extends TestLogger { + + private static final String TEST_URL = "akka//user/shufflemanager"; + private static final long timeout = 50L; + + @Test + public void testNotifyLeaderAddress() throws Exception { + new Context() { + { + runTest( + () -> { + final LeaderInformation newLeader = + new LeaderInformation(UUID.randomUUID(), TEST_URL); + testingLeaderRetrievalDriver.onUpdate(newLeader); + testingListener.waitForNewLeader(timeout); + assertThat( + testingListener.getLeaderSessionID(), + is(newLeader.getLeaderSessionID())); + assertThat( + testingListener.getAddress(), is(newLeader.getLeaderAddress())); + }); + } + }; + } + + @Test + public void testNotifyLeaderAddressEmpty() throws Exception { + new Context() { + { + runTest( + () -> { + final LeaderInformation newLeader = + new LeaderInformation(UUID.randomUUID(), TEST_URL); + testingLeaderRetrievalDriver.onUpdate(newLeader); + testingListener.waitForNewLeader(timeout); + + testingLeaderRetrievalDriver.onUpdate(LeaderInformation.empty()); + testingListener.waitForEmptyLeaderInformation(timeout); + assertSame(testingListener.getLeader(), LeaderInformation.empty()); + }); + } + }; + } + + @Test + public void testErrorForwarding() throws Exception { + new Context() { + { + runTest( + () -> { + final Exception testException = new Exception("test exception"); + + testingLeaderRetrievalDriver.onFatalError(testException); + + testingListener.waitForError(timeout); + assertTrue( + checkNotNull(testingListener.getError()) + .getMessage() + .contains("test exception")); + }); + } + }; + } + + @Test + public void testErrorIsIgnoredAfterBeingStop() throws Exception { + new Context() { + { + runTest( + () -> { + final Exception testException = new Exception("test exception"); + + leaderRetrievalService.stop(); + testingLeaderRetrievalDriver.onFatalError(testException); + + try { + testingListener.waitForError(timeout); + fail( + "We expect to have a timeout here because there's no error should be passed to listener."); + } catch (TimeoutException ex) { + // noop + } + assertThat(testingListener.getError(), is(nullValue())); + }); + } + }; + } + + @Test + public void testNotifyLeaderAddressOnlyWhenLeaderTrulyChanged() throws Exception { + new Context() { + { + runTest( + () -> { + final LeaderInformation newLeader = + new LeaderInformation(UUID.randomUUID(), TEST_URL); + testingLeaderRetrievalDriver.onUpdate(newLeader); + assertThat(testingListener.getLeaderEventQueueSize(), is(1)); + + // Same leader information should not be notified twice. + testingLeaderRetrievalDriver.onUpdate(newLeader); + assertThat(testingListener.getLeaderEventQueueSize(), is(1)); + + // Leader truly changed. + testingLeaderRetrievalDriver.onUpdate( + new LeaderInformation(UUID.randomUUID(), TEST_URL + 1)); + assertThat(testingListener.getLeaderEventQueueSize(), is(2)); + }); + } + }; + } + + private class Context { + private final TestingLeaderRetrievalDriver.TestingLeaderRetrievalDriverFactory + leaderRetrievalDriverFactory = + new TestingLeaderRetrievalDriver.TestingLeaderRetrievalDriverFactory(); + final DefaultLeaderRetrievalService leaderRetrievalService = + new DefaultLeaderRetrievalService(leaderRetrievalDriverFactory); + final TestingListener testingListener = new TestingListener(); + + TestingLeaderRetrievalDriver testingLeaderRetrievalDriver; + + void runTest(RunnableWithException testMethod) throws Exception { + leaderRetrievalService.start(testingListener); + + testingLeaderRetrievalDriver = leaderRetrievalDriverFactory.getCurrentRetrievalDriver(); + assertThat(testingLeaderRetrievalDriver, is(notNullValue())); + testMethod.run(); + + leaderRetrievalService.stop(); + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderretrieval/SettableLeaderRetrievalService.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderretrieval/SettableLeaderRetrievalService.java new file mode 100644 index 00000000..3b57aa6b --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderretrieval/SettableLeaderRetrievalService.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderretrieval; + +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalListener; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalService; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** + * {@link LeaderRetrievalService} implementation which directly forwards calls of notifyListener to + * the listener. + */ +public class SettableLeaderRetrievalService implements LeaderRetrievalService { + + private LeaderInformation leaderInfo; + + private LeaderRetrievalListener listener; + + public SettableLeaderRetrievalService() { + this(LeaderInformation.empty()); + } + + public SettableLeaderRetrievalService(LeaderInformation leaderInfo) { + this.leaderInfo = leaderInfo; + } + + @Override + public synchronized void start(LeaderRetrievalListener listener) throws Exception { + this.listener = checkNotNull(listener); + + if (leaderInfo != LeaderInformation.empty()) { + listener.notifyLeaderAddress(leaderInfo); + } + } + + @Override + public void stop() throws Exception {} + + public synchronized void notifyListener(LeaderInformation leaderInfo) { + this.leaderInfo = leaderInfo; + + if (listener != null) { + listener.notifyLeaderAddress(leaderInfo); + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderretrieval/SettableLeaderRetrievalServiceTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderretrieval/SettableLeaderRetrievalServiceTest.java new file mode 100644 index 00000000..68b87088 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderretrieval/SettableLeaderRetrievalServiceTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderretrieval; + +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.leaderelection.TestingListener; +import com.alibaba.flink.shuffle.core.utils.TestLogger; + +import org.junit.Before; +import org.junit.Test; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +/** Tests for {@link SettableLeaderRetrievalService}. */ +public class SettableLeaderRetrievalServiceTest extends TestLogger { + + private SettableLeaderRetrievalService settableLeaderRetrievalService; + + @Before + public void setup() { + settableLeaderRetrievalService = new SettableLeaderRetrievalService(); + } + + @Test + public void testNotifyListenerLater() throws Exception { + final String localhost = "localhost"; + settableLeaderRetrievalService.notifyListener( + new LeaderInformation(HaServices.DEFAULT_LEADER_ID, localhost)); + + final TestingListener listener = new TestingListener(); + settableLeaderRetrievalService.start(listener); + + assertThat(listener.getAddress(), equalTo(localhost)); + assertThat(listener.getLeaderSessionID(), equalTo(HaServices.DEFAULT_LEADER_ID)); + } + + @Test + public void testNotifyListenerImmediately() throws Exception { + final TestingListener listener = new TestingListener(); + settableLeaderRetrievalService.start(listener); + + final String localhost = "localhost"; + settableLeaderRetrievalService.notifyListener( + new LeaderInformation(HaServices.DEFAULT_LEADER_ID, localhost)); + + assertThat(listener.getAddress(), equalTo(localhost)); + assertThat(listener.getLeaderSessionID(), equalTo(HaServices.DEFAULT_LEADER_ID)); + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderretrieval/TestingLeaderRetrievalDriver.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderretrieval/TestingLeaderRetrievalDriver.java new file mode 100644 index 00000000..48714bb8 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderretrieval/TestingLeaderRetrievalDriver.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderretrieval; + +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalDriver; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalDriverFactory; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalEventHandler; + +import javax.annotation.Nullable; + +/** + * {@link LeaderRetrievalDriver} implementation which provides some convenience functions for + * testing purposes. + */ +public class TestingLeaderRetrievalDriver implements LeaderRetrievalDriver { + + private final LeaderRetrievalEventHandler leaderRetrievalEventHandler; + private final FatalErrorHandler fatalErrorHandler; + + private TestingLeaderRetrievalDriver( + LeaderRetrievalEventHandler leaderRetrievalEventHandler, + FatalErrorHandler fatalErrorHandler) { + this.leaderRetrievalEventHandler = leaderRetrievalEventHandler; + this.fatalErrorHandler = fatalErrorHandler; + } + + @Override + public void close() throws Exception { + // noop + } + + public void onUpdate(LeaderInformation newLeader) { + leaderRetrievalEventHandler.notifyLeaderAddress(newLeader); + } + + public void onFatalError(Throwable throwable) { + fatalErrorHandler.onFatalError(throwable); + } + + /** Factory for create {@link TestingLeaderRetrievalDriver}. */ + public static class TestingLeaderRetrievalDriverFactory + implements LeaderRetrievalDriverFactory { + + private TestingLeaderRetrievalDriver currentDriver; + + @Override + public LeaderRetrievalDriver createLeaderRetrievalDriver( + LeaderRetrievalEventHandler leaderEventHandler, + FatalErrorHandler fatalErrorHandler) { + currentDriver = new TestingLeaderRetrievalDriver(leaderEventHandler, fatalErrorHandler); + return currentDriver; + } + + @Nullable + public TestingLeaderRetrievalDriver getCurrentRetrievalDriver() { + return currentDriver; + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderretrieval/TestingLeaderRetrievalEventHandler.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderretrieval/TestingLeaderRetrievalEventHandler.java new file mode 100644 index 00000000..f8c96032 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/leaderretrieval/TestingLeaderRetrievalEventHandler.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.leaderretrieval; + +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalEventHandler; +import com.alibaba.flink.shuffle.coordinator.leaderelection.TestingRetrievalBase; + +/** + * Test {@link LeaderRetrievalEventHandler} implementation which offers some convenience functions + * for testing purposes. + */ +public class TestingLeaderRetrievalEventHandler extends TestingRetrievalBase + implements LeaderRetrievalEventHandler { + + @Override + public void notifyLeaderAddress(LeaderInformation leaderInfo) { + offerToLeaderQueue(leaderInfo); + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManagerHATest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManagerHATest.java new file mode 100644 index 00000000..85022675 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManagerHATest.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalService; +import com.alibaba.flink.shuffle.coordinator.leaderelection.TestingLeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker.AssignmentTracker; +import com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker.ChangedWorkerStatus; +import com.alibaba.flink.shuffle.coordinator.utils.TestingFatalErrorHandler; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerGateway; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.RegistrationID; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; +import com.alibaba.flink.shuffle.rpc.test.TestingRpcService; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ForkJoinPool; + +/** Tests for the ShuffleManager HA. */ +public class ShuffleManagerHATest { + + @Test + public void testGrantAndRevokeLeadership() throws Exception { + final InstanceID rmInstanceID = new InstanceID(); + final RemoteShuffleRpcService rpcService = new TestingRpcService(); + + final CompletableFuture leaderSessionIdFuture = new CompletableFuture<>(); + + final TestingLeaderElectionService leaderElectionService = + new TestingLeaderElectionService() { + @Override + public void confirmLeadership(LeaderInformation leaderInfo) { + leaderSessionIdFuture.complete(leaderInfo.getLeaderSessionID()); + } + }; + + final HaServices highAvailabilityServices = + new TestHaService() { + @Override + public LeaderElectionService createLeaderElectionService() { + return leaderElectionService; + } + }; + + final TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); + + final CompletableFuture revokedLeaderIdFuture = new CompletableFuture<>(); + + final ShuffleManager shuffleManager = + new ShuffleManager( + rpcService, + rmInstanceID, + highAvailabilityServices, + testingFatalErrorHandler, + ForkJoinPool.commonPool(), + new HeartbeatServices(100, 200), + new HeartbeatServices(100, 200), + new TestAssignmentTracker()) { + + @Override + public void revokeLeadership() { + super.revokeLeadership(); + runAsyncWithoutFencing( + () -> revokedLeaderIdFuture.complete(getFencingToken())); + } + }; + + try { + shuffleManager.start(); + + Assert.assertNull(shuffleManager.getFencingToken()); + final UUID leaderId = UUID.randomUUID(); + leaderElectionService.isLeader(leaderId); + // after grant leadership, ShuffleManager's leaderId has value + Assert.assertEquals(leaderId, leaderSessionIdFuture.get()); + // then revoke leadership, ShuffleManager's leaderId should be different + leaderElectionService.notLeader(); + Assert.assertNotEquals(leaderId, revokedLeaderIdFuture.get()); + + if (testingFatalErrorHandler.hasExceptionOccurred()) { + testingFatalErrorHandler.rethrowError(); + } + } finally { + rpcService.stopService().get(); + } + } + + private static class TestHaService implements HaServices { + + @Override + public LeaderRetrievalService createLeaderRetrievalService(LeaderReceptor receptor) { + return null; + } + + @Override + public LeaderElectionService createLeaderElectionService() { + return null; + } + + @Override + public void closeAndCleanupAllData() throws Exception {} + + @Override + public void close() throws Exception {} + } + + private static class TestAssignmentTracker implements AssignmentTracker { + + @Override + public boolean isWorkerRegistered(RegistrationID registrationID) { + return false; + } + + @Override + public void registerWorker( + InstanceID workerID, + RegistrationID registrationID, + ShuffleWorkerGateway gateway, + String externalAddress, + int dataPort) {} + + @Override + public void workerReportDataPartitionReleased( + RegistrationID registrationID, + JobID jobID, + DataSetID dataSetID, + DataPartitionID dataPartitionID) {} + + @Override + public void synchronizeWorkerDataPartitions( + RegistrationID registrationID, List dataPartitionStatuses) {} + + @Override + public void unregisterWorker(RegistrationID registrationID) {} + + @Override + public boolean isJobRegistered(JobID jobID) { + return false; + } + + @Override + public void registerJob(JobID jobID) {} + + @Override + public void unregisterJob(JobID jobID) {} + + @Override + public ShuffleResource requestShuffleResource( + JobID jobID, + DataSetID dataSetID, + MapPartitionID mapPartitionID, + int numberOfConsumers, + String dataPartitionFactoryName) { + return null; + } + + @Override + public void releaseShuffleResource( + JobID jobID, DataSetID dataSetID, MapPartitionID mapPartitionID) {} + + @Override + public ChangedWorkerStatus computeChangedWorkers( + JobID jobID, + Collection cachedWorkerList, + boolean considerUnrelatedWorkers) { + return null; + } + + @Override + public List listJobs() { + return null; + } + + @Override + public int getNumberOfWorkers() { + return 0; + } + + @Override + public Map getDataPartitionDistribution(JobID jobID) { + return null; + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManagerTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManagerTest.java new file mode 100644 index 00000000..4084ecff --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/manager/ShuffleManagerTest.java @@ -0,0 +1,687 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager; + +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatManagerImpl; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServices; +import com.alibaba.flink.shuffle.coordinator.heartbeat.NoOpHeartbeatManager; +import com.alibaba.flink.shuffle.coordinator.highavailability.TestingHighAvailabilityServices; +import com.alibaba.flink.shuffle.coordinator.leaderelection.TestingLeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker.AssignmentTrackerImpl; +import com.alibaba.flink.shuffle.coordinator.registration.RegistrationResponse; +import com.alibaba.flink.shuffle.coordinator.utils.RandomIDUtils; +import com.alibaba.flink.shuffle.coordinator.utils.TestingFatalErrorHandler; +import com.alibaba.flink.shuffle.coordinator.utils.TestingUtils; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerGateway; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerMetrics; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.RegistrationID; +import com.alibaba.flink.shuffle.core.utils.OneShotLatch; +import com.alibaba.flink.shuffle.core.utils.TestLogger; +import com.alibaba.flink.shuffle.rpc.message.Acknowledge; +import com.alibaba.flink.shuffle.rpc.test.TestingRpcService; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** Test ShuffleManager. */ +public class ShuffleManagerTest extends TestLogger { + + private static final String partitionFactoryName = + "com.alibaba.flink.shuffle.storage.partition.LocalFileMapPartitionFactory"; + + private static final long TIMEOUT = 10000L; + + private static final long HEARTBEAT_TIMEOUT = 5000; + + private static TestingRpcService rpcService; + + private TestingLeaderElectionService leaderElectionService; + + private ShuffleManager shuffleManager; + + private TestingFatalErrorHandler testingFatalErrorHandler; + + private ShuffleManagerGateway shuffleManagerGateway; + + private ShuffleManagerGateway wronglyFencedGateway; + + private AssignmentTrackerImpl testAssignmentTracker; + + @BeforeClass + public static void setupClass() { + rpcService = new TestingRpcService(); + } + + @AfterClass + public static void teardownClass() throws Exception { + if (rpcService != null) { + rpcService.stopService().get(TIMEOUT, TimeUnit.MILLISECONDS); + } + } + + @Before + public void setup() throws Exception { + shuffleManagerGateway = initializeServicesAndShuffleManagerGateway(); + createShuffleWorkerGateway(); + } + + @After + public void teardown() throws Exception { + if (shuffleManager != null) { + shuffleManager.closeAsync().get(TIMEOUT, TimeUnit.MILLISECONDS); + } + + if (testingFatalErrorHandler != null && testingFatalErrorHandler.hasExceptionOccurred()) { + testingFatalErrorHandler.rethrowError(); + } + + rpcService.clearGateways(); + } + + @Test + public void testRegisterShuffleWorker() + throws InterruptedException, ExecutionException, TimeoutException { + + assertEquals(0, testAssignmentTracker.getWorkers().size()); + CompletableFuture successfulFuture = + shuffleManagerGateway.registerWorker( + TestShuffleWorkerGateway.createShuffleWorkerRegistration()); + RegistrationResponse response = successfulFuture.get(TIMEOUT, TimeUnit.MILLISECONDS); + assertTrue(response instanceof ShuffleWorkerRegistrationSuccess); + assertEquals( + Collections.singleton( + ((ShuffleWorkerRegistrationSuccess) response).getRegistrationID()), + testAssignmentTracker.getWorkers().keySet()); + + // test response successful with instanceID not equal to previous when receive duplicate + // registration from shuffleWorker + CompletableFuture duplicateFuture = + shuffleManagerGateway.registerWorker( + TestShuffleWorkerGateway.createShuffleWorkerRegistration()); + + RegistrationResponse duplicateResponse = duplicateFuture.get(); + assertTrue(duplicateResponse instanceof ShuffleWorkerRegistrationSuccess); + assertNotEquals( + ((ShuffleWorkerRegistrationSuccess) response).getRegistrationID(), + ((ShuffleWorkerRegistrationSuccess) duplicateResponse).getRegistrationID()); + + assertEquals( + Collections.singleton( + ((ShuffleWorkerRegistrationSuccess) duplicateResponse).getRegistrationID()), + testAssignmentTracker.getWorkers().keySet()); + } + + @Test(timeout = 20000) + public void testRevokeAndGrantLeadership() throws Exception { + assertNotEquals( + NoOpHeartbeatManager.class, shuffleManager.getWorkerHeartbeatManager().getClass()); + assertNotEquals( + NoOpHeartbeatManager.class, shuffleManager.getJobHeartbeatManager().getClass()); + + ShuffleWorkerRegistration registration = + TestShuffleWorkerGateway.createShuffleWorkerRegistration(); + CompletableFuture successfulFuture = + shuffleManagerGateway.registerWorker(registration); + RegistrationResponse response = successfulFuture.get(TIMEOUT, TimeUnit.MILLISECONDS); + assertTrue(response instanceof ShuffleWorkerRegistrationSuccess); + + // The initial state + assertEquals( + Collections.singleton(registration.getWorkerID()), + shuffleManager.getShuffleWorkers().keySet()); + assertEquals( + Collections.singleton( + ((ShuffleWorkerRegistrationSuccess) successfulFuture.get()) + .getRegistrationID()), + testAssignmentTracker.getWorkers().keySet()); + assertTrue( + ((HeartbeatManagerImpl) shuffleManager.getWorkerHeartbeatManager()) + .getHeartbeatTargets() + .containsKey(registration.getWorkerID())); + + leaderElectionService.notLeader(); + grantLeadership(leaderElectionService); + + // Right after re-grant leadership + assertEquals( + Collections.singleton(registration.getWorkerID()), + shuffleManager.getShuffleWorkers().keySet()); + assertEquals( + Collections.singleton( + ((ShuffleWorkerRegistrationSuccess) successfulFuture.get()) + .getRegistrationID()), + testAssignmentTracker.getWorkers().keySet()); + assertTrue( + ((HeartbeatManagerImpl) shuffleManager.getWorkerHeartbeatManager()) + .getHeartbeatTargets() + .containsKey(registration.getWorkerID())); + + // The shuffle worker would then come to disconnect itself. + shuffleManager.disconnectWorker(registration.getWorkerID(), new RuntimeException("Test")); + + // After the worker disconnect + assertEquals(0, shuffleManager.getShuffleWorkers().size()); + assertEquals( + 0, + ((AssignmentTrackerImpl) shuffleManager.getAssignmentTracker()) + .getWorkers() + .size()); + assertFalse( + ((HeartbeatManagerImpl) shuffleManager.getWorkerHeartbeatManager()) + .getHeartbeatTargets() + .containsKey(registration.getWorkerID())); + } + + /** + * Tests delayed registration of shuffle worker where the delay is introduced during connection + * from shuffle manager to the registering shuffle worker. + */ + @Test + public void testDelayedRegisterShuffleWorker() throws Exception { + try { + final OneShotLatch startConnection = new OneShotLatch(); + final OneShotLatch finishConnection = new OneShotLatch(); + + // first registration is with blocking connection + rpcService.setRpcGatewayFutureFunction( + rpcGateway -> + CompletableFuture.supplyAsync( + () -> { + startConnection.trigger(); + try { + finishConnection.await(); + } catch (InterruptedException ignored) { + } + return rpcGateway; + }, + TestingUtils.defaultScheduledExecutor())); + + CompletableFuture firstFuture = + shuffleManagerGateway.registerWorker( + TestShuffleWorkerGateway.createShuffleWorkerRegistration()); + try { + firstFuture.get(); + fail( + "Should have failed because connection to shuffle worker is delayed beyond timeout"); + } catch (Exception e) { + final Throwable cause = ExceptionUtils.stripException(e, ExecutionException.class); + assertTrue(cause instanceof TimeoutException); + assertTrue(cause.getMessage().contains("ShuffleManagerGateway.registerWorker")); + } + + startConnection.await(); + + // second registration after timeout is with no delay, expecting it to be succeeded + rpcService.resetRpcGatewayFutureFunction(); + CompletableFuture secondFuture = + shuffleManagerGateway.registerWorker( + TestShuffleWorkerGateway.createShuffleWorkerRegistration()); + RegistrationResponse response = secondFuture.get(); + assertTrue(response instanceof ShuffleWorkerRegistrationSuccess); + + // on success, send data partition report for shuffle manager + + shuffleManagerGateway + .reportDataPartitionStatus( + TestShuffleWorkerGateway.getShuffleWorkerID(), + ((ShuffleWorkerRegistrationSuccess) response).getRegistrationID(), + Collections.singletonList(createDataPartitionStatus())) + .get(); + + // let the remaining part of the first registration proceed + finishConnection.trigger(); + Thread.sleep(1L); + + // verify that the latest registration is valid not being unregistered by the delayed + // one + assertEquals(1, testAssignmentTracker.getWorkers().size()); + } finally { + rpcService.resetRpcGatewayFutureFunction(); + } + } + + /** Tests that a shuffle worker can disconnect from the shuffle manager. */ + @Test + public void testDisconnectShuffleWorker() throws Exception { + CompletableFuture successfulFuture = + shuffleManagerGateway.registerWorker( + TestShuffleWorkerGateway.createShuffleWorkerRegistration()); + + RegistrationResponse response = successfulFuture.get(TIMEOUT, TimeUnit.MILLISECONDS); + assertTrue(response instanceof ShuffleWorkerRegistrationSuccess); + assertEquals(1, testAssignmentTracker.getWorkers().size()); + + shuffleManagerGateway + .disconnectWorker( + TestShuffleWorkerGateway.getShuffleWorkerID(), + new RuntimeException("testDisconnectShuffleWorker")) + .get(); + + assertEquals(0, testAssignmentTracker.getWorkers().size()); + } + + /** Test receive registration with unmatched leadershipId from shuffle worker. */ + @Test + public void testRegisterShuffleWorkerWithUnmatchedLeaderSessionId() throws Exception { + // test throw exception when receive a registration from a shuffle worker which takes + // unmatched + // leaderSessionId + CompletableFuture unMatchedLeaderFuture = + wronglyFencedGateway.registerWorker( + TestShuffleWorkerGateway.createShuffleWorkerRegistration()); + + try { + unMatchedLeaderFuture.get(TIMEOUT, TimeUnit.MILLISECONDS); + fail("Should have failed because we are using a wrongly fenced ShuffleManagerGateway."); + } catch (ExecutionException e) { + assertTrue( + ExceptionUtils.stripException(e, ExecutionException.class) + .getMessage() + .contains("Fencing token mismatch")); + } + } + + /** Test receive registration with invalid address from shuffle worker. */ + @Test + public void testRegisterShuffleWorkerFromInvalidAddress() throws Exception { + // test throw exception when receive a registration from shuffle worker which takes invalid + // address + String invalidAddress = "/shuffleworker2"; + + CompletableFuture invalidAddressFuture = + shuffleManagerGateway.registerWorker( + TestShuffleWorkerGateway.createShuffleWorkerRegistration(invalidAddress)); + assertTrue( + invalidAddressFuture.get(TIMEOUT, TimeUnit.MILLISECONDS) + instanceof RegistrationResponse.Decline); + } + + @Test(timeout = 30000L) + public void testShuffleClientRegisterAndUnregister() + throws ExecutionException, InterruptedException { + assertEquals(0, testAssignmentTracker.getJobs().size()); + final JobID jobID = RandomIDUtils.randomJobId(); + final InstanceID instanceId = new InstanceID(); + shuffleManagerGateway.registerClient(jobID, instanceId).get(); + assertEquals(1, testAssignmentTracker.getJobs().size()); + + assertNotNull(testAssignmentTracker.getJobs().get(jobID)); + + shuffleManagerGateway.unregisterClient(jobID, instanceId).get(); + assertEquals(Collections.singleton(jobID), testAssignmentTracker.getJobs().keySet()); + + // Here we remove the instance id mapping. + assertEquals(0, shuffleManager.getRegisteredClients().size()); + + // Wait till the client finally get unregistered via timeout + while (shuffleManager.getAssignmentTracker().isJobRegistered(jobID)) { + Thread.sleep(1000); + } + } + + @Test + public void testShuffleClientRegisterAndUnregisterAndReconnect() + throws ExecutionException, InterruptedException { + assertEquals(0, testAssignmentTracker.getJobs().size()); + + // A hacky way to register a worker + shuffleManager + .getAssignmentTracker() + .registerWorker( + new InstanceID("worker1"), + new RegistrationID(), + new TestShuffleWorkerGateway(), + "localhost", + 10240); + + final JobID jobID = RandomIDUtils.randomJobId(); + final InstanceID instanceId = new InstanceID(); + shuffleManagerGateway.registerClient(jobID, instanceId).get(); + assertEquals(1, testAssignmentTracker.getJobs().size()); + + assertNotNull(testAssignmentTracker.getJobs().get(jobID)); + DataSetID dataSetID = RandomIDUtils.randomDataSetId(); + MapPartitionID dataPartitionID = RandomIDUtils.randomMapPartitionId(); + shuffleManagerGateway + .requestShuffleResource( + jobID, instanceId, dataSetID, dataPartitionID, 1, partitionFactoryName) + .get(); + + shuffleManagerGateway.unregisterClient(jobID, instanceId).get(); + assertEquals(Collections.singleton(jobID), testAssignmentTracker.getJobs().keySet()); + + InstanceID instanceID2 = new InstanceID(); + shuffleManager.registerClient(jobID, instanceID2).get(); + ManagerToJobHeartbeatPayload payload = + shuffleManager + .heartbeatFromClient(jobID, instanceID2, Collections.emptySet()) + .get(); + assertEquals( + Collections.singleton(new InstanceID("worker1")), + payload.getJobChangedWorkerStatus().getRelevantWorkers().keySet()); + } + + @Test + public void testAllocateShuffleResource() + throws InterruptedException, ExecutionException, TimeoutException { + // register shuffle worker + CompletableFuture successfulFuture = + shuffleManagerGateway.registerWorker( + TestShuffleWorkerGateway.createShuffleWorkerRegistration()); + RegistrationResponse response = successfulFuture.get(TIMEOUT, TimeUnit.MILLISECONDS); + assertTrue(response instanceof ShuffleWorkerRegistrationSuccess); + + // register shuffle client + final JobID jobID = RandomIDUtils.randomJobId(); + final InstanceID instanceID = new InstanceID(); + final DataSetID dataSetID = RandomIDUtils.randomDataSetId(); + final MapPartitionID mapPartitionID = RandomIDUtils.randomMapPartitionId(); + shuffleManagerGateway.registerClient(jobID, instanceID).get(); + + // allocate shuffle resource + ShuffleResource shuffleResource = + shuffleManagerGateway + .requestShuffleResource( + jobID, + instanceID, + dataSetID, + mapPartitionID, + 128, + partitionFactoryName) + .get(); + assertTrue(shuffleResource instanceof DefaultShuffleResource); + DefaultShuffleResource defaultShuffleResource = (DefaultShuffleResource) shuffleResource; + assertEquals( + TestShuffleWorkerGateway.getShuffleWorkerID(), + defaultShuffleResource.getMapPartitionLocation().getWorkerId()); + assertNotNull(testAssignmentTracker.getJobs().get(jobID)); + assertEquals(1, testAssignmentTracker.getJobs().get(jobID).getDataPartitions().size()); + + // release allocated resources + + shuffleManagerGateway + .releaseShuffleResource(jobID, instanceID, dataSetID, mapPartitionID) + .get(); + assertEquals(0, testAssignmentTracker.getJobs().get(jobID).getDataPartitions().size()); + } + + @Test + public void testInconsistentInstanceId() throws ExecutionException, InterruptedException { + // register shuffle client + final JobID jobID = RandomIDUtils.randomJobId(); + final InstanceID instanceID = new InstanceID(); + shuffleManagerGateway.registerClient(jobID, instanceID).get(); + + { + CompletableFuture result = + shuffleManagerGateway.heartbeatFromClient( + jobID, new InstanceID(), Collections.emptySet()); + assertFailedWithInconsistentInstanceId(result); + } + + { + CompletableFuture result = + shuffleManagerGateway.requestShuffleResource( + jobID, + new InstanceID(), + RandomIDUtils.randomDataSetId(), + RandomIDUtils.randomMapPartitionId(), + 2, + partitionFactoryName); + assertFailedWithInconsistentInstanceId(result); + } + + { + CompletableFuture result = + shuffleManagerGateway.releaseShuffleResource( + jobID, + new InstanceID(), + RandomIDUtils.randomDataSetId(), + RandomIDUtils.randomMapPartitionId()); + assertFailedWithInconsistentInstanceId(result); + } + + { + CompletableFuture result = + shuffleManagerGateway.unregisterClient(jobID, new InstanceID()); + assertFailedWithInconsistentInstanceId(result); + } + } + + @Test + public void testShuffleWorkerReportDataPartitionWithoutPendingJob() throws Exception { + ShuffleWorkerRegistration worker = + TestShuffleWorkerGateway.createShuffleWorkerRegistration(); + ShuffleWorkerRegistrationSuccess response = + (ShuffleWorkerRegistrationSuccess) shuffleManager.registerWorker(worker).get(); + + DataPartitionStatus dataPartitionStatus = + new DataPartitionStatus( + RandomIDUtils.randomJobId(), + new DataPartitionCoordinate( + RandomIDUtils.randomDataSetId(), + RandomIDUtils.randomMapPartitionId())); + DataPartitionStatus releasingDataPartitionStatus = + new DataPartitionStatus( + RandomIDUtils.randomJobId(), + new DataPartitionCoordinate( + RandomIDUtils.randomDataSetId(), + RandomIDUtils.randomMapPartitionId())); + + shuffleManager + .reportDataPartitionStatus( + worker.getWorkerID(), + response.getRegistrationID(), + Arrays.asList(dataPartitionStatus, releasingDataPartitionStatus)) + .get(); + + assertTrue( + shuffleManager + .getAssignmentTracker() + .isJobRegistered(dataPartitionStatus.getJobId())); + assertTrue( + shuffleManager + .getAssignmentTracker() + .isJobRegistered(releasingDataPartitionStatus.getJobId())); + } + + @Test + public void testShuffleWorkerHeartbeatWithoutPendingJobs() throws Exception { + ShuffleWorkerRegistration worker = + TestShuffleWorkerGateway.createShuffleWorkerRegistration(); + ShuffleWorkerRegistrationSuccess response = + (ShuffleWorkerRegistrationSuccess) shuffleManager.registerWorker(worker).get(); + + DataPartitionStatus dataPartitionStatus = + new DataPartitionStatus( + RandomIDUtils.randomJobId(), + new DataPartitionCoordinate( + RandomIDUtils.randomDataSetId(), + RandomIDUtils.randomMapPartitionId())); + DataPartitionStatus releasingDataPartitionStatus = + new DataPartitionStatus( + RandomIDUtils.randomJobId(), + new DataPartitionCoordinate( + RandomIDUtils.randomDataSetId(), + RandomIDUtils.randomMapPartitionId())); + + shuffleManager.heartbeatFromWorker( + worker.getWorkerID(), + new WorkerToManagerHeartbeatPayload( + Arrays.asList(dataPartitionStatus, releasingDataPartitionStatus))); + + assertTrue( + shuffleManager + .getAssignmentTracker() + .isJobRegistered(dataPartitionStatus.getJobId())); + assertTrue( + shuffleManager + .getAssignmentTracker() + .isJobRegistered(releasingDataPartitionStatus.getJobId())); + } + + private void assertFailedWithInconsistentInstanceId(CompletableFuture result) { + try { + result.get(); + } catch (Exception e) { + assertTrue(ExceptionUtils.findThrowable(e, IllegalStateException.class).isPresent()); + } + } + + private ShuffleWorkerGateway createShuffleWorkerGateway() { + final ShuffleWorkerGateway shuffleWorkerGateway = new TestShuffleWorkerGateway(); + rpcService.registerGateway(shuffleWorkerGateway.getAddress(), shuffleWorkerGateway); + return shuffleWorkerGateway; + } + + private ShuffleManagerGateway initializeServicesAndShuffleManagerGateway() + throws InterruptedException, ExecutionException, TimeoutException { + testingFatalErrorHandler = new TestingFatalErrorHandler(); + InstanceID shuffleManagerInstanceID = new InstanceID(); + testAssignmentTracker = new AssignmentTrackerImpl(); + + leaderElectionService = new TestingLeaderElectionService(); + final TestingHighAvailabilityServices highAvailabilityServices = + new TestingHighAvailabilityServices(); + highAvailabilityServices.setShuffleManagerLeaderElectionService(leaderElectionService); + final HeartbeatServices heartbeatServices = new HeartbeatServices(1000L, HEARTBEAT_TIMEOUT); + shuffleManager = + new ShuffleManager( + rpcService, + shuffleManagerInstanceID, + highAvailabilityServices, + testingFatalErrorHandler, + ForkJoinPool.commonPool(), + heartbeatServices, + heartbeatServices, + testAssignmentTracker); + + shuffleManager.start(); + + wronglyFencedGateway = + rpcService + .connectTo( + shuffleManager.getAddress(), + UUID.randomUUID(), + ShuffleManagerGateway.class) + .get(TIMEOUT, TimeUnit.MILLISECONDS); + grantLeadership(leaderElectionService).get(TIMEOUT, TimeUnit.MILLISECONDS); + return shuffleManager.getSelfGateway(ShuffleManagerGateway.class); + } + + private CompletableFuture grantLeadership( + TestingLeaderElectionService leaderElectionService) { + UUID leaderSessionId = UUID.randomUUID(); + return leaderElectionService.isLeader(leaderSessionId); + } + + private DataPartitionStatus createDataPartitionStatus() { + final DataPartitionCoordinate dataPartitionCoordinate = + new DataPartitionCoordinate( + RandomIDUtils.randomDataSetId(), RandomIDUtils.randomMapPartitionId()); + final DataPartitionStatus dataPartitionStatus = + new DataPartitionStatus(RandomIDUtils.randomJobId(), dataPartitionCoordinate); + return dataPartitionStatus; + } + + private static class TestShuffleWorkerGateway implements ShuffleWorkerGateway { + + private static final String rpcAddress = "foobar:1234"; + + private static final String hostName = "foobar"; + + private static final int dataPort = 1234; + + private static final InstanceID shuffleWorkerID = new InstanceID(); + + private static final int processId = 12345; + + static ShuffleWorkerRegistration createShuffleWorkerRegistration() { + return new ShuffleWorkerRegistration( + rpcAddress, hostName, shuffleWorkerID, dataPort, processId); + } + + static ShuffleWorkerRegistration createShuffleWorkerRegistration(final String rpcAddress) { + return new ShuffleWorkerRegistration( + rpcAddress, hostName, shuffleWorkerID, dataPort, processId); + } + + static InstanceID getShuffleWorkerID() { + return shuffleWorkerID; + } + + @Override + public void heartbeatFromManager(InstanceID managerID) {} + + @Override + public void disconnectManager(Exception cause) {} + + @Override + public CompletableFuture releaseDataPartition( + JobID jobID, DataSetID dataSetID, DataPartitionID dataPartitionID) { + return null; + } + + @Override + public CompletableFuture removeReleasedDataPartitionMeta( + JobID jobID, DataSetID dataSetID, DataPartitionID dataPartitionID) { + return null; + } + + @Override + public CompletableFuture getWorkerMetrics() { + return CompletableFuture.completedFuture(null); + } + + @Override + public String getAddress() { + return rpcAddress; + } + + @Override + public String getHostname() { + return null; + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/AssignmentTrackerTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/AssignmentTrackerTest.java new file mode 100644 index 00000000..0e0a981d --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/manager/assignmenttracker/AssignmentTrackerTest.java @@ -0,0 +1,796 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker; + +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionStatus; +import com.alibaba.flink.shuffle.coordinator.manager.DefaultShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerDescriptor; +import com.alibaba.flink.shuffle.coordinator.utils.EmptyShuffleWorkerGateway; +import com.alibaba.flink.shuffle.coordinator.utils.RandomIDUtils; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerGateway; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.RegistrationID; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.rpc.message.Acknowledge; +import com.alibaba.flink.shuffle.storage.partition.HDDOnlyLocalFileMapPartitionFactory; +import com.alibaba.flink.shuffle.storage.partition.SSDOnlyLocalFileMapPartitionFactory; + +import org.apache.commons.lang3.tuple.Triple; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; + +import static com.alibaba.flink.shuffle.coordinator.utils.RandomIDUtils.randomDataSetId; +import static com.alibaba.flink.shuffle.coordinator.utils.RandomIDUtils.randomJobId; +import static com.alibaba.flink.shuffle.coordinator.utils.RandomIDUtils.randomMapPartitionId; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** Tests the assignment tracker implementation by {@link AssignmentTrackerImpl}. */ +public class AssignmentTrackerTest { + private static final Logger LOG = LoggerFactory.getLogger(AssignmentTrackerTest.class); + + private static final String partitionFactory = + "com.alibaba.flink.shuffle.storage.partition.LocalFileMapPartitionFactory"; + + @Test + public void testWorkerRegistration() { + RegistrationID registrationID = new RegistrationID(); + + AssignmentTracker assignmentTracker = new AssignmentTrackerImpl(); + assertFalse(assignmentTracker.isWorkerRegistered(registrationID)); + + assignmentTracker.registerWorker( + new InstanceID("test"), registrationID, new EmptyShuffleWorkerGateway(), "", 1024); + assertTrue(assignmentTracker.isWorkerRegistered(registrationID)); + } + + @Test + public void testJobRegistration() { + JobID jobId = randomJobId(); + + AssignmentTracker assignmentTracker = new AssignmentTrackerImpl(); + assertFalse(assignmentTracker.isJobRegistered(jobId)); + + assignmentTracker.registerJob(jobId); + assertTrue(assignmentTracker.isJobRegistered(jobId)); + } + + @Test(expected = ShuffleResourceAllocationException.class) + public void testRequestResourceWithoutWorker() throws Exception { + JobID jobId = randomJobId(); + + AssignmentTracker assignmentTracker = new AssignmentTrackerImpl(); + assignmentTracker.registerJob(jobId); + + assignmentTracker.requestShuffleResource( + jobId, randomDataSetId(), randomMapPartitionId(), 3, partitionFactory); + } + + @Test + public void testRequestResourceWithWorkers() throws Exception { + JobID jobId = randomJobId(); + + AssignmentTracker assignmentTracker = new AssignmentTrackerImpl(); + assignmentTracker.registerJob(jobId); + + // Registers two workers + RegistrationID worker1 = new RegistrationID(); + RegistrationID worker2 = new RegistrationID(); + + assignmentTracker.registerWorker( + new InstanceID("worker1"), + worker1, + new EmptyShuffleWorkerGateway(), + "worker1", + 1024); + assignmentTracker.registerWorker( + new InstanceID("worker2"), + worker2, + new EmptyShuffleWorkerGateway(), + "worker2", + 1026); + + MapPartitionID dataPartitionId1 = randomMapPartitionId(); + MapPartitionID dataPartitionId2 = randomMapPartitionId(); + + List allocatedResources = new ArrayList<>(); + allocatedResources.add( + assignmentTracker.requestShuffleResource( + jobId, randomDataSetId(), dataPartitionId1, 2, partitionFactory)); + allocatedResources.add( + assignmentTracker.requestShuffleResource( + jobId, randomDataSetId(), dataPartitionId2, 2, partitionFactory)); + assertThat( + allocatedResources.stream() + .map( + resource -> + ((DefaultShuffleResource) resource) + .getMapPartitionLocation()) + .collect(Collectors.toList()), + containsInAnyOrder( + new ShuffleWorkerDescriptor(new InstanceID("worker1"), "worker1", 1024), + new ShuffleWorkerDescriptor(new InstanceID("worker2"), "worker2", 1026))); + } + + @Test + public void testReAllocation() throws Exception { + JobID jobId = randomJobId(); + RegistrationID worker1 = new RegistrationID(); + + AssignmentTracker assignmentTracker = + createAssignmentTracker( + jobId, worker1, new InstanceID("worker1"), new EmptyShuffleWorkerGateway()); + + DataSetID dataSetId = randomDataSetId(); + MapPartitionID dataPartitionId = randomMapPartitionId(); + ShuffleResource shuffleResource1 = + assignmentTracker.requestShuffleResource( + jobId, dataSetId, dataPartitionId, 2, partitionFactory); + + // reallocation the same data partition on the same worker should remain unchanged + ShuffleResource shuffleResource2 = + assignmentTracker.requestShuffleResource( + jobId, dataSetId, dataPartitionId, 2, partitionFactory); + assertEquals(shuffleResource1, shuffleResource2); + } + + @Test + public void testSynchronizeStatusFromWorkerWithMissedDataPartitions() throws Exception { + JobID jobId = randomJobId(); + DataSetID dataSetID = randomDataSetId(); + MapPartitionID dataPartitionID = randomMapPartitionId(); + + RegistrationID worker1 = new RegistrationID(); + ReleaseRecordingShuffleWorkerGateway shuffleWorkerGateway = + new ReleaseRecordingShuffleWorkerGateway(); + + AssignmentTrackerImpl assignmentTracker = + createAssignmentTracker( + jobId, worker1, new InstanceID("worker1"), shuffleWorkerGateway); + + assignmentTracker.synchronizeWorkerDataPartitions( + worker1, + Collections.singletonList( + new DataPartitionStatus( + jobId, + new DataPartitionCoordinate(dataSetID, dataPartitionID), + false))); + + assertThat( + assignmentTracker.getJobs().get(jobId).getDataPartitions().keySet().stream() + .map(DataPartitionCoordinate::getDataPartitionId) + .collect(Collectors.toList()), + containsInAnyOrder(dataPartitionID)); + } + + @Test + public void testSynchronizedStatusWorkerReleasingAndManagerNot() throws Exception { + JobID jobId = randomJobId(); + DataSetID dataSetID = randomDataSetId(); + MapPartitionID dataPartitionID = randomMapPartitionId(); + DataPartitionCoordinate coordinate = + new DataPartitionCoordinate(dataSetID, dataPartitionID); + + RegistrationID worker1 = new RegistrationID(); + ReleaseRecordingShuffleWorkerGateway shuffleWorkerGateway = + new ReleaseRecordingShuffleWorkerGateway(); + + AssignmentTrackerImpl assignmentTracker = + createAssignmentTracker( + jobId, worker1, new InstanceID("worker1"), shuffleWorkerGateway); + + assignmentTracker.requestShuffleResource( + jobId, dataSetID, dataPartitionID, 2, partitionFactory); + assignmentTracker.synchronizeWorkerDataPartitions( + worker1, + Collections.singletonList( + new DataPartitionStatus( + jobId, + new DataPartitionCoordinate(dataSetID, dataPartitionID), + true))); + + assertEquals(0, assignmentTracker.getJobs().get(jobId).getDataPartitions().size()); + WorkerStatus workerStatus = assignmentTracker.getWorkers().get(worker1); + assertTrue(workerStatus.getDataPartitions().get(coordinate).isReleasing()); + assertEquals(1, shuffleWorkerGateway.getReleaseMetaPartitions().size()); + assertEquals( + Triple.of(jobId, dataSetID, dataPartitionID), + shuffleWorkerGateway.getReleaseMetaPartitions().get(0)); + } + + @Test + public void testSynchronizedStatusManagerReleasingAndWorkerNot() throws Exception { + JobID jobId = randomJobId(); + DataSetID dataSetID = randomDataSetId(); + MapPartitionID dataPartitionID = randomMapPartitionId(); + DataPartitionCoordinate coordinate = + new DataPartitionCoordinate(dataSetID, dataPartitionID); + + RegistrationID worker1 = new RegistrationID(); + ReleaseRecordingShuffleWorkerGateway shuffleWorkerGateway = + new ReleaseRecordingShuffleWorkerGateway(); + + AssignmentTrackerImpl assignmentTracker = + createAssignmentTracker( + jobId, worker1, new InstanceID("worker1"), shuffleWorkerGateway); + + assignmentTracker.requestShuffleResource( + jobId, dataSetID, dataPartitionID, 2, partitionFactory); + assignmentTracker.releaseShuffleResource(jobId, dataSetID, dataPartitionID); + shuffleWorkerGateway.reset(); + + assignmentTracker.synchronizeWorkerDataPartitions( + worker1, + Collections.singletonList( + new DataPartitionStatus( + jobId, + new DataPartitionCoordinate(dataSetID, dataPartitionID), + false))); + + assertEquals(0, assignmentTracker.getJobs().get(jobId).getDataPartitions().size()); + WorkerStatus workerStatus = assignmentTracker.getWorkers().get(worker1); + assertTrue(workerStatus.getDataPartitions().get(coordinate).isReleasing()); + assertEquals(1, shuffleWorkerGateway.getReleasedPartitions().size()); + assertEquals( + Triple.of(jobId, dataSetID, dataPartitionID), + shuffleWorkerGateway.getReleasedPartitions().get(0)); + } + + @Test + public void testSynchronizedStatusManagerReleasingAndWorkerReleasing() throws Exception { + JobID jobId = randomJobId(); + DataSetID dataSetID = randomDataSetId(); + MapPartitionID dataPartitionID = randomMapPartitionId(); + DataPartitionCoordinate coordinate = + new DataPartitionCoordinate(dataSetID, dataPartitionID); + + RegistrationID worker1 = new RegistrationID(); + ReleaseRecordingShuffleWorkerGateway shuffleWorkerGateway = + new ReleaseRecordingShuffleWorkerGateway(); + + AssignmentTrackerImpl assignmentTracker = + createAssignmentTracker( + jobId, worker1, new InstanceID("worker1"), shuffleWorkerGateway); + + assignmentTracker.requestShuffleResource( + jobId, dataSetID, dataPartitionID, 2, partitionFactory); + assignmentTracker.releaseShuffleResource(jobId, dataSetID, dataPartitionID); + shuffleWorkerGateway.reset(); + + assignmentTracker.synchronizeWorkerDataPartitions( + worker1, + Collections.singletonList( + new DataPartitionStatus( + jobId, + new DataPartitionCoordinate(dataSetID, dataPartitionID), + true))); + + assertEquals(0, assignmentTracker.getJobs().get(jobId).getDataPartitions().size()); + WorkerStatus workerStatus = assignmentTracker.getWorkers().get(worker1); + assertTrue(workerStatus.getDataPartitions().get(coordinate).isReleasing()); + assertEquals(1, shuffleWorkerGateway.getReleaseMetaPartitions().size()); + assertEquals( + Triple.of(jobId, dataSetID, dataPartitionID), + shuffleWorkerGateway.getReleaseMetaPartitions().get(0)); + } + + @Test + public void testClientReleaseDataPartition() throws Exception { + JobID jobId = randomJobId(); + RegistrationID worker1 = new RegistrationID(); + ReleaseRecordingShuffleWorkerGateway shuffleWorkerGateway = + new ReleaseRecordingShuffleWorkerGateway(); + + AssignmentTrackerImpl assignmentTracker = + createAssignmentTracker( + jobId, worker1, new InstanceID("worker1"), shuffleWorkerGateway); + + DataSetID dataSetId = randomDataSetId(); + MapPartitionID dataPartitionId = randomMapPartitionId(); + DataPartitionCoordinate coordinate = + new DataPartitionCoordinate(dataSetId, dataPartitionId); + + assignmentTracker.requestShuffleResource( + jobId, dataSetId, dataPartitionId, 2, partitionFactory); + + // Step 1. Client asks to releasing the data partition + assignmentTracker.releaseShuffleResource(jobId, dataSetId, dataPartitionId); + assertEquals(1, shuffleWorkerGateway.getReleasedPartitions().size()); + assertEquals( + Triple.of(jobId, dataSetId, dataPartitionId), + shuffleWorkerGateway.getReleasedPartitions().get(0)); + + WorkerStatus workerStatus = assignmentTracker.getWorkers().get(worker1); + assertTrue(workerStatus.getDataPartitions().get(coordinate).isReleasing()); + + // Step 2. If the data partition is not removed on synchronization, would try to remove it + // again. + assignmentTracker.synchronizeWorkerDataPartitions( + worker1, + Collections.singletonList( + new DataPartitionStatus( + jobId, new DataPartitionCoordinate(dataSetId, dataPartitionId)))); + assertEquals(2, shuffleWorkerGateway.getReleasedPartitions().size()); + assertEquals( + Triple.of(jobId, dataSetId, dataPartitionId), + shuffleWorkerGateway.getReleasedPartitions().get(1)); + + // Step 3. The Worker remove the data and notifies the manager + assignmentTracker.workerReportDataPartitionReleased( + worker1, jobId, dataSetId, dataPartitionId); + assertEquals(1, shuffleWorkerGateway.getReleaseMetaPartitions().size()); + assertEquals( + Triple.of(jobId, dataSetId, dataPartitionId), + shuffleWorkerGateway.getReleaseMetaPartitions().get(0)); + + // The data would be removed after worker has removed it. + assignmentTracker.synchronizeWorkerDataPartitions(worker1, Collections.emptyList()); + assertFalse( + assignmentTracker.getJobs().get(jobId).getDataPartitions().containsKey(coordinate)); + assertFalse( + assignmentTracker + .getWorkers() + .get(worker1) + .getDataPartitions() + .containsKey(coordinate)); + } + + @Test + public void testWorkerReleaseDataPartition() throws Exception { + JobID jobId = randomJobId(); + + AssignmentTrackerImpl assignmentTracker = new AssignmentTrackerImpl(); + assignmentTracker.registerJob(jobId); + + RegistrationID worker1 = new RegistrationID(); + ReleaseRecordingShuffleWorkerGateway shuffleWorkerGateway = + new ReleaseRecordingShuffleWorkerGateway(); + assignmentTracker.registerWorker( + new InstanceID("worker1"), worker1, shuffleWorkerGateway, "worker1", 1024); + + DataSetID dataSetId = randomDataSetId(); + MapPartitionID dataPartitionId = randomMapPartitionId(); + DataPartitionCoordinate coordinate = + new DataPartitionCoordinate(dataSetId, dataPartitionId); + + assignmentTracker.requestShuffleResource( + jobId, dataSetId, dataPartitionId, 2, partitionFactory); + + // Step 1. Worker reports the data is released + assignmentTracker.workerReportDataPartitionReleased( + worker1, jobId, dataSetId, dataPartitionId); + assertEquals(1, shuffleWorkerGateway.getReleaseMetaPartitions().size()); + assertEquals( + Triple.of(jobId, dataSetId, dataPartitionId), + shuffleWorkerGateway.getReleaseMetaPartitions().get(0)); + + WorkerStatus workerStatus = assignmentTracker.getWorkers().get(worker1); + assertTrue(workerStatus.getDataPartitions().get(coordinate).isReleasing()); + + // Step 2. If the data partition meta is not removed on synchronization, would try to remove + // it again. + assignmentTracker.synchronizeWorkerDataPartitions( + worker1, + Collections.singletonList( + new DataPartitionStatus( + jobId, + new DataPartitionCoordinate(dataSetId, dataPartitionId), + true))); + assertEquals(2, shuffleWorkerGateway.getReleaseMetaPartitions().size()); + assertEquals( + Triple.of(jobId, dataSetId, dataPartitionId), + shuffleWorkerGateway.getReleaseMetaPartitions().get(1)); + + // Step 3. The data would be removed after worker has removed it. + assignmentTracker.synchronizeWorkerDataPartitions(worker1, Collections.emptyList()); + assertFalse( + assignmentTracker.getJobs().get(jobId).getDataPartitions().containsKey(coordinate)); + assertFalse( + assignmentTracker + .getWorkers() + .get(worker1) + .getDataPartitions() + .containsKey(coordinate)); + } + + @Test + public void testComputeChangedWorker() throws ShuffleResourceAllocationException { + JobID jobId = randomJobId(); + RegistrationID worker1 = new RegistrationID(); + ReleaseRecordingShuffleWorkerGateway shuffleWorkerGateway = + new ReleaseRecordingShuffleWorkerGateway(); + + AssignmentTrackerImpl assignmentTracker = + createAssignmentTracker( + jobId, worker1, new InstanceID("worker1"), shuffleWorkerGateway); + assignmentTracker.registerWorker( + new InstanceID("worker2"), + new RegistrationID(), + new EmptyShuffleWorkerGateway(), + "worker2", + 1024); + + DataSetID dataSetID = RandomIDUtils.randomDataSetId(); + MapPartitionID mapPartitionID = RandomIDUtils.randomMapPartitionId(); + + // Requesting two shuffle resources, which would be assigned in the two workers + InstanceID firstWorkerId = + ((DefaultShuffleResource) + assignmentTracker.requestShuffleResource( + jobId, dataSetID, mapPartitionID, 2, partitionFactory)) + .getMapPartitionLocation() + .getWorkerId(); + InstanceID secondWorkerId = + ((DefaultShuffleResource) + assignmentTracker.requestShuffleResource( + jobId, + RandomIDUtils.randomDataSetId(), + RandomIDUtils.randomMapPartitionId(), + 2, + partitionFactory)) + .getMapPartitionLocation() + .getWorkerId(); + assertNotEquals(secondWorkerId, firstWorkerId); + + ChangedWorkerStatus changedWorkerStatus = + assignmentTracker.computeChangedWorkers( + jobId, + new HashSet<>(Arrays.asList(new InstanceID("dummy"), secondWorkerId)), + true); + assertEquals( + Collections.singletonList(new InstanceID("dummy")), + changedWorkerStatus.getIrrelevantWorkers()); + assertEquals(1, changedWorkerStatus.getRelevantWorkers().size()); + Set dataPartitions = + changedWorkerStatus.getRelevantWorkers().get(firstWorkerId); + assertEquals( + Collections.singleton(new DataPartitionCoordinate(dataSetID, mapPartitionID)), + dataPartitions); + + ChangedWorkerStatus noUnrelatedWorkerStatus = + assignmentTracker.computeChangedWorkers( + jobId, + new HashSet<>(Arrays.asList(new InstanceID("dummy"), secondWorkerId)), + false); + assertEquals(0, noUnrelatedWorkerStatus.getIrrelevantWorkers().size()); + assertEquals(1, noUnrelatedWorkerStatus.getRelevantWorkers().size()); + assertEquals( + Collections.singleton(new DataPartitionCoordinate(dataSetID, mapPartitionID)), + dataPartitions); + } + + @Test + public void testComputeChangeWorkerSafeGuardTest() throws ShuffleResourceAllocationException { + JobID jobId = randomJobId(); + RegistrationID worker1 = new RegistrationID("worker1"); + InstanceID worker1InstanceId = new InstanceID(); + ReleaseRecordingShuffleWorkerGateway shuffleWorkerGateway = + new ReleaseRecordingShuffleWorkerGateway(); + AssignmentTrackerImpl assignmentTracker = + createAssignmentTracker(jobId, worker1, worker1InstanceId, shuffleWorkerGateway); + + DataSetID dataSetID = RandomIDUtils.randomDataSetId(); + MapPartitionID mapPartitionID = RandomIDUtils.randomMapPartitionId(); + + assignmentTracker.requestShuffleResource( + jobId, dataSetID, mapPartitionID, 2, partitionFactory); + + // This is a safe guard, in reality the following should not happen, but + // we still want to have this. + assignmentTracker.getWorkers().remove(worker1); + + ChangedWorkerStatus workerStatus = + assignmentTracker.computeChangedWorkers( + jobId, Collections.singleton(worker1InstanceId), true); + assertEquals(0, workerStatus.getRelevantWorkers().size()); + assertEquals( + Collections.singletonList(worker1InstanceId), workerStatus.getIrrelevantWorkers()); + } + + @Test + public void testUnregisterJob() throws Exception { + JobID jobId = randomJobId(); + DataSetID dataSetId = randomDataSetId(); + MapPartitionID dataPartitionID1 = randomMapPartitionId(); + MapPartitionID dataPartitionID2 = randomMapPartitionId(); + + RegistrationID worker1 = new RegistrationID(); + ReleaseRecordingShuffleWorkerGateway shuffleWorkerGateway = + new ReleaseRecordingShuffleWorkerGateway(); + AssignmentTrackerImpl assignmentTracker = + createAssignmentTracker( + jobId, worker1, new InstanceID("worker1"), shuffleWorkerGateway); + + assignmentTracker.requestShuffleResource( + jobId, dataSetId, dataPartitionID1, 2, partitionFactory); + assignmentTracker.requestShuffleResource( + jobId, dataSetId, dataPartitionID2, 2, partitionFactory); + + assignmentTracker.unregisterJob(jobId); + + assertEquals(0, assignmentTracker.getJobs().size()); + + WorkerStatus workerStatus = assignmentTracker.getWorkers().get(worker1); + assertEquals(2, workerStatus.getDataPartitions().size()); + for (MapPartitionID partitionId : Arrays.asList(dataPartitionID1, dataPartitionID2)) { + assertTrue( + workerStatus + .getDataPartitions() + .get(new DataPartitionCoordinate(dataSetId, partitionId)) + .isReleasing()); + } + } + + @Test + public void testUnregisterWorker() throws ShuffleResourceAllocationException { + JobID jobId = randomJobId(); + DataSetID dataSetId = randomDataSetId(); + MapPartitionID dataPartitionID1 = randomMapPartitionId(); + MapPartitionID dataPartitionID2 = randomMapPartitionId(); + + RegistrationID worker1 = new RegistrationID(); + ReleaseRecordingShuffleWorkerGateway shuffleWorkerGateway = + new ReleaseRecordingShuffleWorkerGateway(); + AssignmentTrackerImpl assignmentTracker = + createAssignmentTracker( + jobId, worker1, new InstanceID("worker1"), shuffleWorkerGateway); + + assignmentTracker.requestShuffleResource( + jobId, dataSetId, dataPartitionID1, 2, partitionFactory); + assignmentTracker.requestShuffleResource( + jobId, dataSetId, dataPartitionID2, 2, partitionFactory); + + assignmentTracker.unregisterWorker(worker1); + + assertEquals(0, assignmentTracker.getJobs().get(jobId).getDataPartitions().size()); + assertEquals(0, assignmentTracker.getWorkers().size()); + } + + @Test + public void testUnregisterWorkerWithReleasingPartitions() + throws ShuffleResourceAllocationException { + JobID jobId = randomJobId(); + DataSetID dataSetId = randomDataSetId(); + MapPartitionID dataPartitionID1 = randomMapPartitionId(); + MapPartitionID dataPartitionID2 = randomMapPartitionId(); + + RegistrationID worker1 = new RegistrationID(); + ReleaseRecordingShuffleWorkerGateway shuffleWorkerGateway = + new ReleaseRecordingShuffleWorkerGateway(); + AssignmentTrackerImpl assignmentTracker = + createAssignmentTracker( + jobId, worker1, new InstanceID("worker1"), shuffleWorkerGateway); + + assignmentTracker.requestShuffleResource( + jobId, dataSetId, dataPartitionID1, 2, partitionFactory); + assignmentTracker.requestShuffleResource( + jobId, dataSetId, dataPartitionID2, 2, partitionFactory); + + assignmentTracker.releaseShuffleResource(jobId, dataSetId, dataPartitionID1); + + assignmentTracker.unregisterWorker(worker1); + assertEquals(0, assignmentTracker.getJobs().get(jobId).getDataPartitions().size()); + assertEquals(0, assignmentTracker.getWorkers().size()); + } + + @Test + public void testShuffleWorkerRestartedBeforeLastTimeout() + throws ShuffleResourceAllocationException { + JobID jobId = randomJobId(); + DataSetID dataSetId = randomDataSetId(); + MapPartitionID dataPartitionId = randomMapPartitionId(); + DataPartitionCoordinate coordinate = + new DataPartitionCoordinate(dataSetId, dataPartitionId); + + RegistrationID worker1 = new RegistrationID(); + ReleaseRecordingShuffleWorkerGateway shuffleWorkerGateway1 = + new ReleaseRecordingShuffleWorkerGateway(); + AssignmentTrackerImpl assignmentTracker = + createAssignmentTracker( + jobId, worker1, new InstanceID("worker1"), shuffleWorkerGateway1); + + assignmentTracker.requestShuffleResource( + jobId, dataSetId, dataPartitionId, 2, partitionFactory); + assertEquals(1, assignmentTracker.getWorkers().get(worker1).getDataPartitions().size()); + assertEquals(1, assignmentTracker.getJobs().get(jobId).getDataPartitions().size()); + + // Now simulates the first worker exit and the second worker registered before the first + // timeout. + RegistrationID worker2 = new RegistrationID(); + ReleaseRecordingShuffleWorkerGateway shuffleWorkerGateway2 = + new ReleaseRecordingShuffleWorkerGateway(); + assignmentTracker.registerWorker( + new InstanceID("worker2"), worker2, shuffleWorkerGateway2, "xx", 12345); + assignmentTracker.synchronizeWorkerDataPartitions( + worker2, Collections.singletonList(new DataPartitionStatus(jobId, coordinate))); + assertEquals(0, assignmentTracker.getWorkers().get(worker1).getDataPartitions().size()); + assertEquals(1, assignmentTracker.getWorkers().get(worker2).getDataPartitions().size()); + assertEquals(1, assignmentTracker.getJobs().get(jobId).getDataPartitions().size()); + assertEquals( + worker2, + assignmentTracker + .getJobs() + .get(jobId) + .getDataPartitions() + .get(coordinate) + .getRegistrationID()); + + // Now the first worker timeout, it should not affect the current status + assignmentTracker.unregisterWorker(worker1); + assertEquals(1, assignmentTracker.getWorkers().get(worker2).getDataPartitions().size()); + assertEquals(1, assignmentTracker.getJobs().get(jobId).getDataPartitions().size()); + assertEquals( + worker2, + assignmentTracker + .getJobs() + .get(jobId) + .getDataPartitions() + .get(coordinate) + .getRegistrationID()); + + assignmentTracker.unregisterJob(jobId); + assertThat( + shuffleWorkerGateway2.getReleasedPartitions(), + containsInAnyOrder(Triple.of(jobId, dataSetId, dataPartitionId))); + } + + // ------------------------------- Utilities ---------------------------------------------- + + private AssignmentTrackerImpl createAssignmentTracker( + JobID jobId, + RegistrationID workerRegistrationId, + InstanceID workerInstanceID, + ShuffleWorkerGateway gateway) { + AssignmentTrackerImpl assignmentTracker = new AssignmentTrackerImpl(); + assignmentTracker.registerJob(jobId); + + assignmentTracker.registerWorker( + workerInstanceID, workerRegistrationId, gateway, "worker1", 1024); + return assignmentTracker; + } + + private static class ReleaseRecordingShuffleWorkerGateway extends EmptyShuffleWorkerGateway { + private final List> releasedPartitions = + new ArrayList<>(); + + private final List> releaseMetaPartitions = + new ArrayList<>(); + + @Override + public CompletableFuture releaseDataPartition( + JobID jobID, DataSetID dataSetID, DataPartitionID dataPartitionID) { + releasedPartitions.add(Triple.of(jobID, dataSetID, dataPartitionID)); + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + @Override + public CompletableFuture removeReleasedDataPartitionMeta( + JobID jobID, DataSetID dataSetID, DataPartitionID dataPartitionID) { + releaseMetaPartitions.add(Triple.of(jobID, dataSetID, dataPartitionID)); + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + public List> getReleasedPartitions() { + return releasedPartitions; + } + + public List> getReleaseMetaPartitions() { + return releaseMetaPartitions; + } + + public void reset() { + releasedPartitions.clear(); + releaseMetaPartitions.clear(); + } + } + + @Test + public void testGetDataPartitionTypeSuccess0() { + JobID jobId = randomJobId(); + RegistrationID worker1 = new RegistrationID(); + + AssignmentTrackerImpl assignmentTracker = + createAssignmentTracker( + jobId, worker1, new InstanceID("worker1"), new EmptyShuffleWorkerGateway()); + + DataPartition.DataPartitionType dataPartitionType = null; + try { + dataPartitionType = assignmentTracker.getDataPartitionType(partitionFactory); + } catch (Throwable th) { + LOG.error("Get data partition type failed, ", th); + fail(); + } + assertEquals(DataPartition.DataPartitionType.MAP_PARTITION, dataPartitionType); + } + + @Test + public void testGetDataPartitionTypeSuccess1() { + JobID jobId = randomJobId(); + RegistrationID worker1 = new RegistrationID(); + + AssignmentTrackerImpl assignmentTracker = + createAssignmentTracker( + jobId, worker1, new InstanceID("worker1"), new EmptyShuffleWorkerGateway()); + + DataPartition.DataPartitionType dataPartitionType = null; + try { + dataPartitionType = + assignmentTracker.getDataPartitionType( + SSDOnlyLocalFileMapPartitionFactory.class.getCanonicalName()); + } catch (Throwable th) { + LOG.error("Get data partition type failed, ", th); + fail(); + } + assertEquals(DataPartition.DataPartitionType.MAP_PARTITION, dataPartitionType); + } + + @Test + public void testGetDataPartitionTypeSuccess2() { + JobID jobId = randomJobId(); + RegistrationID worker1 = new RegistrationID(); + + AssignmentTrackerImpl assignmentTracker = + createAssignmentTracker( + jobId, worker1, new InstanceID("worker1"), new EmptyShuffleWorkerGateway()); + + DataPartition.DataPartitionType dataPartitionType = null; + try { + dataPartitionType = + assignmentTracker.getDataPartitionType( + HDDOnlyLocalFileMapPartitionFactory.class.getCanonicalName()); + } catch (Throwable th) { + LOG.error("Get data partition type failed, ", th); + fail(); + } + assertEquals(DataPartition.DataPartitionType.MAP_PARTITION, dataPartitionType); + } + + @Test(expected = ShuffleResourceAllocationException.class) + public void testGetDataPartitionTypeFailed() throws ShuffleResourceAllocationException { + JobID jobId = randomJobId(); + RegistrationID worker1 = new RegistrationID(); + + AssignmentTrackerImpl assignmentTracker = + createAssignmentTracker( + jobId, worker1, new InstanceID("worker1"), new EmptyShuffleWorkerGateway()); + + assignmentTracker.getDataPartitionType("a.b.c.d"); + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/registration/RegisteredRpcConnectionTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/registration/RegisteredRpcConnectionTest.java new file mode 100644 index 00000000..33273296 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/registration/RegisteredRpcConnectionTest.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.registration; + +import com.alibaba.flink.shuffle.core.utils.TestLogger; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; +import com.alibaba.flink.shuffle.rpc.test.TestingRpcService; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.slf4j.LoggerFactory; + +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** Tests for RegisteredRpcConnection, validating the successful, failure and close behavior. */ +public class RegisteredRpcConnectionTest extends TestLogger { + + private TestingRpcService rpcService; + + @Before + public void setup() { + rpcService = new TestingRpcService(); + } + + @After + public void tearDown() throws ExecutionException, InterruptedException { + if (rpcService != null) { + rpcService.stopService().get(); + } + } + + @Test + public void testSuccessfulRpcConnection() throws Exception { + final String testRpcConnectionEndpointAddress = ""; + final UUID leaderId = UUID.randomUUID(); + final String connectionID = "Test RPC Connection ID"; + + // an endpoint that immediately returns success + TestRegistrationGateway testGateway = + new TestRegistrationGateway( + new RetryingRegistrationTest.TestRegistrationSuccess(connectionID)); + + try { + rpcService.registerGateway(testRpcConnectionEndpointAddress, testGateway); + + TestRpcConnection connection = + new TestRpcConnection( + testRpcConnectionEndpointAddress, + leaderId, + rpcService.getExecutor(), + rpcService); + connection.start(); + + // wait for connection established + final String actualConnectionId = connection.getConnectionFuture().get(); + + // validate correct invocation and result + assertTrue(connection.isConnected()); + assertEquals(testRpcConnectionEndpointAddress, connection.getTargetAddress()); + assertEquals(leaderId, connection.getTargetLeaderId()); + assertEquals(testGateway, connection.getTargetGateway()); + assertEquals(connectionID, actualConnectionId); + } finally { + testGateway.stop(); + } + } + + @Test + public void testRpcConnectionFailures() throws Exception { + final String connectionFailureMessage = "Test RPC Connection failure"; + final String testRpcConnectionEndpointAddress = ""; + final UUID leaderId = UUID.randomUUID(); + + // gateway that upon calls Throw an exception + TestRegistrationGateway testGateway = mock(TestRegistrationGateway.class); + final RuntimeException registrationException = + new RuntimeException(connectionFailureMessage); + when(testGateway.registrationCall(any(UUID.class))).thenThrow(registrationException); + + rpcService.registerGateway(testRpcConnectionEndpointAddress, testGateway); + + TestRpcConnection connection = + new TestRpcConnection( + testRpcConnectionEndpointAddress, + leaderId, + rpcService.getExecutor(), + rpcService); + connection.start(); + + // wait for connection failure + try { + connection.getConnectionFuture().get(); + fail("expected failure."); + } catch (ExecutionException ee) { + assertEquals(registrationException, ee.getCause()); + } + + // validate correct invocation and result + assertFalse(connection.isConnected()); + assertEquals(testRpcConnectionEndpointAddress, connection.getTargetAddress()); + assertEquals(leaderId, connection.getTargetLeaderId()); + assertNull(connection.getTargetGateway()); + } + + @Test + public void testRpcConnectionClose() { + final String testRpcConnectionEndpointAddress = ""; + final UUID leaderId = UUID.randomUUID(); + final String connectionID = "Test RPC Connection ID"; + + TestRegistrationGateway testGateway = + new TestRegistrationGateway( + new RetryingRegistrationTest.TestRegistrationSuccess(connectionID)); + + try { + rpcService.registerGateway(testRpcConnectionEndpointAddress, testGateway); + + TestRpcConnection connection = + new TestRpcConnection( + testRpcConnectionEndpointAddress, + leaderId, + rpcService.getExecutor(), + rpcService); + connection.start(); + // close the connection + connection.close(); + + // validate connection is closed + assertEquals(testRpcConnectionEndpointAddress, connection.getTargetAddress()); + assertEquals(leaderId, connection.getTargetLeaderId()); + assertTrue(connection.isClosed()); + } finally { + testGateway.stop(); + } + } + + @Test + public void testReconnect() throws Exception { + final String connectionId1 = "Test RPC Connection ID 1"; + final String connectionId2 = "Test RPC Connection ID 2"; + final String testRpcConnectionEndpointAddress = ""; + final UUID leaderId = UUID.randomUUID(); + final TestRegistrationGateway testGateway = + new TestRegistrationGateway( + new RetryingRegistrationTest.TestRegistrationSuccess(connectionId1), + new RetryingRegistrationTest.TestRegistrationSuccess(connectionId2)); + + rpcService.registerGateway(testRpcConnectionEndpointAddress, testGateway); + + TestRpcConnection connection = + new TestRpcConnection( + testRpcConnectionEndpointAddress, + leaderId, + rpcService.getExecutor(), + rpcService); + connection.start(); + + final String actualConnectionId1 = connection.getConnectionFuture().get(); + + assertEquals(actualConnectionId1, connectionId1); + + assertTrue(connection.tryReconnect()); + + final String actualConnectionId2 = connection.getConnectionFuture().get(); + + assertEquals(actualConnectionId2, connectionId2); + } + + // ------------------------------------------------------------------------ + // test RegisteredRpcConnection + // ------------------------------------------------------------------------ + + private static class TestRpcConnection + extends RegisteredRpcConnection< + UUID, + TestRegistrationGateway, + RetryingRegistrationTest.TestRegistrationSuccess> { + + private final Object lock = new Object(); + + private final RemoteShuffleRpcService rpcService; + + private CompletableFuture connectionFuture; + + public TestRpcConnection( + String targetAddress, + UUID targetLeaderId, + Executor executor, + RemoteShuffleRpcService rpcService) { + super( + LoggerFactory.getLogger(RegisteredRpcConnectionTest.class), + targetAddress, + targetLeaderId, + executor); + this.rpcService = rpcService; + this.connectionFuture = new CompletableFuture<>(); + } + + @Override + protected RetryingRegistration< + UUID, + TestRegistrationGateway, + RetryingRegistrationTest.TestRegistrationSuccess> + generateRegistration() { + return new RetryingRegistrationTest.TestRetryingRegistration( + rpcService, getTargetAddress(), getTargetLeaderId()); + } + + @Override + protected void onRegistrationSuccess( + RetryingRegistrationTest.TestRegistrationSuccess success) { + synchronized (lock) { + connectionFuture.complete(success.getCorrelationId()); + } + } + + @Override + protected void onRegistrationFailure(Throwable failure) { + synchronized (lock) { + connectionFuture.completeExceptionally(failure); + } + } + + @Override + public boolean tryReconnect() { + synchronized (lock) { + connectionFuture.cancel(false); + connectionFuture = new CompletableFuture<>(); + } + return super.tryReconnect(); + } + + public CompletableFuture getConnectionFuture() { + synchronized (lock) { + return connectionFuture; + } + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/registration/RetryingRegistrationConfigurationTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/registration/RetryingRegistrationConfigurationTest.java new file mode 100644 index 00000000..ed0e4939 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/registration/RetryingRegistrationConfigurationTest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.registration; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; + +import org.junit.Test; + +import java.time.Duration; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +/** Tests for the {@link RetryingRegistrationConfiguration}. */ +public class RetryingRegistrationConfigurationTest { + + @Test + public void testConfigurationParsing() { + final Configuration configuration = new Configuration(); + final Duration refusedRegistrationDelay = Duration.ofMillis(3); + final Duration errorRegistrationDelay = Duration.ofMillis(4); + + configuration.setDuration( + ClusterOptions.REFUSED_REGISTRATION_DELAY, refusedRegistrationDelay); + configuration.setDuration(ClusterOptions.ERROR_REGISTRATION_DELAY, errorRegistrationDelay); + + final RetryingRegistrationConfiguration retryingRegistrationConfiguration = + RetryingRegistrationConfiguration.fromConfiguration(configuration); + assertThat( + retryingRegistrationConfiguration.getRefusedDelayMillis(), + is(refusedRegistrationDelay.toMillis())); + assertThat( + retryingRegistrationConfiguration.getErrorDelayMillis(), + is(errorRegistrationDelay.toMillis())); + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/registration/RetryingRegistrationTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/registration/RetryingRegistrationTest.java new file mode 100644 index 00000000..36d8e341 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/registration/RetryingRegistrationTest.java @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.registration; + +import com.alibaba.flink.shuffle.common.utils.FutureUtils; +import com.alibaba.flink.shuffle.core.utils.TestLogger; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; +import com.alibaba.flink.shuffle.rpc.test.TestingRpcService; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.slf4j.LoggerFactory; + +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.atMost; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for the generic retrying registration class, validating the failure, retry, and back-off + * behavior. + */ +public class RetryingRegistrationTest extends TestLogger { + + private TestingRpcService rpcService; + + @Before + public void setup() { + rpcService = new TestingRpcService(); + } + + @After + public void tearDown() throws ExecutionException, InterruptedException { + if (rpcService != null) { + rpcService.stopService().get(); + } + } + + @Test + public void testSimpleSuccessfulRegistration() throws Exception { + final String testId = "laissez les bon temps roulez"; + final String testEndpointAddress = ""; + final UUID leaderId = UUID.randomUUID(); + + // an endpoint that immediately returns success + TestRegistrationGateway testGateway = + new TestRegistrationGateway(new TestRegistrationSuccess(testId)); + + try { + rpcService.registerGateway(testEndpointAddress, testGateway); + + TestRetryingRegistration registration = + new TestRetryingRegistration(rpcService, testEndpointAddress, leaderId); + registration.startRegistration(); + + CompletableFuture> future = + registration.getFuture(); + assertNotNull(future); + + // multiple accesses return the same future + assertEquals(future, registration.getFuture()); + + Pair success = + future.get(10L, TimeUnit.SECONDS); + + // validate correct invocation and result + assertEquals(testId, success.getRight().getCorrelationId()); + assertEquals(leaderId, testGateway.getInvocations().take().leaderId()); + } finally { + testGateway.stop(); + } + } + + @Test + public void testPropagateFailures() throws Exception { + final String testExceptionMessage = "testExceptionMessage"; + + // RPC service that fails with exception upon the connection + RemoteShuffleRpcService rpcService = mock(RemoteShuffleRpcService.class); + when(rpcService.connectTo(anyString(), any(Class.class))) + .thenThrow(new RuntimeException(testExceptionMessage)); + + TestRetryingRegistration registration = + new TestRetryingRegistration(rpcService, "testaddress", UUID.randomUUID()); + registration.startRegistration(); + + CompletableFuture future = registration.getFuture(); + assertTrue(future.isDone()); + + try { + future.get(); + + fail("We expected an ExecutionException."); + } catch (ExecutionException e) { + assertEquals(testExceptionMessage, e.getCause().getMessage()); + } + } + + @Test + public void testRetryConnectOnFailure() throws Exception { + final String testId = "laissez les bon temps roulez"; + final UUID leaderId = UUID.randomUUID(); + + ExecutorService executor = Executors.newSingleThreadScheduledExecutor(); + TestRegistrationGateway testGateway = + new TestRegistrationGateway(new TestRegistrationSuccess(testId)); + + try { + // RPC service that fails upon the first connection, but succeeds on the second + RemoteShuffleRpcService rpcService = mock(RemoteShuffleRpcService.class); + when(rpcService.connectTo(anyString(), any(Class.class))) + .thenReturn( + FutureUtils.completedExceptionally( + new Exception( + "test connect failure")), // first connection attempt + // fails + CompletableFuture.completedFuture( + testGateway) // second connection attempt succeeds + ); + when(rpcService.getExecutor()).thenReturn(executor); + when(rpcService.scheduleRunnable(any(Runnable.class), anyLong(), any(TimeUnit.class))) + .thenAnswer( + (InvocationOnMock invocation) -> { + final Runnable runnable = invocation.getArgument(0); + final long delay = invocation.getArgument(1); + final TimeUnit timeUnit = invocation.getArgument(2); + return Executors.newSingleThreadScheduledExecutor() + .schedule(runnable, delay, timeUnit); + }); + + TestRetryingRegistration registration = + new TestRetryingRegistration(rpcService, "foobar address", leaderId); + + long start = System.currentTimeMillis(); + + registration.startRegistration(); + + Pair success = + registration.getFuture().get(10L, TimeUnit.SECONDS); + + // measure the duration of the registration --> should be longer than the error delay + long duration = System.currentTimeMillis() - start; + + assertTrue( + "The registration should have failed the first time. Thus the duration should be longer than at least a single error delay.", + duration > TestRetryingRegistration.DELAY_ON_ERROR); + + // validate correct invocation and result + assertEquals(testId, success.getRight().getCorrelationId()); + assertEquals(leaderId, testGateway.getInvocations().take().leaderId()); + } finally { + testGateway.stop(); + } + } + + @Test(timeout = 10000) + public void testRetriesOnTimeouts() throws Exception { + final String testId = "rien ne va plus"; + final String testEndpointAddress = ""; + final UUID leaderId = UUID.randomUUID(); + + // an endpoint that immediately returns futures with timeouts before returning a successful + // future + TestRegistrationGateway testGateway = + new TestRegistrationGateway( + null, // timeout + null, // timeout + new TestRegistrationSuccess(testId) // success + ); + + try { + rpcService.registerGateway(testEndpointAddress, testGateway); + TestRetryingRegistration registration = + new TestRetryingRegistration( + rpcService, + testEndpointAddress, + leaderId, + new RetryingRegistrationConfiguration( + 15000L, // make sure that we timeout in case of an error + 15000L)); + registration.startRegistration(); + + CompletableFuture> future = + registration.getFuture(); + Pair success = + future.get(10L, TimeUnit.SECONDS); + + // validate correct invocation and result + assertEquals(testId, success.getRight().getCorrelationId()); + assertEquals(leaderId, testGateway.getInvocations().take().leaderId()); + } finally { + testGateway.stop(); + } + } + + @Test + public void testDecline() throws Exception { + final String testId = "qui a coupe le fromage"; + final String testEndpointAddress = ""; + final UUID leaderId = UUID.randomUUID(); + + TestRegistrationGateway testGateway = + new TestRegistrationGateway( + null, // timeout + new RegistrationResponse.Decline("no reason "), + null, // timeout + new TestRegistrationSuccess(testId) // success + ); + + try { + rpcService.registerGateway(testEndpointAddress, testGateway); + + TestRetryingRegistration registration = + new TestRetryingRegistration(rpcService, testEndpointAddress, leaderId); + registration.startRegistration(); + + CompletableFuture> future = + registration.getFuture(); + Pair success = + future.get(10L, TimeUnit.SECONDS); + + // validate correct invocation and result + assertEquals(testId, success.getRight().getCorrelationId()); + assertEquals(leaderId, testGateway.getInvocations().take().leaderId()); + } finally { + testGateway.stop(); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testRetryOnError() throws Exception { + final String testId = "Petit a petit, l'oiseau fait son nid"; + final String testEndpointAddress = ""; + final UUID leaderId = UUID.randomUUID(); + + // gateway that upon calls first responds with a failure, then with a success + TestRegistrationGateway testGateway = mock(TestRegistrationGateway.class); + + when(testGateway.registrationCall(any(UUID.class))) + .thenReturn( + FutureUtils.completedExceptionally(new Exception("test exception")), + CompletableFuture.completedFuture(new TestRegistrationSuccess(testId))); + + rpcService.registerGateway(testEndpointAddress, testGateway); + + TestRetryingRegistration registration = + new TestRetryingRegistration(rpcService, testEndpointAddress, leaderId); + + long started = System.nanoTime(); + registration.startRegistration(); + + CompletableFuture> future = + registration.getFuture(); + Pair success = + future.get(10, TimeUnit.SECONDS); + + long finished = System.nanoTime(); + long elapsedMillis = (finished - started) / 1000000; + + assertEquals(testId, success.getRight().getCorrelationId()); + + // validate that some retry-delay / back-off behavior happened + assertTrue( + "retries did not properly back off", + elapsedMillis >= TestRetryingRegistration.DELAY_ON_ERROR); + } + + @Test + public void testCancellation() throws Exception { + final String testEndpointAddress = "my-test-address"; + final UUID leaderId = UUID.randomUUID(); + + CompletableFuture result = new CompletableFuture<>(); + + TestRegistrationGateway testGateway = mock(TestRegistrationGateway.class); + when(testGateway.registrationCall(any(UUID.class))).thenReturn(result); + + rpcService.registerGateway(testEndpointAddress, testGateway); + + TestRetryingRegistration registration = + new TestRetryingRegistration(rpcService, testEndpointAddress, leaderId); + registration.startRegistration(); + + // cancel and fail the current registration attempt + registration.cancel(); + result.completeExceptionally(new TimeoutException()); + + // there should not be a second registration attempt + verify(testGateway, atMost(1)).registrationCall(any(UUID.class)); + } + + // ------------------------------------------------------------------------ + // test registration + // ------------------------------------------------------------------------ + + static class TestRegistrationSuccess extends RegistrationResponse.Success { + private static final long serialVersionUID = 5542698790917150604L; + + private final String correlationId; + + public TestRegistrationSuccess(String correlationId) { + this.correlationId = correlationId; + } + + public String getCorrelationId() { + return correlationId; + } + } + + static class TestRetryingRegistration + extends RetryingRegistration { + + // we use shorter timeouts here to speed up the tests + static final long DELAY_ON_ERROR = 200; + static final long DELAY_ON_DECLINE = 200; + static final RetryingRegistrationConfiguration RETRYING_REGISTRATION_CONFIGURATION = + new RetryingRegistrationConfiguration(DELAY_ON_ERROR, DELAY_ON_DECLINE); + + public TestRetryingRegistration( + RemoteShuffleRpcService rpcService, String targetAddress, UUID leaderId) { + this(rpcService, targetAddress, leaderId, RETRYING_REGISTRATION_CONFIGURATION); + } + + public TestRetryingRegistration( + RemoteShuffleRpcService rpcService, + String targetAddress, + UUID leaderId, + RetryingRegistrationConfiguration retryingRegistrationConfiguration) { + super( + LoggerFactory.getLogger(RetryingRegistrationTest.class), + rpcService, + "TestEndpoint", + TestRegistrationGateway.class, + targetAddress, + leaderId, + retryingRegistrationConfiguration); + } + + @Override + protected CompletableFuture invokeRegistration( + TestRegistrationGateway gateway, UUID leaderId) { + return gateway.registrationCall(leaderId); + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/registration/TestRegistrationGateway.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/registration/TestRegistrationGateway.java new file mode 100644 index 00000000..c4deb27a --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/registration/TestRegistrationGateway.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.registration; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.LinkedBlockingQueue; + +/** Mock gateway for {@link RegistrationResponse}. */ +public class TestRegistrationGateway extends TestingGatewayBase { + + private final BlockingQueue invocations; + + private final RegistrationResponse[] responses; + + private int pos; + + public TestRegistrationGateway(RegistrationResponse... responses) { + CommonUtils.checkArgument(responses != null && responses.length > 0); + + this.invocations = new LinkedBlockingQueue<>(); + this.responses = responses; + } + + // ------------------------------------------------------------------------ + + public CompletableFuture registrationCall(UUID leaderId) { + invocations.add(new RegistrationCall(leaderId, 10)); + + RegistrationResponse response = responses[pos]; + if (pos < responses.length - 1) { + pos++; + } + + // return a completed future (for a proper value), or one that never completes and will time + // out (for null) + return response != null + ? CompletableFuture.completedFuture(response) + : futureWithTimeout(10); + } + + public BlockingQueue getInvocations() { + return invocations; + } + + // ------------------------------------------------------------------------ + + /** Invocation parameters. */ + public static class RegistrationCall { + private final UUID leaderId; + private final long timeout; + + public RegistrationCall(UUID leaderId, long timeout) { + this.leaderId = leaderId; + this.timeout = timeout; + } + + public UUID leaderId() { + return leaderId; + } + + public long timeout() { + return timeout; + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/registration/TestingGatewayBase.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/registration/TestingGatewayBase.java new file mode 100644 index 00000000..8570493d --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/registration/TestingGatewayBase.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.registration; + +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcGateway; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** Utility base class for testing gateways. */ +public abstract class TestingGatewayBase implements RemoteShuffleRpcGateway { + + private final ScheduledExecutorService executor; + + private final String address; + + protected TestingGatewayBase(final String address) { + this.executor = Executors.newSingleThreadScheduledExecutor(); + this.address = address; + } + + protected TestingGatewayBase() { + this("localhost"); + } + + // ------------------------------------------------------------------------ + // shutdown + // ------------------------------------------------------------------------ + + public void stop() { + executor.shutdownNow(); + } + + @Override + protected void finalize() throws Throwable { + super.finalize(); + executor.shutdownNow(); + } + + // ------------------------------------------------------------------------ + // Base class methods + // ------------------------------------------------------------------------ + + @Override + public String getAddress() { + return address; + } + + @Override + public String getHostname() { + return address; + } + + // ------------------------------------------------------------------------ + // utilities + // ------------------------------------------------------------------------ + + public CompletableFuture futureWithTimeout(long timeoutMillis) { + CompletableFuture future = new CompletableFuture<>(); + executor.schedule(new FutureTimeout(future), timeoutMillis, TimeUnit.MILLISECONDS); + return future; + } + + // ------------------------------------------------------------------------ + + private static final class FutureTimeout implements Runnable { + + private final CompletableFuture promise; + + private FutureTimeout(CompletableFuture promise) { + this.promise = promise; + } + + @Override + public void run() { + try { + promise.completeExceptionally(new TimeoutException()); + } catch (Throwable t) { + System.err.println("CAUGHT AN ERROR IN THE TEST: " + t.getMessage()); + t.printStackTrace(); + } + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/CommonTestUtils.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/CommonTestUtils.java new file mode 100644 index 00000000..61b64159 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/CommonTestUtils.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.utils; + +import com.alibaba.flink.shuffle.common.functions.SupplierWithException; + +import java.lang.management.ManagementFactory; +import java.lang.management.RuntimeMXBean; +import java.util.concurrent.TimeoutException; + +/** This class contains auxiliary methods for unit tests. */ +public class CommonTestUtils { + + private static final long RETRY_INTERVAL = 100L; + + /** + * Gets the classpath with which the current JVM was started. + * + * @return The classpath with which the current JVM was started. + */ + public static String getCurrentClasspath() { + RuntimeMXBean bean = ManagementFactory.getRuntimeMXBean(); + return bean.getClassPath(); + } + + public static void waitUntilCondition( + SupplierWithException condition, long timeoutMillis) + throws Exception { + waitUntilCondition(condition, timeoutMillis, RETRY_INTERVAL); + } + + public static void waitUntilCondition( + SupplierWithException condition, + long timeoutMillis, + long retryIntervalMillis) + throws Exception { + waitUntilCondition( + condition, + timeoutMillis, + retryIntervalMillis, + "Condition was not met in given timeout."); + } + + public static void waitUntilCondition( + SupplierWithException condition, + long timeoutMillis, + String errorMsg) + throws Exception { + waitUntilCondition(condition, timeoutMillis, RETRY_INTERVAL, errorMsg); + } + + public static void waitUntilCondition( + SupplierWithException condition, + long timeoutMillis, + long retryIntervalMillis, + String errorMsg) + throws Exception { + long timeLeft = timeoutMillis; + long deadlineTime = System.nanoTime() / 1000000 + timeoutMillis; + while (timeLeft > 0 && !condition.get()) { + Thread.sleep(Math.min(retryIntervalMillis, timeLeft)); + timeLeft = deadlineTime - System.nanoTime() / 1000000; + } + + if (timeLeft <= 0) { + throw new TimeoutException(errorMsg); + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/EmptyMetaStore.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/EmptyMetaStore.java new file mode 100644 index 00000000..54d5ded3 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/EmptyMetaStore.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.utils; + +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionStatus; +import com.alibaba.flink.shuffle.coordinator.worker.metastore.Metastore; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiConsumer; + +/** An empty meta store implementation. */ +public class EmptyMetaStore implements Metastore { + + @Override + public void setPartitionRemovedConsumer( + BiConsumer partitionRemovedConsumer) {} + + @Override + public List listDataPartitions() throws Exception { + return new ArrayList<>(); + } + + @Override + public void removeReleasingDataPartition(DataPartitionCoordinate dataPartitionCoordinate) {} + + @Override + public int getSize() { + return 0; + } + + @Override + public void onPartitionCreated(DataPartitionMeta partitionMeta) throws Exception {} + + @Override + public void onPartitionRemoved(DataPartitionMeta partitionMeta) {} + + @Override + public void close() throws Exception {} +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/EmptyPartitionedDataStore.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/EmptyPartitionedDataStore.java new file mode 100644 index 00000000..7bab7c89 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/EmptyPartitionedDataStore.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.utils; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.executor.SingleThreadExecutorPool; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.memory.BufferDispatcher; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReadingView; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWritingView; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.core.storage.ReadingViewContext; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; +import com.alibaba.flink.shuffle.core.storage.WritingViewContext; + +import javax.annotation.Nullable; + +/** An empty partitioned data store used for tests. */ +public class EmptyPartitionedDataStore implements PartitionedDataStore { + + @Override + public DataPartitionWritingView createDataPartitionWritingView(WritingViewContext context) { + return null; + } + + @Override + public DataPartitionReadingView createDataPartitionReadingView(ReadingViewContext context) { + return null; + } + + @Override + public boolean isDataPartitionConsumable(DataPartitionMeta partitionMeta) { + return false; + } + + @Override + public void addDataPartition(DataPartitionMeta partitionMeta) throws Exception {} + + @Override + public void removeDataPartition(DataPartitionMeta partitionMeta) {} + + @Override + public void releaseDataPartition( + DataSetID dataSetID, DataPartitionID partitionID, @Nullable Throwable throwable) {} + + @Override + public void releaseDataSet(DataSetID dataSetID, @Nullable Throwable throwable) {} + + @Override + public void releaseDataByJobID(JobID jobID, @Nullable Throwable throwable) {} + + @Override + public void shutDown(boolean releaseData) {} + + @Override + public boolean isShutDown() { + return false; + } + + @Override + public Configuration getConfiguration() { + return null; + } + + @Override + public BufferDispatcher getWritingBufferDispatcher() { + return null; + } + + @Override + public BufferDispatcher getReadingBufferDispatcher() { + return null; + } + + @Override + public SingleThreadExecutorPool getExecutorPool(StorageMeta storageMeta) { + return null; + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/EmptyShuffleWorkerGateway.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/EmptyShuffleWorkerGateway.java new file mode 100644 index 00000000..c2e3917c --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/EmptyShuffleWorkerGateway.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.utils; + +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerGateway; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerMetrics; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.rpc.message.Acknowledge; + +import java.util.concurrent.CompletableFuture; + +/** A test empty shuffle worker gateway implementation. */ +public class EmptyShuffleWorkerGateway implements ShuffleWorkerGateway { + + @Override + public void heartbeatFromManager(InstanceID managerID) {} + + @Override + public void disconnectManager(Exception cause) {} + + @Override + public CompletableFuture releaseDataPartition( + JobID jobID, DataSetID dataSetID, DataPartitionID dataPartitionID) { + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + @Override + public CompletableFuture removeReleasedDataPartitionMeta( + JobID jobID, DataSetID dataSetID, DataPartitionID dataPartitionID) { + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + @Override + public CompletableFuture getWorkerMetrics() { + return CompletableFuture.completedFuture(null); + } + + @Override + public String getAddress() { + return ""; + } + + @Override + public String getHostname() { + return ""; + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/EnvironmentInformationTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/EnvironmentInformationTest.java new file mode 100644 index 00000000..93507f79 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/EnvironmentInformationTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.utils; + +import com.alibaba.flink.shuffle.core.utils.TestLogger; + +import org.junit.Test; +import org.mockito.Mockito; +import org.slf4j.Logger; + +import java.util.Arrays; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** Test for {@link EnvironmentInformation}. */ +public class EnvironmentInformationTest extends TestLogger { + + @Test + public void testJavaMemory() { + try { + long fullHeap = EnvironmentInformation.getMaxJvmHeapMemory(); + long freeWithGC = EnvironmentInformation.getSizeOfFreeHeapMemoryWithDefrag(); + + assertTrue(fullHeap > 0); + assertTrue(freeWithGC >= 0); + + try { + long free = EnvironmentInformation.getSizeOfFreeHeapMemory(); + assertTrue(free >= 0); + } catch (RuntimeException e) { + // this may only occur if the Xmx is not set + assertEquals(Long.MAX_VALUE, EnvironmentInformation.getMaxJvmHeapMemory()); + } + + // we cannot make these assumptions, because the test JVM may grow / shrink during the + // GC + // assertTrue(free <= fullHeap); + // assertTrue(freeWithGC <= fullHeap); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testEnvironmentMethods() { + try { + assertNotNull(EnvironmentInformation.getJvmStartupOptions()); + assertNotNull(EnvironmentInformation.getJvmStartupOptionsArray()); + assertNotNull(EnvironmentInformation.getJvmVersion()); + assertNotNull(EnvironmentInformation.getRevisionInformation()); + assertNotNull(EnvironmentInformation.getVersion()); + assertNotNull(EnvironmentInformation.getBuildTime()); + assertNotNull(EnvironmentInformation.getBuildTimeString()); + assertNotNull(EnvironmentInformation.getGitCommitId()); + assertNotNull(EnvironmentInformation.getGitCommitIdAbbrev()); + assertNotNull(EnvironmentInformation.getGitCommitTime()); + assertNotNull(EnvironmentInformation.getGitCommitTimeString()); + assertTrue(EnvironmentInformation.getOpenFileHandlesLimit() >= -1); + + if (log.isInfoEnabled()) { + // Visual inspection of the available Environment variables + // To actually see it set "rootLogger.level = INFO" in "log4j2-test.properties" + log.info( + "JvmStartupOptions : {}", + EnvironmentInformation.getJvmStartupOptions()); + log.info( + "JvmStartupOptionsArray : {}", + Arrays.asList(EnvironmentInformation.getJvmStartupOptionsArray())); + log.info("JvmVersion : {}", EnvironmentInformation.getJvmVersion()); + log.info( + "RevisionInformation : {}", + EnvironmentInformation.getRevisionInformation()); + log.info("Version : {}", EnvironmentInformation.getVersion()); + log.info("BuildTime : {}", EnvironmentInformation.getBuildTime()); + log.info( + "BuildTimeString : {}", EnvironmentInformation.getBuildTimeString()); + log.info("GitCommitId : {}", EnvironmentInformation.getGitCommitId()); + log.info( + "GitCommitIdAbbrev : {}", + EnvironmentInformation.getGitCommitIdAbbrev()); + log.info("GitCommitTime : {}", EnvironmentInformation.getGitCommitTime()); + log.info( + "GitCommitTimeString : {}", + EnvironmentInformation.getGitCommitTimeString()); + log.info( + "OpenFileHandlesLimit : {}", + EnvironmentInformation.getOpenFileHandlesLimit()); + } + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testLogEnvironmentInformation() { + try { + Logger mockLogger = Mockito.mock(Logger.class); + EnvironmentInformation.logEnvironmentInfo(mockLogger, "test", new String[0]); + EnvironmentInformation.logEnvironmentInfo(mockLogger, "test", null); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/RandomIDUtils.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/RandomIDUtils.java new file mode 100644 index 00000000..34f79a59 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/RandomIDUtils.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.utils; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; + +/** Utilities to generate id. */ +public class RandomIDUtils { + + public static JobID randomJobId() { + return new JobID(CommonUtils.randomBytes(16)); + } + + public static DataSetID randomDataSetId() { + return new DataSetID(CommonUtils.randomBytes(16)); + } + + public static MapPartitionID randomMapPartitionId() { + return new MapPartitionID(CommonUtils.randomBytes(16)); + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/RecordingHeartbeatServices.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/RecordingHeartbeatServices.java new file mode 100644 index 00000000..8bd6c3a8 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/RecordingHeartbeatServices.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.utils; + +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatListener; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatManager; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatManagerImpl; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatManagerSenderImpl; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServices; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatTarget; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.rpc.executor.ScheduledExecutor; + +import org.slf4j.Logger; + +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; + +/** Special {@link HeartbeatServices} which records the unmonitored targets. */ +public class RecordingHeartbeatServices extends HeartbeatServices { + + private final BlockingQueue unmonitoredTargets; + + private final BlockingQueue monitoredTargets; + + public RecordingHeartbeatServices(long heartbeatInterval, long heartbeatTimeout) { + super(heartbeatInterval, heartbeatTimeout); + + this.unmonitoredTargets = new ArrayBlockingQueue<>(1); + this.monitoredTargets = new ArrayBlockingQueue<>(1); + } + + @Override + public HeartbeatManager createHeartbeatManager( + InstanceID instanceID, + HeartbeatListener heartbeatListener, + ScheduledExecutor mainThreadExecutor, + Logger log) { + return new RecordingHeartbeatManagerImpl<>( + heartbeatTimeout, + instanceID, + heartbeatListener, + mainThreadExecutor, + log, + unmonitoredTargets, + monitoredTargets); + } + + @Override + public HeartbeatManager createHeartbeatManagerSender( + InstanceID instanceID, + HeartbeatListener heartbeatListener, + ScheduledExecutor mainThreadExecutor, + Logger log) { + return new RecordingHeartbeatManagerSenderImpl<>( + heartbeatInterval, + heartbeatTimeout, + instanceID, + heartbeatListener, + mainThreadExecutor, + log, + unmonitoredTargets, + monitoredTargets); + } + + public BlockingQueue getUnmonitoredTargets() { + return unmonitoredTargets; + } + + public BlockingQueue getMonitoredTargets() { + return monitoredTargets; + } + + /** {@link HeartbeatManagerImpl} which records the unmonitored targets. */ + private static final class RecordingHeartbeatManagerImpl + extends HeartbeatManagerImpl { + + private final BlockingQueue unmonitoredTargets; + + private final BlockingQueue monitoredTargets; + + public RecordingHeartbeatManagerImpl( + long heartbeatTimeoutIntervalMs, + InstanceID ownInstanceID, + HeartbeatListener heartbeatListener, + ScheduledExecutor mainThreadExecutor, + Logger log, + BlockingQueue unmonitoredTargets, + BlockingQueue monitoredTargets) { + super( + heartbeatTimeoutIntervalMs, + ownInstanceID, + heartbeatListener, + mainThreadExecutor, + log); + this.unmonitoredTargets = unmonitoredTargets; + this.monitoredTargets = monitoredTargets; + } + + @Override + public void unmonitorTarget(InstanceID instanceID) { + super.unmonitorTarget(instanceID); + unmonitoredTargets.offer(instanceID); + } + + @Override + public void monitorTarget(InstanceID instanceID, HeartbeatTarget heartbeatTarget) { + super.monitorTarget(instanceID, heartbeatTarget); + monitoredTargets.offer(instanceID); + } + } + + /** {@link HeartbeatManagerSenderImpl} which records the unmonitored targets. */ + private static final class RecordingHeartbeatManagerSenderImpl + extends HeartbeatManagerSenderImpl { + + private final BlockingQueue unmonitoredTargets; + + private final BlockingQueue monitoredTargets; + + public RecordingHeartbeatManagerSenderImpl( + long heartbeatPeriod, + long heartbeatTimeout, + InstanceID ownInstanceID, + HeartbeatListener heartbeatListener, + ScheduledExecutor mainThreadExecutor, + Logger log, + BlockingQueue unmonitoredTargets, + BlockingQueue monitoredTargets) { + super( + heartbeatPeriod, + heartbeatTimeout, + ownInstanceID, + heartbeatListener, + mainThreadExecutor, + log); + this.unmonitoredTargets = unmonitoredTargets; + this.monitoredTargets = monitoredTargets; + } + + @Override + public void unmonitorTarget(InstanceID instanceID) { + super.unmonitorTarget(instanceID); + unmonitoredTargets.offer(instanceID); + } + + @Override + public void monitorTarget(InstanceID instanceID, HeartbeatTarget heartbeatTarget) { + super.monitorTarget(instanceID, heartbeatTarget); + monitoredTargets.offer(instanceID); + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/TestingFatalErrorHandler.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/TestingFatalErrorHandler.java new file mode 100644 index 00000000..59468b80 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/TestingFatalErrorHandler.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.utils; + +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** + * Testing fatal error handler which records the occurred exceptions during the execution of the + * tests. Captured exceptions are thrown as a {@link TestingException}. + */ +public class TestingFatalErrorHandler implements FatalErrorHandler { + private static final Logger LOG = LoggerFactory.getLogger(TestingFatalErrorHandler.class); + private CompletableFuture errorFuture; + + public TestingFatalErrorHandler() { + errorFuture = new CompletableFuture<>(); + } + + public synchronized void rethrowError() throws TestingException { + final Throwable throwable = getException(); + + if (throwable != null) { + throw new TestingException(throwable); + } + } + + public synchronized boolean hasExceptionOccurred() { + return errorFuture.isDone(); + } + + @Nullable + public synchronized Throwable getException() { + if (errorFuture.isDone()) { + Throwable throwable; + + try { + throwable = errorFuture.get(); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException( + "This should never happen since the future was completed."); + } catch (ExecutionException e) { + throwable = ExceptionUtils.stripException(e, ExecutionException.class); + } + + return throwable; + } else { + return null; + } + } + + public synchronized CompletableFuture getErrorFuture() { + return errorFuture; + } + + public synchronized void clearError() { + errorFuture = new CompletableFuture<>(); + } + + @Override + public synchronized void onFatalError(@Nonnull Throwable exception) { + LOG.error("OnFatalError:", exception); + + if (!errorFuture.complete(exception)) { + final Throwable throwable = checkNotNull(getException()); + throwable.addSuppressed(exception); + } + } + + // ------------------------------------------------------------------ + // static utility classes + // ------------------------------------------------------------------ + + private static final class TestingException extends Exception { + public TestingException(String message) { + super(message); + } + + public TestingException(String message, Throwable cause) { + super(message, cause); + } + + public TestingException(Throwable cause) { + super(cause); + } + + private static final long serialVersionUID = -4648195335470914498L; + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/TestingShuffleManagerGateway.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/TestingShuffleManagerGateway.java new file mode 100644 index 00000000..d20029ff --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/TestingShuffleManagerGateway.java @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.utils; + +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionStatus; +import com.alibaba.flink.shuffle.coordinator.manager.JobDataPartitionDistribution; +import com.alibaba.flink.shuffle.coordinator.manager.ManagerToJobHeartbeatPayload; +import com.alibaba.flink.shuffle.coordinator.manager.RegistrationSuccess; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleManagerGateway; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerRegistration; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerRegistrationSuccess; +import com.alibaba.flink.shuffle.coordinator.manager.WorkerToManagerHeartbeatPayload; +import com.alibaba.flink.shuffle.coordinator.manager.assignmenttracker.ChangedWorkerStatus; +import com.alibaba.flink.shuffle.coordinator.registration.RegistrationResponse; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerMetrics; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.RegistrationID; +import com.alibaba.flink.shuffle.rpc.message.Acknowledge; +import com.alibaba.flink.shuffle.utils.QuadFunction; +import com.alibaba.flink.shuffle.utils.TriFunction; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.Triple; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; + +/** A testing {@link ShuffleManagerGateway} implementation. */ +public class TestingShuffleManagerGateway implements ShuffleManagerGateway { + + private final String address; + + private final String hostname; + + private final UUID shuffleManagerId; + + private final InstanceID instanceID; + + private volatile Function> + registerShuffleWorkerConsumer; + + private volatile BiFunction< + Pair, + Triple, + CompletableFuture> + workerReportDataPartitionReleasedConsumer; + + private volatile TriFunction< + InstanceID, + RegistrationID, + List, + CompletableFuture> + reportShuffleDataStatusConsumer; + + private BiConsumer + heartbeatFromShuffleWorkerConsumer; + + private BiFunction> + disconnectShuffleWorkerConsumer; + + private Function> registerClientConsumer; + + private Function> unregisterClientConsumer; + + private BiFunction, CompletableFuture> + heartbeatFromClientConsumer; + + private QuadFunction< + JobID, DataSetID, MapPartitionID, Integer, CompletableFuture> + allocateShuffleResourceConsumer; + + private TriFunction> + releaseShuffleResourceConsumer; + + public TestingShuffleManagerGateway() { + this("localhost/" + UUID.randomUUID(), "localhost", UUID.randomUUID(), new InstanceID()); + } + + public TestingShuffleManagerGateway( + String address, String hostname, UUID shuffleManagerId, InstanceID instanceID) { + this.address = address; + this.hostname = hostname; + this.shuffleManagerId = shuffleManagerId; + this.instanceID = instanceID; + } + + // ------------------------------ setters ------------------------------------------ + + public void setRegisterShuffleWorkerConsumer( + Function> + registerShuffleWorkerConsumer) { + this.registerShuffleWorkerConsumer = registerShuffleWorkerConsumer; + } + + public void setWorkerReportDataPartitionReleasedConsumer( + BiFunction< + Pair, + Triple, + CompletableFuture> + workerReportDataPartitionReleasedConsumer) { + this.workerReportDataPartitionReleasedConsumer = workerReportDataPartitionReleasedConsumer; + } + + public void setReportShuffleDataStatusConsumer( + TriFunction< + InstanceID, + RegistrationID, + List, + CompletableFuture> + reportShuffleDataStatusConsumer) { + this.reportShuffleDataStatusConsumer = reportShuffleDataStatusConsumer; + } + + public void setHeartbeatFromShuffleWorkerConsumer( + BiConsumer + heartbeatFromShuffleWorkerConsumer) { + this.heartbeatFromShuffleWorkerConsumer = heartbeatFromShuffleWorkerConsumer; + } + + public void setDisconnectShuffleWorkerConsumer( + BiFunction> + disconnectShuffleWorkerConsumer) { + this.disconnectShuffleWorkerConsumer = disconnectShuffleWorkerConsumer; + } + + public void setRegisterClientConsumer( + Function> registerClientConsumer) { + this.registerClientConsumer = registerClientConsumer; + } + + public void setUnregisterClientConsumer( + Function> unregisterClientConsumer) { + this.unregisterClientConsumer = unregisterClientConsumer; + } + + public void setHeartbeatFromClientConsumer( + BiFunction, CompletableFuture> + heartbeatFromClientConsumer) { + this.heartbeatFromClientConsumer = heartbeatFromClientConsumer; + } + + public void setAllocateShuffleResourceConsumer( + QuadFunction< + JobID, + DataSetID, + MapPartitionID, + Integer, + CompletableFuture> + allocateShuffleResourceConsumer) { + this.allocateShuffleResourceConsumer = allocateShuffleResourceConsumer; + } + + public void setReleaseShuffleResourceConsumer( + TriFunction> + releaseShuffleResourceConsumer) { + this.releaseShuffleResourceConsumer = releaseShuffleResourceConsumer; + } + + // ------------------------------ shuffle worker gateway --------------------------- + + @Override + public CompletableFuture registerWorker( + ShuffleWorkerRegistration workerRegistration) { + final Function> + currentConsumer = registerShuffleWorkerConsumer; + if (currentConsumer != null) { + return currentConsumer.apply(workerRegistration); + } + + return CompletableFuture.completedFuture( + new ShuffleWorkerRegistrationSuccess(new RegistrationID(), instanceID)); + } + + @Override + public CompletableFuture workerReportDataPartitionReleased( + InstanceID workerID, + RegistrationID registrationID, + JobID jobID, + DataSetID dataSetID, + DataPartitionID dataPartitionID) { + final BiFunction< + Pair, + Triple, + CompletableFuture> + currentConsumer = workerReportDataPartitionReleasedConsumer; + if (currentConsumer != null) { + return currentConsumer.apply( + Pair.of(workerID, registrationID), + Triple.of(jobID, dataSetID, dataPartitionID)); + } + + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + @Override + public CompletableFuture reportDataPartitionStatus( + InstanceID workerID, + RegistrationID registrationID, + List dataPartitionStatuses) { + final TriFunction< + InstanceID, + RegistrationID, + List, + CompletableFuture> + currentConsumer = reportShuffleDataStatusConsumer; + if (currentConsumer != null) { + return currentConsumer.apply(workerID, registrationID, dataPartitionStatuses); + } + + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + @Override + public void heartbeatFromWorker(InstanceID workerID, WorkerToManagerHeartbeatPayload payload) { + BiConsumer currentConsumer = + heartbeatFromShuffleWorkerConsumer; + if (currentConsumer != null) { + currentConsumer.accept(workerID, payload); + } + } + + @Override + public CompletableFuture disconnectWorker(InstanceID workerID, Exception cause) { + BiFunction> currentConsumer = + disconnectShuffleWorkerConsumer; + if (currentConsumer != null) { + return currentConsumer.apply(workerID, cause); + } + + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + @Override + public CompletableFuture registerClient( + JobID jobID, InstanceID clientID) { + Function> currentConsumer = + registerClientConsumer; + if (currentConsumer != null) { + return currentConsumer.apply(jobID); + } + + return CompletableFuture.completedFuture(new RegistrationSuccess(getInstanceID())); + } + + @Override + public CompletableFuture unregisterClient(JobID jobID, InstanceID clientID) { + Function> currentConsumer = unregisterClientConsumer; + if (currentConsumer != null) { + return unregisterClientConsumer.apply(jobID); + } + + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + @Override + public CompletableFuture heartbeatFromClient( + JobID jobID, InstanceID clientID, Set cachedWorkerList) { + BiFunction, CompletableFuture> + currentConsumer = heartbeatFromClientConsumer; + if (currentConsumer != null) { + return currentConsumer.apply(jobID, cachedWorkerList); + } + + return CompletableFuture.completedFuture( + new ManagerToJobHeartbeatPayload( + clientID, + new ChangedWorkerStatus(Collections.emptyList(), Collections.emptyMap()))); + } + + @Override + public CompletableFuture requestShuffleResource( + JobID jobID, + InstanceID clientID, + DataSetID dataSetID, + MapPartitionID mapPartitionID, + int numberOfConsumers, + String dataPartitionFactoryName) { + QuadFunction> + currentConsumer = allocateShuffleResourceConsumer; + if (currentConsumer != null) { + return currentConsumer.apply(jobID, dataSetID, mapPartitionID, numberOfConsumers); + } + + return CompletableFuture.completedFuture(null); + } + + @Override + public CompletableFuture releaseShuffleResource( + JobID jobID, InstanceID clientID, DataSetID dataSetID, MapPartitionID mapPartitionID) { + TriFunction> + currentConsumer = releaseShuffleResourceConsumer; + if (currentConsumer != null) { + currentConsumer.apply(jobID, dataSetID, mapPartitionID); + } + + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + @Override + public CompletableFuture getNumberOfRegisteredWorkers() { + return CompletableFuture.completedFuture(null); + } + + @Override + public CompletableFuture> getShuffleWorkerMetrics() { + return CompletableFuture.completedFuture(null); + } + + @Override + public CompletableFuture> listJobs() { + throw new UnsupportedOperationException("Not supported currently."); + } + + @Override + public CompletableFuture getJobDataPartitionDistribution( + JobID jobID) { + throw new UnsupportedOperationException("Not supported currently."); + } + + public InstanceID getInstanceID() { + return instanceID; + } + + @Override + public UUID getFencingToken() { + return shuffleManagerId; + } + + @Override + public String getAddress() { + return address; + } + + @Override + public String getHostname() { + return hostname; + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/TestingUtils.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/TestingUtils.java new file mode 100644 index 00000000..98ce2909 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/utils/TestingUtils.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.utils; + +import com.alibaba.flink.shuffle.rpc.executor.ScheduledExecutor; +import com.alibaba.flink.shuffle.rpc.executor.ScheduledExecutorServiceAdapter; + +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; + +/** The utils for testing coordinators. */ +public class TestingUtils { + + private static ScheduledExecutorService scheduledExecutor; + + public static synchronized ScheduledExecutor defaultScheduledExecutor() { + if (scheduledExecutor == null || scheduledExecutor.isShutdown()) { + scheduledExecutor = Executors.newSingleThreadScheduledExecutor(); + } + return new ScheduledExecutorServiceAdapter(scheduledExecutor); + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/worker/LocalShuffleWorkerLocation.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/worker/LocalShuffleWorkerLocation.java new file mode 100644 index 00000000..24aec307 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/worker/LocalShuffleWorkerLocation.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.worker; + +import com.alibaba.flink.shuffle.core.ids.InstanceID; + +/** Dummy local shuffle worker unresolved location for testing purposes. */ +public class LocalShuffleWorkerLocation extends ShuffleWorkerLocation { + private static final long serialVersionUID = 1L; + + public LocalShuffleWorkerLocation() { + super("localhost", 42, new InstanceID()); + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerRunnerTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerRunnerTest.java new file mode 100644 index 00000000..6dc58fad --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerRunnerTest.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.worker; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; +import com.alibaba.flink.shuffle.core.config.MemoryOptions; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.config.TransferOptions; +import com.alibaba.flink.shuffle.core.config.WorkerOptions; + +import org.junit.After; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; + +import java.io.File; +import java.io.IOException; +import java.time.Duration; +import java.util.Random; + +import static org.junit.Assert.assertEquals; + +/** Tests the behavior of the {@link ShuffleWorkerRunner}. */ +public class ShuffleWorkerRunnerTest { + + @ClassRule public static final TemporaryFolder TEMP_FOLDER = new TemporaryFolder(); + + @Rule public final Timeout timeout = Timeout.seconds(30); + + private ShuffleWorkerRunner shuffleWorkerRunner; + + @After + public void after() throws Exception { + System.setSecurityManager(null); + if (shuffleWorkerRunner != null) { + shuffleWorkerRunner.close(); + } + } + + @Test + public void testShouldShutdownOnFatalError() throws Exception { + Configuration configuration = createConfiguration(); + // very high timeout, to ensure that we don't fail because of registration timeouts + configuration.setDuration(ClusterOptions.REGISTRATION_TIMEOUT, Duration.ofHours(42)); + shuffleWorkerRunner = createShuffleWorkerRunner(configuration); + + shuffleWorkerRunner.onFatalError(new RuntimeException("Test Exception")); + + assertEquals( + ShuffleWorkerRunner.Result.FAILURE, + shuffleWorkerRunner.getTerminationFuture().get()); + } + + @Test + public void testShouldShutdownIfRegistrationWithShuffleManagerFails() throws Exception { + Configuration configuration = createConfiguration(); + configuration.setDuration(ClusterOptions.REGISTRATION_TIMEOUT, Duration.ofMillis(10)); + shuffleWorkerRunner = createShuffleWorkerRunner(configuration); + + assertEquals( + ShuffleWorkerRunner.Result.FAILURE, + shuffleWorkerRunner.getTerminationFuture().get()); + } + + private static Configuration createConfiguration() throws IOException { + Configuration configuration = new Configuration(); + File baseDir = TEMP_FOLDER.newFolder(); + String basePath = baseDir.getAbsolutePath() + "/"; + + configuration.setString(ManagerOptions.RPC_ADDRESS, "localhost"); + configuration.setString(WorkerOptions.HOST, "localhost"); + configuration.setString(StorageOptions.STORAGE_LOCAL_DATA_DIRS, basePath); + + // choose random worker port + Random random = new Random(System.currentTimeMillis()); + int nextPort = random.nextInt(30000) + 20000; + configuration.setInteger(TransferOptions.SERVER_DATA_PORT, nextPort); + + return configuration; + } + + private static ShuffleWorkerRunner createShuffleWorkerRunner(Configuration configuration) + throws Exception { + configuration.setMemorySize( + MemoryOptions.MEMORY_SIZE_FOR_DATA_READING, MemoryOptions.MIN_VALID_MEMORY_SIZE); + configuration.setMemorySize( + MemoryOptions.MEMORY_SIZE_FOR_DATA_WRITING, MemoryOptions.MIN_VALID_MEMORY_SIZE); + ShuffleWorkerRunner shuffleWorkerRunner = new ShuffleWorkerRunner(configuration); + shuffleWorkerRunner.start(); + return shuffleWorkerRunner; + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerTest.java new file mode 100644 index 00000000..2e7edfb7 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/worker/ShuffleWorkerTest.java @@ -0,0 +1,505 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.worker; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServices; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServicesUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.TestingHaServices; +import com.alibaba.flink.shuffle.coordinator.leaderretrieval.SettableLeaderRetrievalService; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionStatus; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerRegistrationSuccess; +import com.alibaba.flink.shuffle.coordinator.manager.WorkerToManagerHeartbeatPayload; +import com.alibaba.flink.shuffle.coordinator.utils.EmptyMetaStore; +import com.alibaba.flink.shuffle.coordinator.utils.EmptyPartitionedDataStore; +import com.alibaba.flink.shuffle.coordinator.utils.RandomIDUtils; +import com.alibaba.flink.shuffle.coordinator.utils.RecordingHeartbeatServices; +import com.alibaba.flink.shuffle.coordinator.utils.TestingFatalErrorHandler; +import com.alibaba.flink.shuffle.coordinator.utils.TestingShuffleManagerGateway; +import com.alibaba.flink.shuffle.coordinator.worker.metastore.LocalShuffleMetaStore; +import com.alibaba.flink.shuffle.coordinator.worker.metastore.LocalShuffleMetaStoreTest; +import com.alibaba.flink.shuffle.core.config.TransferOptions; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.RegistrationID; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.rpc.message.Acknowledge; +import com.alibaba.flink.shuffle.rpc.test.TestingRpcService; +import com.alibaba.flink.shuffle.transfer.NettyConfig; +import com.alibaba.flink.shuffle.transfer.NettyServer; + +import org.apache.commons.lang3.tuple.Triple; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Queue; +import java.util.Random; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** Tests the behavior of the {@link ShuffleWorker}. */ +public class ShuffleWorkerTest { + + @ClassRule public static final TemporaryFolder TEMP_FOLDER = new TemporaryFolder(); + + private static final long timeout = 10000L; + + private TestingRpcService rpcService; + + private Configuration configuration; + + private SettableLeaderRetrievalService shuffleManagerLeaderRetrieveService; + + private TestingHaServices haServices; + + private ShuffleWorkerLocation shuffleWorkerLocation; + + private TestingFatalErrorHandler testingFatalErrorHandler; + + private PartitionedDataStore dataStore; + + private NettyServer nettyServer; + + @Before + public void setup() throws IOException { + rpcService = new TestingRpcService(); + + configuration = new Configuration(); + + shuffleManagerLeaderRetrieveService = new SettableLeaderRetrievalService(); + haServices = new TestingHaServices(); + haServices.setShuffleManagerLeaderRetrieveService(shuffleManagerLeaderRetrieveService); + + shuffleWorkerLocation = new LocalShuffleWorkerLocation(); + + testingFatalErrorHandler = new TestingFatalErrorHandler(); + + dataStore = new EmptyPartitionedDataStore(); + + // choose random worker port + Random random = new Random(System.currentTimeMillis()); + int nextPort = random.nextInt(30000) + 20000; + configuration.setInteger(TransferOptions.SERVER_DATA_PORT, nextPort); + NettyConfig nettyConfig = new NettyConfig(configuration); + nettyServer = new NettyServer(dataStore, nettyConfig); + + // By default, we do not start netty server. + } + + @After + public void teardown() throws Exception { + if (rpcService != null) { + rpcService.stopService().get(timeout, TimeUnit.MILLISECONDS); + rpcService = null; + } + + if (nettyServer != null) { + nettyServer.shutdown(); + } + } + + @Test + public void testRegisterAndReportDataPartitionStatus() throws Exception { + TestingShuffleManagerGateway smGateway = new TestingShuffleManagerGateway(); + + CompletableFuture registrationFuture = new CompletableFuture<>(); + ShuffleWorkerRegistrationSuccess registrationSuccess = + new ShuffleWorkerRegistrationSuccess( + new RegistrationID(), smGateway.getInstanceID()); + smGateway.setRegisterShuffleWorkerConsumer( + registration -> { + registrationFuture.complete(registration.getWorkerID()); + return CompletableFuture.completedFuture(registrationSuccess); + }); + + List reportedStatuses = + Arrays.asList( + new DataPartitionStatus( + RandomIDUtils.randomJobId(), + new DataPartitionCoordinate( + RandomIDUtils.randomDataSetId(), + RandomIDUtils.randomMapPartitionId()), + false), + new DataPartitionStatus( + RandomIDUtils.randomJobId(), + new DataPartitionCoordinate( + RandomIDUtils.randomDataSetId(), + RandomIDUtils.randomMapPartitionId()), + true)); + CompletableFuture> reportedDataPartitionsFuture = + new CompletableFuture<>(); + smGateway.setReportShuffleDataStatusConsumer( + (resourceID, instanceID, dataPartitionStatuses) -> { + reportedDataPartitionsFuture.complete(dataPartitionStatuses); + return CompletableFuture.completedFuture(Acknowledge.get()); + }); + + rpcService.registerGateway(smGateway.getAddress(), smGateway); + + try (ShuffleWorker shuffleWorker = + new ShuffleWorker( + rpcService, + ShuffleWorkerConfiguration.fromConfiguration(configuration), + haServices, + HeartbeatServicesUtils.createManagerWorkerHeartbeatServices(configuration), + testingFatalErrorHandler, + shuffleWorkerLocation, + new EmptyMetaStore() { + @Override + public List listDataPartitions() throws Exception { + return reportedStatuses; + } + }, + dataStore, + nettyServer)) { + shuffleWorker.start(); + shuffleManagerLeaderRetrieveService.notifyListener( + new LeaderInformation(smGateway.getFencingToken(), smGateway.getAddress())); + assertThat( + registrationFuture.get(timeout, TimeUnit.MILLISECONDS), + equalTo(shuffleWorkerLocation.getWorkerID())); + + assertThat( + reportedDataPartitionsFuture.get(timeout, TimeUnit.MILLISECONDS), + equalTo(reportedStatuses)); + } + } + + @Test + public void testHeartbeatTimeoutWithShuffleManager() throws Exception { + TestingShuffleManagerGateway smGateway = new TestingShuffleManagerGateway(); + + CompletableFuture registrationFuture = new CompletableFuture<>(); + CountDownLatch registrationAttempts = new CountDownLatch(2); + ShuffleWorkerRegistrationSuccess registrationSuccess = + new ShuffleWorkerRegistrationSuccess( + new RegistrationID(), smGateway.getInstanceID()); + smGateway.setRegisterShuffleWorkerConsumer( + registration -> { + registrationFuture.complete(registration.getWorkerID()); + registrationAttempts.countDown(); + return CompletableFuture.completedFuture(registrationSuccess); + }); + + CompletableFuture shuffleWorkerDisconnectFuture = new CompletableFuture<>(); + smGateway.setDisconnectShuffleWorkerConsumer( + (resourceID, e) -> { + shuffleWorkerDisconnectFuture.complete(resourceID); + return CompletableFuture.completedFuture(Acknowledge.get()); + }); + + rpcService.registerGateway(smGateway.getAddress(), smGateway); + + HeartbeatServices heartbeatServices = new HeartbeatServices(1L, 3L); + try (ShuffleWorker shuffleWorker = + new ShuffleWorker( + rpcService, + ShuffleWorkerConfiguration.fromConfiguration(configuration), + haServices, + heartbeatServices, + testingFatalErrorHandler, + shuffleWorkerLocation, + new EmptyMetaStore(), + dataStore, + nettyServer)) { + shuffleWorker.start(); + shuffleManagerLeaderRetrieveService.notifyListener( + new LeaderInformation(smGateway.getFencingToken(), smGateway.getAddress())); + assertThat( + registrationFuture.get(timeout, TimeUnit.MILLISECONDS), + equalTo(shuffleWorkerLocation.getWorkerID())); + + assertThat( + shuffleWorkerDisconnectFuture.get(timeout, TimeUnit.MILLISECONDS), + equalTo(shuffleWorkerLocation.getWorkerID())); + + assertTrue( + "The Shuffle Worker should try to reconnect to the RM", + registrationAttempts.await(timeout, TimeUnit.SECONDS)); + } + } + + @Test + public void testHeartbeatReporting() throws Exception { + TestingShuffleManagerGateway smGateway = new TestingShuffleManagerGateway(); + + // Registration future + CompletableFuture registrationFuture = new CompletableFuture<>(); + CountDownLatch registrationAttempts = new CountDownLatch(2); + ShuffleWorkerRegistrationSuccess registrationSuccess = + new ShuffleWorkerRegistrationSuccess( + new RegistrationID(), smGateway.getInstanceID()); + smGateway.setRegisterShuffleWorkerConsumer( + registration -> { + registrationFuture.complete(registration.getWorkerID()); + registrationAttempts.countDown(); + return CompletableFuture.completedFuture(registrationSuccess); + }); + + // Initial report future + CompletableFuture> reportedDataPartitionsFuture = + new CompletableFuture<>(); + smGateway.setReportShuffleDataStatusConsumer( + (resourceID, instanceID, dataPartitionStatuses) -> { + reportedDataPartitionsFuture.complete(dataPartitionStatuses); + return CompletableFuture.completedFuture(Acknowledge.get()); + }); + + // Heartbeat future + CompletableFuture heartbeatPayloadCompletableFuture = + new CompletableFuture<>(); + smGateway.setHeartbeatFromShuffleWorkerConsumer( + (resourceID, shuffleWorkerToManagerHeartbeatPayload) -> + heartbeatPayloadCompletableFuture.complete( + shuffleWorkerToManagerHeartbeatPayload)); + + rpcService.registerGateway(smGateway.getAddress(), smGateway); + + List reportedStatuses1 = + Arrays.asList( + new DataPartitionStatus( + RandomIDUtils.randomJobId(), + new DataPartitionCoordinate( + RandomIDUtils.randomDataSetId(), + RandomIDUtils.randomMapPartitionId()), + false), + new DataPartitionStatus( + RandomIDUtils.randomJobId(), + new DataPartitionCoordinate( + RandomIDUtils.randomDataSetId(), + RandomIDUtils.randomMapPartitionId()), + true)); + List reportedStatuses2 = + Arrays.asList( + new DataPartitionStatus( + RandomIDUtils.randomJobId(), + new DataPartitionCoordinate( + RandomIDUtils.randomDataSetId(), + RandomIDUtils.randomMapPartitionId()), + false), + new DataPartitionStatus( + RandomIDUtils.randomJobId(), + new DataPartitionCoordinate( + RandomIDUtils.randomDataSetId(), + RandomIDUtils.randomMapPartitionId()), + true)); + Queue> reportedStatusesQueue = + new ArrayDeque<>(Arrays.asList(reportedStatuses1, reportedStatuses2)); + + HeartbeatServices heartbeatServices = new HeartbeatServices(1000L, 2000L); + try (ShuffleWorker shuffleWorker = + new ShuffleWorker( + rpcService, + ShuffleWorkerConfiguration.fromConfiguration(configuration), + haServices, + heartbeatServices, + testingFatalErrorHandler, + shuffleWorkerLocation, + new EmptyMetaStore() { + @Override + public List listDataPartitions() throws Exception { + return reportedStatusesQueue.poll(); + } + }, + dataStore, + nettyServer)) { + shuffleWorker.start(); + shuffleManagerLeaderRetrieveService.notifyListener( + new LeaderInformation(smGateway.getFencingToken(), smGateway.getAddress())); + assertThat( + registrationFuture.get(timeout, TimeUnit.MILLISECONDS), + equalTo(shuffleWorkerLocation.getWorkerID())); + + assertThat( + reportedDataPartitionsFuture.get(timeout, TimeUnit.MILLISECONDS), + equalTo(reportedStatuses1)); + + ShuffleWorkerGateway shuffleWorkerGateway = + shuffleWorker.getSelfGateway(ShuffleWorkerGateway.class); + shuffleWorkerGateway.heartbeatFromManager(smGateway.getInstanceID()); + assertThat( + heartbeatPayloadCompletableFuture + .get(timeout, TimeUnit.MILLISECONDS) + .getDataPartitionStatuses(), + equalTo(reportedStatuses2)); + } + } + + @Test + public void testUnMonitorShuffleManagerOnLeadershipRevoked() throws Exception { + TestingShuffleManagerGateway smGateway = new TestingShuffleManagerGateway(); + + rpcService.registerGateway(smGateway.getAddress(), smGateway); + + RecordingHeartbeatServices heartbeatServices = new RecordingHeartbeatServices(1L, 100000L); + try (ShuffleWorker shuffleWorker = + new ShuffleWorker( + rpcService, + ShuffleWorkerConfiguration.fromConfiguration(configuration), + haServices, + heartbeatServices, + testingFatalErrorHandler, + shuffleWorkerLocation, + new EmptyMetaStore(), + dataStore, + nettyServer)) { + shuffleWorker.start(); + + BlockingQueue monitoredTargets = heartbeatServices.getMonitoredTargets(); + BlockingQueue unmonitoredTargets = + heartbeatServices.getUnmonitoredTargets(); + + shuffleManagerLeaderRetrieveService.notifyListener( + new LeaderInformation(smGateway.getFencingToken(), smGateway.getAddress())); + assertThat( + monitoredTargets.poll(timeout, TimeUnit.MILLISECONDS), + equalTo(smGateway.getInstanceID())); + + shuffleManagerLeaderRetrieveService.notifyListener(LeaderInformation.empty()); + assertThat( + unmonitoredTargets.poll(timeout, TimeUnit.MILLISECONDS), + equalTo(smGateway.getInstanceID())); + } + } + + @Test + public void testReconnectionAttemptIfExplicitlyDisconnected() throws Exception { + TestingShuffleManagerGateway smGateway = new TestingShuffleManagerGateway(); + + BlockingQueue registrationQueue = new ArrayBlockingQueue<>(1); + ShuffleWorkerRegistrationSuccess registrationSuccess = + new ShuffleWorkerRegistrationSuccess( + new RegistrationID(), smGateway.getInstanceID()); + smGateway.setRegisterShuffleWorkerConsumer( + registration -> { + registrationQueue.add(registration.getWorkerID()); + return CompletableFuture.completedFuture(registrationSuccess); + }); + + rpcService.registerGateway(smGateway.getAddress(), smGateway); + + try (ShuffleWorker shuffleWorker = + new ShuffleWorker( + rpcService, + ShuffleWorkerConfiguration.fromConfiguration(configuration), + haServices, + HeartbeatServicesUtils.createManagerWorkerHeartbeatServices(configuration), + testingFatalErrorHandler, + shuffleWorkerLocation, + new EmptyMetaStore(), + dataStore, + nettyServer)) { + shuffleWorker.start(); + shuffleManagerLeaderRetrieveService.notifyListener( + new LeaderInformation(smGateway.getFencingToken(), smGateway.getAddress())); + InstanceID firstRegistrationAttempt = registrationQueue.take(); + assertEquals(shuffleWorkerLocation.getWorkerID(), firstRegistrationAttempt); + assertEquals(0, registrationQueue.size()); + + ShuffleWorkerGateway shuffleWorkerGateway = + shuffleWorker.getSelfGateway(ShuffleWorkerGateway.class); + shuffleWorkerGateway.disconnectManager(new Exception("Test exception")); + InstanceID secondAttempt = registrationQueue.take(); + assertEquals(shuffleWorkerLocation.getWorkerID(), secondAttempt); + } + } + + @Test + public void testNotifyManagerOnPartitionRemoved() throws Exception { + TestingShuffleManagerGateway smGateway = new TestingShuffleManagerGateway(); + + // Initial report future + CompletableFuture> reportedDataPartitionsFuture = + new CompletableFuture<>(); + smGateway.setReportShuffleDataStatusConsumer( + (resourceID, instanceID, dataPartitionStatuses) -> { + reportedDataPartitionsFuture.complete(dataPartitionStatuses); + return CompletableFuture.completedFuture(Acknowledge.get()); + }); + + // Release future + CompletableFuture> releasedPartition = + new CompletableFuture<>(); + smGateway.setWorkerReportDataPartitionReleasedConsumer( + (tmIds, dataPartitionIds) -> { + releasedPartition.complete(dataPartitionIds); + return CompletableFuture.completedFuture(Acknowledge.get()); + }); + + rpcService.registerGateway(smGateway.getAddress(), smGateway); + + // Create the test metastore + File baseDir = TEMP_FOLDER.newFolder(); + String basePath = baseDir.getAbsolutePath() + "/"; + LocalShuffleMetaStore metaStore = + new LocalShuffleMetaStore(Collections.singleton(basePath)); + DataPartitionMeta meta = LocalShuffleMetaStoreTest.randomMeta(basePath); + metaStore.onPartitionCreated(meta); + + try (ShuffleWorker shuffleWorker = + new ShuffleWorker( + rpcService, + ShuffleWorkerConfiguration.fromConfiguration(configuration), + haServices, + HeartbeatServicesUtils.createManagerWorkerHeartbeatServices(configuration), + testingFatalErrorHandler, + shuffleWorkerLocation, + metaStore, + dataStore, + nettyServer)) { + shuffleWorker.start(); + shuffleManagerLeaderRetrieveService.notifyListener( + new LeaderInformation(smGateway.getFencingToken(), smGateway.getAddress())); + assertThat( + reportedDataPartitionsFuture.get(timeout, TimeUnit.MILLISECONDS).size(), + equalTo(1)); + + metaStore.onPartitionRemoved(meta); + assertThat( + releasedPartition.get(timeout, TimeUnit.MILLISECONDS), + equalTo( + Triple.of( + meta.getJobID(), + meta.getDataSetID(), + meta.getDataPartitionID()))); + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/worker/metastore/LocalShuffleMetaStoreTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/worker/metastore/LocalShuffleMetaStoreTest.java new file mode 100644 index 00000000..0a68bd08 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/worker/metastore/LocalShuffleMetaStoreTest.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.worker.metastore; + +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionStatus; +import com.alibaba.flink.shuffle.coordinator.utils.RandomIDUtils; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; +import com.alibaba.flink.shuffle.core.storage.StorageType; +import com.alibaba.flink.shuffle.storage.partition.LocalFileMapPartitionMeta; +import com.alibaba.flink.shuffle.storage.partition.LocalMapPartitionFile; +import com.alibaba.flink.shuffle.storage.partition.LocalMapPartitionFileMeta; + +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.FileOutputStream; +import java.nio.channels.FileChannel; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** Tests the behavior of {@link LocalShuffleMetaStore}. */ +public class LocalShuffleMetaStoreTest { + + @ClassRule public static final TemporaryFolder TEMP_FOLDER = new TemporaryFolder(); + + @Test + public void testAddDataPartition() throws Exception { + File base = TEMP_FOLDER.newFolder(); + String basePath = base.getAbsolutePath() + "/"; + + LocalShuffleMetaStore metaStore = + new LocalShuffleMetaStore(Collections.singleton(basePath)); + + DataPartitionMeta meta = randomMeta(basePath); + metaStore.onPartitionCreated(meta); + + List current = metaStore.listDataPartitions(); + assertEquals(1, current.size()); + assertEquals( + new DataPartitionStatus( + meta.getJobID(), + new DataPartitionCoordinate(meta.getDataSetID(), meta.getDataPartitionID()), + false), + current.get(0)); + + File[] metaFiles = listMetaFiles(basePath); + assertNotNull(metaFiles); + assertEquals(1, metaFiles.length); + } + + @Test + public void testRestore() throws Exception { + Set basePaths = new HashSet<>(); + List expectedDataPartitions = new ArrayList<>(); + + for (int i = 0; i < 2; ++i) { + File base = TEMP_FOLDER.newFolder(); + basePaths.add(base.getAbsolutePath() + "/"); + } + LocalShuffleMetaStore metaStore = new LocalShuffleMetaStore(basePaths); + + // Adds some partition meta + for (String basePath : basePaths) { + for (int i = 0; i < 4; ++i) { + DataPartitionMeta meta = randomMeta(basePath); + metaStore.onPartitionCreated(meta); + expectedDataPartitions.add( + new DataPartitionStatus( + meta.getJobID(), + new DataPartitionCoordinate( + meta.getDataSetID(), meta.getDataPartitionID()), + false)); + } + } + + // Now let's create a new meta store + LocalShuffleMetaStore restoredMetastore = new LocalShuffleMetaStore(basePaths); + List dataPartitionStatus = restoredMetastore.listDataPartitions(); + assertThat(dataPartitionStatus, containsInAnyOrder(expectedDataPartitions.toArray())); + } + + @Test + public void testRestoreSkipNonExistMetaDir() throws Exception { + File base = TEMP_FOLDER.newFolder(); + String basePath = base.getAbsolutePath() + "/"; + LocalShuffleMetaStore metaStore = + new LocalShuffleMetaStore(Collections.singleton(basePath)); + assertEquals(0, metaStore.listDataPartitions().size()); + } + + @Test + public void testRestoreRemoveSpoiledMetaFiles() throws Exception { + File base = TEMP_FOLDER.newFolder(); + String basePath = base.getAbsolutePath() + "/"; + LocalShuffleMetaStore metaStore = + new LocalShuffleMetaStore(Collections.singleton(basePath)); + metaStore.onPartitionCreated(randomMeta(basePath)); + + File[] metaFiles = listMetaFiles(basePath); + assertEquals(1, metaFiles.length); + try (FileChannel channel = new FileOutputStream(metaFiles[0]).getChannel()) { + channel.truncate(10); + } + + LocalShuffleMetaStore restoredMetaStore = + new LocalShuffleMetaStore(Collections.singleton(basePath)); + assertEquals(0, restoredMetaStore.listDataPartitions().size()); + assertEquals(0, listMetaFiles(basePath).length); + } + + @Test + public void testRemoveMetaFile() throws Exception { + File base = TEMP_FOLDER.newFolder(); + String basePath = base.getAbsolutePath() + "/"; + LocalShuffleMetaStore metaStore = + new LocalShuffleMetaStore(Collections.singleton(basePath)); + DataPartitionMeta meta = randomMeta(basePath); + metaStore.onPartitionCreated(meta); + + // Removes the meta + metaStore.onPartitionRemoved(meta); + + List current = metaStore.listDataPartitions(); + assertEquals(1, current.size()); + assertEquals( + new DataPartitionStatus( + meta.getJobID(), + new DataPartitionCoordinate(meta.getDataSetID(), meta.getDataPartitionID()), + true), + current.get(0)); + // The meta file should has been removed + File[] metaFiles = listMetaFiles(basePath); + assertEquals(0, metaFiles.length); + } + + public static LocalFileMapPartitionMeta randomMeta(String storagePath) { + JobID jobId = RandomIDUtils.randomJobId(); + DataSetID dataSetId = RandomIDUtils.randomDataSetId(); + MapPartitionID dataPartitionId = RandomIDUtils.randomMapPartitionId(); + LocalMapPartitionFileMeta fileMeta = + new LocalMapPartitionFileMeta( + storagePath + "test", 10, LocalMapPartitionFile.LATEST_STORAGE_VERSION); + StorageMeta storageMeta = new StorageMeta(storagePath, StorageType.SSD); + return new LocalFileMapPartitionMeta( + jobId, dataSetId, dataPartitionId, fileMeta, storageMeta); + } + + private File[] listMetaFiles(String basePath) { + return new File(basePath, LocalShuffleMetaStore.META_DIR_NAME).listFiles(); + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/zookeeper/ZooKeeperHaServicesTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/zookeeper/ZooKeeperHaServicesTest.java new file mode 100644 index 00000000..f967d9fe --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/zookeeper/ZooKeeperHaServicesTest.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.zookeeper; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.functions.ConsumerWithException; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaMode; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServiceUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderElectionService; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderRetrievalService; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperHaServices; +import com.alibaba.flink.shuffle.coordinator.leaderelection.TestingContender; +import com.alibaba.flink.shuffle.coordinator.leaderelection.TestingListener; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; +import com.alibaba.flink.shuffle.core.utils.TestLogger; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFrameworkFactory; +import org.apache.flink.shaded.curator4.org.apache.curator.retry.RetryNTimes; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; + +import javax.annotation.Nonnull; + +import java.util.List; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +/** Tests for the {@link ZooKeeperHaServices}. */ +public class ZooKeeperHaServicesTest extends TestLogger { + + @ClassRule public static final ZooKeeperResource ZOO_KEEPER_RESOURCE = new ZooKeeperResource(); + + private static CuratorFramework client; + + @BeforeClass + public static void setupClass() { + client = startCuratorFramework(); + client.start(); + } + + @Before + public void setup() throws Exception { + final List children = client.getChildren().forPath("/"); + + for (String child : children) { + if (!child.equals("zookeeper")) { + client.delete().deletingChildrenIfNeeded().forPath('/' + child); + } + } + } + + @AfterClass + public static void teardownClass() { + if (client != null) { + client.close(); + } + } + + /** Tests that a simple {@link ZooKeeperHaServices#close()} does not delete Zookeeper paths. */ + @Test + public void testSimpleClose() throws Exception { + final String rootPath = "/foo/bar/flink"; + final Configuration configuration = createConfiguration(rootPath); + + runCleanupTest(configuration, ZooKeeperHaServices::close); + + final List children = client.getChildren().forPath(rootPath); + assertThat(children, is(not(empty()))); + } + + /** + * Tests that the {@link ZooKeeperHaServices} cleans up all paths if it is closed via {@link + * ZooKeeperHaServices#closeAndCleanupAllData()}. + */ + @Test + public void testSimpleCloseAndCleanupAllData() throws Exception { + final Configuration configuration = createConfiguration("/foo/bar/flink"); + + final List initialChildren = client.getChildren().forPath("/"); + + runCleanupTest(configuration, ZooKeeperHaServices::closeAndCleanupAllData); + + final List children = client.getChildren().forPath("/"); + assertThat(children, is(equalTo(initialChildren))); + } + + /** Tests that we can only delete the parent znodes as long as they are empty. */ + @Test + public void testCloseAndCleanupAllDataWithUncle() throws Exception { + final String prefix = "/foo/bar"; + final String flinkPath = prefix + "/flink"; + final Configuration configuration = createConfiguration(flinkPath); + + final String unclePath = prefix + "/foobar"; + client.create().creatingParentContainersIfNeeded().forPath(unclePath); + + runCleanupTest(configuration, ZooKeeperHaServices::closeAndCleanupAllData); + + assertThat(client.checkExists().forPath(flinkPath), is(nullValue())); + assertThat(client.checkExists().forPath(unclePath), is(notNullValue())); + } + + private static CuratorFramework startCuratorFramework() { + return CuratorFrameworkFactory.builder() + .connectString(ZOO_KEEPER_RESOURCE.getConnectString()) + .retryPolicy(new RetryNTimes(50, 100)) + .build(); + } + + @Nonnull + private Configuration createConfiguration(String rootPath) { + final Configuration configuration = new Configuration(); + configuration.setString( + HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM, + ZOO_KEEPER_RESOURCE.getConnectString()); + configuration.setString(HighAvailabilityOptions.HA_ZOOKEEPER_ROOT, rootPath); + return configuration; + } + + private void runCleanupTest( + Configuration configuration, + ConsumerWithException zooKeeperHaServicesConsumer) + throws Exception { + configuration.setString(HighAvailabilityOptions.HA_MODE, HaMode.ZOOKEEPER.toString()); + try (ZooKeeperHaServices zooKeeperHaServices = + (ZooKeeperHaServices) + HaServiceUtils.createAvailableOrEmbeddedServices( + configuration, Runnable::run)) { + + // create some Zk services to trigger the generation of paths + final LeaderRetrievalService shuffleManagerLeaderRetriever = + zooKeeperHaServices.createLeaderRetrievalService( + HaServices.LeaderReceptor.SHUFFLE_WORKER); + final LeaderElectionService shuffleManagerLeaderElectionService = + zooKeeperHaServices.createLeaderElectionService(); + + final TestingListener listener = new TestingListener(); + shuffleManagerLeaderRetriever.start(listener); + shuffleManagerLeaderElectionService.start( + new TestingContender("foobar", shuffleManagerLeaderElectionService)); + + listener.waitForNewLeader(2000L); + + shuffleManagerLeaderRetriever.stop(); + shuffleManagerLeaderElectionService.stop(); + + zooKeeperHaServicesConsumer.accept(zooKeeperHaServices); + } + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/zookeeper/ZooKeeperResource.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/zookeeper/ZooKeeperResource.java new file mode 100644 index 00000000..3f4ac5cf --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/coordinator/zookeeper/ZooKeeperResource.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.coordinator.zookeeper; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import org.apache.curator.test.TestingServer; +import org.junit.rules.ExternalResource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.io.IOException; + +/** {@link ExternalResource} which starts a {@link org.apache.zookeeper.server.ZooKeeperServer}. */ +public class ZooKeeperResource extends ExternalResource { + + private static final Logger LOG = LoggerFactory.getLogger(ZooKeeperResource.class); + + @Nullable private TestingServer zooKeeperServer; + + public String getConnectString() { + verifyIsRunning(); + return zooKeeperServer.getConnectString(); + } + + private void verifyIsRunning() { + CommonUtils.checkState(zooKeeperServer != null); + } + + @Override + protected void before() throws Throwable { + terminateZooKeeperServer(); + zooKeeperServer = new TestingServer(true); + } + + private void terminateZooKeeperServer() throws IOException { + if (zooKeeperServer != null) { + zooKeeperServer.stop(); + zooKeeperServer = null; + } + } + + @Override + protected void after() { + try { + terminateZooKeeperServer(); + } catch (IOException e) { + LOG.warn("Could not properly terminate the {}.", getClass().getSimpleName(), e); + } + } + + public void restart() throws Exception { + CommonUtils.checkNotNull(zooKeeperServer); + zooKeeperServer.restart(); + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/minicluster/ShuffleMiniClusterTest.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/minicluster/ShuffleMiniClusterTest.java new file mode 100644 index 00000000..6a246303 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/minicluster/ShuffleMiniClusterTest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.minicluster; + +import com.alibaba.flink.shuffle.client.ShuffleManagerClient; +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleResource; +import com.alibaba.flink.shuffle.coordinator.utils.RandomIDUtils; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; + +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +/** Tests the remote shuffle mini-cluster. */ +public class ShuffleMiniClusterTest { + + @ClassRule public static final TemporaryFolder TEMP_FOLDER = new TemporaryFolder(); + + private static final String partitionFactoryName = + "com.alibaba.flink.shuffle.storage.partition.LocalFileMapPartitionFactory"; + + @Test + public void testMiniClusterWithDedicatedRpcService() throws Exception { + ShuffleMiniClusterConfiguration configuration = + new ShuffleMiniClusterConfiguration.Builder() + .setNumShuffleWorkers(2) + .setConfiguration(initConfiguration()) + .build(); + + try (ShuffleMiniCluster shuffleMiniCluster = new ShuffleMiniCluster(configuration)) { + shuffleMiniCluster.start(); + + JobID jobId = RandomIDUtils.randomJobId(); + ShuffleManagerClient client = shuffleMiniCluster.createClient(jobId); + + DataSetID dataSetId = RandomIDUtils.randomDataSetId(); + MapPartitionID mapPartitionId = RandomIDUtils.randomMapPartitionId(); + CompletableFuture shuffleResource = + client.requestShuffleResource( + dataSetId, mapPartitionId, 2, partitionFactoryName); + shuffleResource.get(60_000, TimeUnit.MILLISECONDS); + } + } + + private Configuration initConfiguration() throws IOException { + Configuration configuration = new Configuration(); + configuration.setString( + StorageOptions.STORAGE_LOCAL_DATA_DIRS.key(), + TEMP_FOLDER.newFolder().getAbsolutePath()); + String address = InetAddress.getLocalHost().getHostAddress(); + configuration.setString(ManagerOptions.RPC_ADDRESS, address); + configuration.setString(ManagerOptions.RPC_BIND_ADDRESS, address); + return configuration; + } +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/utils/QuadFunction.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/utils/QuadFunction.java new file mode 100644 index 00000000..139173e1 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/utils/QuadFunction.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.utils; + +/** Function which takes three arguments. */ +public interface QuadFunction { + + /** Applies this function to the given arguments. */ + R apply(S s, T t, U u, V v); +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/utils/TriFunction.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/utils/TriFunction.java new file mode 100644 index 00000000..18bd9ab5 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/utils/TriFunction.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.utils; + +/** Function which takes three arguments. */ +public interface TriFunction { + + /** Applies this function to the given arguments. */ + R apply(S s, T t, U u); +} diff --git a/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/utils/Tuple4.java b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/utils/Tuple4.java new file mode 100644 index 00000000..57c17cd3 --- /dev/null +++ b/shuffle-coordinator/src/test/java/com/alibaba/flink/shuffle/utils/Tuple4.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.utils; + +import java.io.Serializable; +import java.util.Objects; + +/** A tuple with 4 fields. */ +public class Tuple4 implements Serializable { + + private static final long serialVersionUID = 453491702625489108L; + + /** Field 0 of the tuple. */ + public T0 f0; + /** Field 1 of the tuple. */ + public T1 f1; + /** Field 2 of the tuple. */ + public T2 f2; + /** Field 3 of the tuple. */ + public T3 f3; + + public Tuple4(T0 f0, T1 f1, T2 f2, T3 f3) { + this.f0 = f0; + this.f1 = f1; + this.f2 = f2; + this.f3 = f3; + } + + @Override + public boolean equals(Object that) { + if (this == that) { + return true; + } + if (that == null || getClass() != that.getClass()) { + return false; + } + Tuple4 tuple4 = (Tuple4) that; + return Objects.equals(f0, tuple4.f0) + && Objects.equals(f1, tuple4.f1) + && Objects.equals(f2, tuple4.f2) + && Objects.equals(f3, tuple4.f3); + } + + @Override + public int hashCode() { + return Objects.hash(f0, f1, f2, f3); + } + + public static Tuple4 of(T0 f0, T1 f1, T2 f2, T3 f3) { + return new Tuple4<>(f0, f1, f2, f3); + } +} diff --git a/shuffle-coordinator/src/test/resources/log4j2-test.properties b/shuffle-coordinator/src/test/resources/log4j2-test.properties new file mode 100644 index 00000000..d7fcb327 --- /dev/null +++ b/shuffle-coordinator/src/test/resources/log4j2-test.properties @@ -0,0 +1,26 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +rootLogger.level=OFF +rootLogger.appenderRef.test.ref=TestLogger +appender.testlogger.name=TestLogger +appender.testlogger.type=CONSOLE +appender.testlogger.target=SYSTEM_ERR +appender.testlogger.layout.type=PatternLayout +appender.testlogger.layout.pattern=%-4r [%t] %-5p %c %x - %m%n diff --git a/shuffle-core/pom.xml b/shuffle-core/pom.xml new file mode 100644 index 00000000..446a5927 --- /dev/null +++ b/shuffle-core/pom.xml @@ -0,0 +1,67 @@ + + + + + + com.alibaba.flink.shuffle + flink-shuffle-parent + 1.0-SNAPSHOT + + 4.0.0 + + shuffle-core + + + + com.alibaba.flink.shuffle + shuffle-common + ${project.version} + + + + org.apache.flink + flink-shaded-netty + 4.1.49.Final-${flink.shaded.version} + provided + + + + commons-cli + commons-cli + ${commons.cli.version} + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + test-jar + + + + + + + diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/ClusterOptions.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/ClusterOptions.java new file mode 100644 index 00000000..aa40080a --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/ClusterOptions.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; + +import java.time.Duration; + +/** Options which control the cluster behaviour. */ +public class ClusterOptions { + + /** The unique ID of the remote shuffle cluster used by HA. */ + public static final ConfigOption REMOTE_SHUFFLE_CLUSTER_ID = + new ConfigOption("remote-shuffle.cluster.id") + .defaultValue("/default-cluster") + .description( + "The unique ID of the remote shuffle cluster used by high-availability."); + + /** + * Defines the timeout for the shuffle worker or client registration to the shuffle manager. If + * the duration is exceeded without a successful registration, then the shuffle worker or client + * terminates. + */ + public static final ConfigOption REGISTRATION_TIMEOUT = + new ConfigOption("remote-shuffle.cluster.registration.timeout") + .defaultValue(Duration.ofMinutes(5)) + .description( + "Defines the timeout for the shuffle worker or client registration to " + + "the shuffle manager. If the duration is exceeded without a " + + "successful registration, then the shuffle worker or client " + + "terminates."); + + /** The pause made after a registration attempt caused an exception (other than timeout). */ + public static final ConfigOption ERROR_REGISTRATION_DELAY = + new ConfigOption("remote-shuffle.cluster.registration.error-delay") + .defaultValue(Duration.ofSeconds(10)) + .description( + "The pause made after a registration attempt caused an exception " + + "(other than timeout)."); + + /** The pause made after the registration attempt was refused. */ + public static final ConfigOption REFUSED_REGISTRATION_DELAY = + new ConfigOption("remote-shuffle.cluster.registration.refused-delay") + .defaultValue(Duration.ofSeconds(30)) + .description("The pause made after the registration attempt was refused."); + + // ------------------------------------------------------------------------ + + /** Not intended to be instantiated. */ + private ClusterOptions() {} +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/HeartbeatOptions.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/HeartbeatOptions.java new file mode 100644 index 00000000..742d9ea5 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/HeartbeatOptions.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; + +import java.time.Duration; + +/** The configuration for the heart beat. */ +public class HeartbeatOptions { + + /** Time interval for shuffle manager to request heartbeat from shuffle worker. */ + public static final ConfigOption HEARTBEAT_WORKER_INTERVAL = + new ConfigOption("remote-shuffle.worker.heartbeat.interval") + .defaultValue(Duration.ofSeconds(10)) + .description( + "Time interval for shuffle manager to request heartbeat from shuffle " + + "worker."); + + /** Timeout for shuffle manager and shuffle worker to request and receive heartbeat. */ + public static final ConfigOption HEARTBEAT_WORKER_TIMEOUT = + new ConfigOption("remote-shuffle.worker.heartbeat.timeout") + .defaultValue(Duration.ofSeconds(60)) + .description( + "Timeout for shuffle manager and shuffle worker to request and receive" + + " heartbeat."); + + /** Time interval for shuffle client to request heartbeat from shuffle manager. */ + public static final ConfigOption HEARTBEAT_JOB_INTERVAL = + new ConfigOption("remote-shuffle.client.heartbeat.interval") + .defaultValue(Duration.ofSeconds(10)) + .description( + "Time interval for shuffle client to request heartbeat from shuffle " + + "manager."); + + /** Timeout for shuffle client and shuffle manager to request and receive heartbeat. */ + public static final ConfigOption HEARTBEAT_JOB_TIMEOUT = + new ConfigOption("remote-shuffle.client.heartbeat.timeout") + .defaultValue(Duration.ofSeconds(120)) + .description( + "Timeout for shuffle client and shuffle manager to request and receive" + + " heartbeat."); + + // ------------------------------------------------------------------------ + + /** Not intended to be instantiated. */ + private HeartbeatOptions() {} +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/HighAvailabilityOptions.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/HighAvailabilityOptions.java new file mode 100644 index 00000000..12f57f1f --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/HighAvailabilityOptions.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; + +import java.time.Duration; + +/** The set of configuration options relating to high-availability settings. */ +public class HighAvailabilityOptions { + + /** + * Defines high-availability mode used for the cluster execution. A value of "NONE" signals no + * highly available setup. To enable high-availability, set this mode to "ZOOKEEPER". Can also + * be set to FQN of HighAvailability factory class. + */ + public static final ConfigOption HA_MODE = + new ConfigOption("remote-shuffle.high-availability.mode") + .defaultValue("NONE") + .description( + "Defines high-availability mode used for the cluster execution." + + " To enable high-availability, set this mode to 'ZOOKEEPER' " + + "or specify FQN of factory class."); + // ------------------------------------------------------------------------ + // ZooKeeper Options + // ------------------------------------------------------------------------ + + /** + * The ZooKeeper quorum to use when running the remote shuffle cluster in a high-availability + * mode with ZooKeeper. + */ + public static final ConfigOption HA_ZOOKEEPER_QUORUM = + new ConfigOption("remote-shuffle.ha.zookeeper.quorum") + .defaultValue(null) + .description( + "The ZooKeeper quorum to use when running the remote shuffle cluster " + + "in a high-availability mode with ZooKeeper."); + + /** + * The root path under which the remote shuffle cluster stores its entries in ZooKeeper. + * Different remote shuffle clusters will be distinguished by the cluster id. + */ + public static final ConfigOption HA_ZOOKEEPER_ROOT = + new ConfigOption("remote-shuffle.ha.zookeeper.root-path") + .defaultValue("flink-remote-shuffle") + .description( + "The root path in ZooKeeper under which the remote shuffle cluster " + + "stores its entries. Different remote shuffle clusters will " + + "be distinguished by the cluster id."); + + // ------------------------------------------------------------------------ + // ZooKeeper Client Settings + // ------------------------------------------------------------------------ + + /** Defines the session timeout for the ZooKeeper session. */ + public static final ConfigOption ZOOKEEPER_SESSION_TIMEOUT = + new ConfigOption("remote-shuffle.ha.zookeeper.session-timeout") + .defaultValue(Duration.ofSeconds(60)) + .description("Defines the session timeout for the ZooKeeper session."); + + /** Defines the connection timeout for the ZooKeeper client. */ + public static final ConfigOption ZOOKEEPER_CONNECTION_TIMEOUT = + new ConfigOption("remote-shuffle.ha.zookeeper.connection-timeout") + .defaultValue(Duration.ofSeconds(15)) + .description("Defines the connection timeout for the ZooKeeper client."); + + /** Defines the pause between consecutive connection retries. */ + public static final ConfigOption ZOOKEEPER_RETRY_WAIT = + new ConfigOption("remote-shuffle.ha.zookeeper.retry-wait") + .defaultValue(Duration.ofSeconds(5)) + .description("Defines the pause between consecutive connection retries."); + + /** Defines the number of connection retries before the client gives up. */ + public static final ConfigOption ZOOKEEPER_MAX_RETRY_ATTEMPTS = + new ConfigOption("remote-shuffle.ha.zookeeper.max-retry-attempts") + .defaultValue(3) + .description( + "Defines the number of connection retries before the client gives up."); + + // ------------------------------------------------------------------------ + + /** Not intended to be instantiated. */ + private HighAvailabilityOptions() {} +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/KubernetesOptions.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/KubernetesOptions.java new file mode 100644 index 00000000..c89d2d66 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/KubernetesOptions.java @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** This class holds configuration constants used by the remote shuffle deployment. */ +public class KubernetesOptions { + + // -------------------------------------------------------------------------------------------- + // Common configurations. + // -------------------------------------------------------------------------------------------- + + /** Image to use for the remote shuffle manager and worker containers. */ + public static final ConfigOption CONTAINER_IMAGE = + new ConfigOption("remote-shuffle.kubernetes.container.image") + .defaultValue(null) + .description( + "Image to use for the remote shuffle manager and worker containers."); + + /** + * The Kubernetes container image pull policy (IfNotPresent or Always or Never). The default + * policy is IfNotPresent to avoid putting pressure to image repository. + */ + public static final ConfigOption CONTAINER_IMAGE_PULL_POLICY = + new ConfigOption("remote-shuffle.kubernetes.container.image.pull-policy") + .defaultValue("IfNotPresent") + .description( + "The Kubernetes container image pull policy (IfNotPresent or Always or" + + " Never). The default policy is IfNotPresent to avoid putting" + + " pressure to image repository."); + + /** Whether to enable host network for pod. Generally, host network is faster. */ + public static final ConfigOption POD_HOST_NETWORK_ENABLED = + new ConfigOption("remote-shuffle.kubernetes.host-network.enabled") + .defaultValue(true) + .description( + "Whether to enable host network for pod. Generally, host network is " + + "faster."); + + // -------------------------------------------------------------------------------------------- + // ShuffleManager configurations. + // -------------------------------------------------------------------------------------------- + + /** The number of cpu used by the shuffle manager. */ + public static final ConfigOption SHUFFLE_MANAGER_CPU = + new ConfigOption("remote-shuffle.kubernetes.manager.cpu") + .defaultValue(1.0) + .description("The number of cpu used by the shuffle manager."); + + /** + * Env vars for the shuffle manager. Specified as key:value pairs separated by commas. For + * example, set timezone as TZ:Asia/Shanghai. + */ + public static final ConfigOption> SHUFFLE_MANAGER_ENV_VARS = + new ConfigOption>("remote-shuffle.kubernetes.manager.env-vars") + .defaultValue(Collections.emptyMap()) + .description( + "Env vars for the shuffle manager. Specified as key:value pairs " + + "separated by commas. For example, set timezone as " + + "TZ:Asia/Shanghai."); + + /** + * Specify the kubernetes EmptyDir volumes that will be mounted into shuffle manager container. + * Following attribute can be configured: + * + *

'name', the name of the volume. For example, 'name:disk1' means the volume named disk1. + * + *

'sizeLimit', the limit size of the volume. For example, 'sizeLimit:5Gi' means the volume + * size is limited to 5Gi. + * + *

'mountPath', the mount path in container. For example, 'mountPath:/opt/disk1' means this + * volume will be mounted as /opt/disk1 path. + */ + public static final ConfigOption>> SHUFFLE_MANAGER_EMPTY_DIR_VOLUMES = + new ConfigOption>>( + "remote-shuffle.kubernetes.manager.volume.empty-dirs") + .defaultValue(Collections.emptyList()) + .description( + "Specify the kubernetes empty dir volumes that will be mounted into " + + "shuffle manager container. The value should be in form of " + + "name:disk1,sizeLimit:5Gi,mountPath:/opt/disk1;name:disk2," + + "sizeLimit:5Gi,mountPath:/opt/disk2. More specifically, " + + "'name' is the name of the volume, 'sizeLimit' is the limit " + + "size of the volume and 'mountPath' is the mount path in " + + "container."); + + /** + * Specify the kubernetes HostPath volumes that will be mounted into shuffle manager container. + * Following attribute can be configured: + * + *

'name', the name of the volume. For example, 'name:disk1' means the volume named disk1. + * + *

'path', the directory location on host. For example, 'path:/dump/1' means the directory + * /dump/1 on host will be mounted into container. + * + *

'mountPath', the mount path in container. For example, 'mountPath:/opt/disk1' means this + * volume will be mounted as /opt/disk1 path. + */ + public static final ConfigOption>> SHUFFLE_MANAGER_HOST_PATH_VOLUMES = + new ConfigOption>>( + "remote-shuffle.kubernetes.manager.volume.host-paths") + .defaultValue(Collections.emptyList()) + .description( + "Specify the kubernetes HostPath volumes that will be mounted into " + + "shuffle manager container. The value should be in form of " + + "name:disk1,path:/dump/1,mountPath:/opt/disk1;name:disk2," + + "path:/dump/2,mountPath:/opt/disk2. More specifically, " + + "'name' is the name of the volume, 'path' is the directory " + + "location on host and 'mountPath' is the mount path in " + + "container."); + + /** + * The user-specified labels to be set for the shuffle manager pod. Specified as key:value pairs + * separated by commas. For example, version:alphav1,deploy:test. + */ + public static final ConfigOption> SHUFFLE_MANAGER_LABELS = + new ConfigOption>("remote-shuffle.kubernetes.manager.labels") + .defaultValue(Collections.emptyMap()) + .description( + "The user-specified labels to be set for the shuffle manager pod. " + + "Specified as key:value pairs separated by commas. For " + + "example, version:alphav1,deploy:test."); + + /** + * The user-specified node selector to be set for the shuffle manager pod. Specified as + * key:value pairs separated by commas. For example, environment:production,disk:ssd. + */ + public static final ConfigOption> SHUFFLE_MANAGER_NODE_SELECTOR = + new ConfigOption>("remote-shuffle.kubernetes.manager.node-selector") + .defaultValue(Collections.emptyMap()) + .description( + "The user-specified node selector to be set for the shuffle manager " + + "pod. Specified as key:value pairs separated by commas. For " + + "example, environment:production,disk:ssd."); + + /** + * The user-specified tolerations to be set to the shuffle manager pod. The value should be in + * the form of + * key:key1,operator:Equal,value:value1,effect:NoSchedule;key:key2,operator:Exists,effect:NoExecute,tolerationSeconds:6000. + */ + public static final ConfigOption>> SHUFFLE_MANAGER_TOLERATIONS = + new ConfigOption>>( + "remote-shuffle.kubernetes.manager.tolerations") + .defaultValue(Collections.emptyList()) + .description( + "The user-specified tolerations to be set to the shuffle manager pod. " + + "The value should be in the form of key:key1,operator:Equal," + + "value:value1,effect:NoSchedule;key:key2,operator:Exists," + + "effect:NoExecute,tolerationSeconds:6000."); + + // -------------------------------------------------------------------------------------------- + // ShuffleWorker configurations. + // -------------------------------------------------------------------------------------------- + + /** The number of cpu used by the shuffle worker. */ + public static final ConfigOption SHUFFLE_WORKER_CPU = + new ConfigOption("remote-shuffle.kubernetes.worker.cpu") + .defaultValue(1.0) + .description("The number of cpu used by the shuffle worker."); + + /** + * Env vars for the shuffle worker. Specified as key:value pairs separated by commas. For + * example, set timezone as TZ:Asia/Shanghai. + */ + public static final ConfigOption> SHUFFLE_WORKER_ENV_VARS = + new ConfigOption>("remote-shuffle.kubernetes.worker.env-vars") + .defaultValue(Collections.emptyMap()) + .description( + "Env vars for the shuffle worker. Specified as key:value pairs " + + "separated by commas. For example, set timezone as " + + "TZ:Asia/Shanghai."); + + /** + * Specify the kubernetes EmptyDir volumes that will be mounted into shuffle worker container. + * Following attribute can be configured: + * + *

'name', the name of the volume. For example, 'name:disk1' means the volume named disk1. + * + *

'sizeLimit', the limit size of the volume. For example, 'sizeLimit:5Gi' means the volume + * size is limited to 5Gi. + * + *

'mountPath', the mount path in container. For example, 'mountPath:/opt/disk1' means this + * volume will be mounted as /opt/disk1 path. + */ + public static final ConfigOption>> SHUFFLE_WORKER_EMPTY_DIR_VOLUMES = + new ConfigOption>>( + "remote-shuffle.kubernetes.worker.volume.empty-dirs") + .defaultValue(Collections.emptyList()) + .description( + "Specify the kubernetes empty dir volumes that will be mounted into " + + "shuffle worker container. The value should be in form of " + + "name:disk1,sizeLimit:5Gi,mountPath:/opt/disk1;name:disk2," + + "sizeLimit:5Gi,mountPath:/opt/disk2. More specifically, " + + "'name' is the name of the volume, 'sizeLimit', the limit " + + "size of the volume and 'mountPath' is the mount path in " + + "container."); + + /** + * Specify the kubernetes HostPath volumes that will be mounted into shuffle worker container. + * Following attribute can be configured: + * + *

'name', the name of the volume. For example, 'name:disk1' means the volume named disk1. + * + *

'path', the directory location on host. For example, 'path:/dump/1' means the directory + * /dump/1 on host will be mounted into container. + * + *

'mountPath', the mount path in container. For example, 'mountPath:/opt/disk1' means this + * volume will be mounted as /opt/disk1 path. + */ + public static final ConfigOption>> SHUFFLE_WORKER_HOST_PATH_VOLUMES = + new ConfigOption>>( + "remote-shuffle.kubernetes.worker.volume.host-paths") + .defaultValue(Collections.emptyList()) + .description( + "Specify the kubernetes HostPath volumes that will be mounted into " + + "shuffle worker container. The value should be in form of " + + "name:disk1,path:/dump/1,mountPath:/opt/disk1;name:disk2," + + "path:/dump/2,mountPath:/opt/disk2. More specifically, 'name'" + + " is the name of the volume, 'path' is the directory location" + + " on host and 'mountPath' is the mount path in container."); + + /** + * The user-specified labels to be set for the shuffle worker pods. Specified as key:value pairs + * separated by commas. For example, version:alphav1,deploy:test. + */ + public static final ConfigOption> SHUFFLE_WORKER_LABELS = + new ConfigOption>("remote-shuffle.kubernetes.worker.labels") + .defaultValue(Collections.emptyMap()) + .description( + "The user-specified labels to be set for the shuffle worker pods. " + + "Specified as key:value pairs separated by commas. For " + + "example, version:alphav1,deploy:test."); + + /** + * The user-specified node selector to be set for the shuffle worker pods. Specified as + * key:value pairs separated by commas. For example, environment:production,disk:ssd. + */ + public static final ConfigOption> SHUFFLE_WORKER_NODE_SELECTOR = + new ConfigOption>("remote-shuffle.kubernetes.worker.node-selector") + .defaultValue(Collections.emptyMap()) + .description( + "The user-specified node selector to be set for the shuffle worker " + + "pods. Specified as key:value pairs separated by commas. For " + + "example, environment:production,disk:ssd."); + + /** + * The user-specified tolerations to be set to the shuffle worker pod. The value should be in + * the form of + * key:key1,operator:Equal,value:value1,effect:NoSchedule;key:key2,operator:Exists,effect:NoExecute,tolerationSeconds:6000. + */ + public static final ConfigOption>> SHUFFLE_WORKER_TOLERATIONS = + new ConfigOption>>( + "remote-shuffle.kubernetes.worker.tolerations") + .defaultValue(Collections.emptyList()) + .description( + "The user-specified tolerations to be set to the shuffle worker pods. " + + "The value should be in the form of key:key1,operator:Equal," + + "value:value1,effect:NoSchedule;key:key2,operator:Exists," + + "effect:NoExecute,tolerationSeconds:6000."); + + /** + * The prefix of Kubernetes resource limit factor. It should not be less than 1. The resource + * could be cpu, memory, ephemeral-storage and all other types supported by Kubernetes. + */ + public static final String SHUFFLE_MANAGER_RESOURCE_LIMIT_FACTOR_PREFIX = + "remote-shuffle.kubernetes.manager.limit-factor."; + + /** + * The prefix of Kubernetes resource limit factor. It should not be less than 1. The resource + * could be cpu, memory, ephemeral-storage and all other types supported by Kubernetes. + */ + public static final String SHUFFLE_WORKER_RESOURCE_LIMIT_FACTOR_PREFIX = + "remote-shuffle.kubernetes.worker.limit-factor."; + + // ------------------------------------------------------------------------ + + /** Not intended to be instantiated. */ + private KubernetesOptions() {} +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/ManagerOptions.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/ManagerOptions.java new file mode 100644 index 00000000..36856817 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/ManagerOptions.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; +import com.alibaba.flink.shuffle.common.config.MemorySize; + +/** The configuration for the shuffle manager. */ +public class ManagerOptions { + + // ------------------------------------------------------------------------ + // General ShuffleManager Options + // ------------------------------------------------------------------------ + + /** Defines the network address to connect to for communication with the shuffle manager. */ + public static final ConfigOption RPC_ADDRESS = + new ConfigOption("remote-shuffle.manager.rpc-address") + .defaultValue(null) + .description( + "Defines the network address to connect to for communication with the " + + "shuffle manager."); + + /** The local address of the network interface that the shuffle manager binds to. */ + public static final ConfigOption RPC_BIND_ADDRESS = + new ConfigOption("remote-shuffle.manager.rpc-bind-address") + .defaultValue(null) + .description( + "The local address of the network interface that the shuffle manager " + + "binds to."); + + /** Defines the network port to connect to for communication with the shuffle manager. */ + public static final ConfigOption RPC_PORT = + new ConfigOption("remote-shuffle.manager.rpc-port") + .defaultValue(23123) + .description( + "Defines the network port to connect to for communication with the " + + "shuffle manager."); + + /** The local network port that the shuffle manager binds to. */ + public static final ConfigOption RPC_BIND_PORT = + new ConfigOption("remote-shuffle.manager.rpc-bind-port") + .defaultValue(null) + .description( + "The local network port that the shuffle manager binds to. If not " + + "configured, the external port (configured by '" + + RPC_PORT.key() + + "') will be used."); + + // ------------------------------------------------------------------------ + // ShuffleManager Memory Options + // ------------------------------------------------------------------------ + + /** Heap memory size to be used by the shuffle manager. */ + public static final ConfigOption FRAMEWORK_HEAP_MEMORY = + new ConfigOption("remote-shuffle.manager.memory.heap-size") + .defaultValue(MemorySize.parse("4g")) + .description("Heap memory size to be used by the shuffle manager."); + + /** Off-heap memory size to be used by the shuffle manager. */ + public static final ConfigOption FRAMEWORK_OFF_HEAP_MEMORY = + new ConfigOption("remote-shuffle.manager.memory.off-heap-size") + .defaultValue(MemorySize.parse("128m")) + .description("Off-heap memory size to be used by the shuffle manager."); + + /** JVM metaspace size to be used by the shuffle manager. */ + public static final ConfigOption JVM_METASPACE = + new ConfigOption("remote-shuffle.manager.memory.jvm-metaspace-size") + .defaultValue(MemorySize.parse("128m")) + .description("JVM metaspace size to be used by the shuffle manager."); + + /** JVM overhead size for the shuffle manager java process. */ + public static final ConfigOption JVM_OVERHEAD = + new ConfigOption("remote-shuffle.manager.memory.jvm-overhead-size") + .defaultValue(MemorySize.parse("128m")) + .description("JVM overhead size for the shuffle manager java process."); + + /** Java options to start the JVM of the shuffle manager with. */ + public static final ConfigOption JVM_OPTIONS = + new ConfigOption("remote-shuffle.manager.jvm-opts") + .defaultValue("") + .description("Java options to start the JVM of the shuffle manager with."); + + // ------------------------------------------------------------------------ + + /** Not intended to be instantiated. */ + private ManagerOptions() {} +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/MemoryOptions.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/MemoryOptions.java new file mode 100644 index 00000000..f6ab9957 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/MemoryOptions.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; +import com.alibaba.flink.shuffle.common.config.MemorySize; + +/** Config options for memory. */ +public class MemoryOptions { + + /** + * Minimum valid size of memory can be configured for data writing and reading. The target + * configuration options include {@link #MEMORY_SIZE_FOR_DATA_WRITING} and {@link + * #MEMORY_SIZE_FOR_DATA_READING} + */ + public static final MemorySize MIN_VALID_MEMORY_SIZE = MemorySize.parse("64m"); + + /** + * Size of the buffer to be allocated. Those allocated buffers will be used by both network and + * storage for data transmission, data writing and data reading. + */ + public static final ConfigOption MEMORY_BUFFER_SIZE = + new ConfigOption("remote-shuffle.memory.buffer-size") + .defaultValue(MemorySize.parse("32k")) + .description( + "Size of the buffer to be allocated. Those allocated buffers will be " + + "used by both network and storage for data transmission, data" + + " writing and data reading."); + + /** + * Size of memory to be allocated for data writing. Larger value means more direct memory + * consumption which may lead to better performance. The configured value must be no smaller + * than {@link #MIN_VALID_MEMORY_SIZE} and the buffer size configured by {@link + * #MEMORY_BUFFER_SIZE}, otherwise an exception will be thrown. + */ + public static final ConfigOption MEMORY_SIZE_FOR_DATA_WRITING = + new ConfigOption("remote-shuffle.memory.data-writing-size") + .defaultValue(MemorySize.parse("4g")) + .description( + String.format( + "Size of memory to be allocated for data writing. Larger value " + + "means more direct memory consumption which may lead " + + "to better performance. The configured value must be " + + "no smaller than %s and the buffer size configured by" + + " %s, otherwise an exception will be thrown.", + MIN_VALID_MEMORY_SIZE.toHumanReadableString(), + MEMORY_BUFFER_SIZE.key())); + + /** + * Size of memory to be allocated for data reading. Larger value means more direct memory + * consumption which may lead to better performance. The configured value must be no smaller + * than {@link #MIN_VALID_MEMORY_SIZE} and the buffer size configured by {@link + * #MEMORY_BUFFER_SIZE}, otherwise an exception will be thrown. + */ + public static final ConfigOption MEMORY_SIZE_FOR_DATA_READING = + new ConfigOption("remote-shuffle.memory.data-reading-size") + .defaultValue(MemorySize.parse("4g")) + .description( + String.format( + "Size of memory to be allocated for data reading. Larger value " + + "means more direct memory consumption which may lead " + + "to better performance. The configured value must be " + + "no smaller than %s and the buffer size configured by" + + " %s, otherwise an exception will be thrown.", + MIN_VALID_MEMORY_SIZE.toHumanReadableString(), + MEMORY_BUFFER_SIZE.key())); + + // ------------------------------------------------------------------------ + + /** Not intended to be instantiated. */ + private MemoryOptions() {} +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/MetricOptions.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/MetricOptions.java new file mode 100644 index 00000000..0c165cfc --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/MetricOptions.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; + +/** Config options for metrics. */ +public class MetricOptions { + + /** Whether the http server for reading metrics is enabled. */ + public static final ConfigOption METRICS_HTTP_SERVER_ENABLE = + new ConfigOption("remote-shuffle.metrics.enabled-http-server") + .defaultValue(true) + .description("Whether the http server for requesting metrics is enabled."); + + /** The local address of the network interface that the http metric server binds to. */ + public static final ConfigOption METRICS_BIND_HOST = + new ConfigOption("remote-shuffle.metrics.bind-host") + .defaultValue("0.0.0.0") + .description( + "The local address of the network interface that the http metric server" + + " binds to."); + + /** Shuffle manager http metric server bind port. */ + public static final ConfigOption METRICS_SHUFFLE_MANAGER_HTTP_BIND_PORT = + new ConfigOption("remote-shuffle.metrics.manager.bind-port") + .defaultValue(23101) + .description("Shuffle manager http metric server bind port."); + + /** Shuffle worker http metric server bind port. */ + public static final ConfigOption METRICS_SHUFFLE_WORKER_HTTP_BIND_PORT = + new ConfigOption("remote-shuffle.metrics.worker.bind-port") + .defaultValue(23103) + .description("Shuffle worker http metric server bind port."); + + /** + * Specify the implementation classes of metrics reporter. Separate by ';' if there are multiple + * class names. Each class name needs a package name prefix, e.g. a.b.c.Factory1;a.b.c.Factory2. + */ + public static final ConfigOption METRICS_REPORTER_CLASSES = + new ConfigOption("remote-shuffle.metrics.reporter.factories") + .defaultValue(null) + .description( + "Specify the implementation classes of metrics reporter. Separate by " + + "';' if there are multiple class names. Each class name needs" + + " a package name prefix, e.g. a.b.c.Factory1;a.b.c.Factory2."); + + // ------------------------------------------------------------------------ + + /** Not intended to be instantiated. */ + private MetricOptions() {} +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/RpcOptions.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/RpcOptions.java new file mode 100644 index 00000000..0bf5075e --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/RpcOptions.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; + +import java.time.Duration; + +/** Options for RPC service. */ +public class RpcOptions { + + /** Timeout for client <-> manager and worker <-> manager rpc calls. */ + public static final ConfigOption RPC_TIMEOUT = + new ConfigOption("remote-shuffle.rpc.timeout") + .defaultValue(Duration.ofSeconds(30)) + .description( + "Timeout for client <-> manager and worker <-> manager rpc calls."); + + /** Maximum size of messages can be sent through rpc calls. */ + public static final ConfigOption AKKA_FRAME_SIZE = + new ConfigOption("remote-shuffle.rpc.akka-frame-size") + .defaultValue("10485760b") + .description("Maximum size of messages can be sent through rpc calls."); + + // ------------------------------------------------------------------------ + + /** Not intended to be instantiated. */ + private RpcOptions() {} +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/StorageOptions.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/StorageOptions.java new file mode 100644 index 00000000..1f1a3f44 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/StorageOptions.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; +import com.alibaba.flink.shuffle.common.config.MemorySize; +import com.alibaba.flink.shuffle.core.storage.StorageType; + +/** Config options for storage. */ +public class StorageOptions { + + /** Minimum size of memory to be used for data partition writing and reading. */ + public static final MemorySize MIN_WRITING_READING_MEMORY_SIZE = MemorySize.parse("16m"); + + /** Whether to enable data checksum for data integrity verification or not. */ + public static final ConfigOption STORAGE_ENABLE_DATA_CHECKSUM = + new ConfigOption("remote-shuffle.storage.enable-data-checksum") + .defaultValue(false) + .description( + "Whether to enable data checksum for data integrity verification or not."); + + /** + * Maximum number of tolerable failures before marking a data partition as corrupted, which will + * trigger the reproduction of the corresponding data. + */ + public static final ConfigOption STORAGE_FILE_TOLERABLE_FAILURES = + new ConfigOption("remote-shuffle.storage.file-tolerable-failures") + .defaultValue(Integer.MAX_VALUE) + .description( + "Maximum number of tolerable failures before marking a data partition " + + "as corrupted, which will trigger the reproduction of the " + + "corresponding data."); + + /** + * Local file system directories to persist partitioned data to. Multiple directories can be + * configured and these directories should be separated by comma (,). Each configured directory + * can be attached with an optional label which indicates the disk type. The valid disk types + * include 'SSD' and 'HDD'. If no label is offered, the default type would be 'HDD'. Here is a + * simple valid configuration example: [SSD]/dir1/,[HDD]/dir2/,/dir3/. This option must + * be configured and the configured dirs must exist. + */ + public static final ConfigOption STORAGE_LOCAL_DATA_DIRS = + new ConfigOption("remote-shuffle.storage.local-data-dirs") + .defaultValue(null) + .description( + "Local file system directories to persist partitioned data to. Multiple" + + " directories can be configured and these directories should " + + "be separated by comma (,). Each configured directory can be " + + "attached with an optional label which indicates the disk " + + "type. The valid disk types include 'SSD' and 'HDD'. If no " + + "label is offered, the default type would be 'HDD'. Here is a" + + " simple valid configuration example: '[SSD]/dir1/,[HDD]/dir2" + + "/,/dir3/'. This option must be configured and the configured" + + " directories must exist."); + + /** + * Preferred disk type to use for data storage. The valid types include 'SSD' and 'HDD'. If + * there are disks of the preferred type, only those disks will be used. However, this is not a + * strict restriction, which means if there is no disk of the preferred type, disks of other + * types will be also used. + */ + public static final ConfigOption STORAGE_PREFERRED_TYPE = + new ConfigOption("remote-shuffle.storage.preferred-disk-type") + .defaultValue(StorageType.SSD.name()) + .description( + "Preferred disk type to use for data storage. The valid types include " + + "'SSD' and 'HDD'. If there are disks of the preferred type, " + + "only those disks will be used. However, this is not a strict " + + "restriction, which means if there is no disk of the preferred" + + " type, disks of other types will be also used."); + + /** + * Number of threads to be used by data store for data partition processing of each HDD. The + * actual number of threads per disk will be min[configured value, 4 * (number of processors)]. + */ + public static final ConfigOption STORAGE_NUM_THREADS_PER_HDD = + new ConfigOption("remote-shuffle.storage.hdd.num-executor-threads") + .defaultValue(8) + .description( + "Number of threads to be used by data store for data partition processing" + + " of each HDD. The actual number of threads per disk will be " + + "min[configured value, 4 * (number of processors)]."); + + /** + * Number of threads to be used by data store for data partition processing of each SSD. The + * actual number of threads per disk will be min[configured value, 4 * (number of processors)]. + */ + public static final ConfigOption STORAGE_SSD_NUM_EXECUTOR_THREADS = + new ConfigOption("remote-shuffle.storage.ssd.num-executor-threads") + .defaultValue(Integer.MAX_VALUE) + .description( + "Number of threads to be used by data store for data partition processing" + + " of each SSD. The actual number of threads per disk will be " + + "min[configured value, 4 * (number of processors)]."); + + /** + * Number of threads to be used by data store for in-memory data partition processing. The + * actual number of threads used will be min[configured value, 4 * (number of processors)]. + */ + public static final ConfigOption STORAGE_MEMORY_NUM_EXECUTOR_THREADS = + new ConfigOption("remote-shuffle.storage.memory.num-executor-threads") + .defaultValue(Integer.MAX_VALUE) + .description( + "Number of threads to be used by data store for in-memory data partition" + + " processing. The actual number of threads used will be min[" + + "configured value, 4 * (number of processors)]."); + + /** + * Maximum memory size to use for the data writing of each data partition. Note that if the + * configured value is smaller than {@link #MIN_WRITING_READING_MEMORY_SIZE}, the minimum {@link + * #MIN_WRITING_READING_MEMORY_SIZE} will be used. + */ + public static final ConfigOption STORAGE_MAX_PARTITION_WRITING_MEMORY = + new ConfigOption("remote-shuffle.storage.partition.max-writing-memory") + .defaultValue(MemorySize.parse("128m")) + .description( + String.format( + "Maximum memory size to use for the data writing of each data " + + "partition. Note that if the configured value is " + + "smaller than %s, the minimum %s will be used.", + MIN_WRITING_READING_MEMORY_SIZE.toHumanReadableString(), + MIN_WRITING_READING_MEMORY_SIZE.toHumanReadableString())); + + /** + * Maximum memory size to use for the data reading of each data partition. Note that if the + * configured value is smaller than {@link #MIN_WRITING_READING_MEMORY_SIZE}, the minimum {@link + * #MIN_WRITING_READING_MEMORY_SIZE} will be used. + */ + public static final ConfigOption STORAGE_MAX_PARTITION_READING_MEMORY = + new ConfigOption("remote-shuffle.storage.partition.max-reading-memory") + .defaultValue(MemorySize.parse("128m")) + .description( + String.format( + "Maximum memory size to use for the data reading of each data " + + "partition. Note that if the configured value is " + + "smaller than %s, the minimum %s will be used.", + MIN_WRITING_READING_MEMORY_SIZE.toHumanReadableString(), + MIN_WRITING_READING_MEMORY_SIZE.toHumanReadableString())); + + // ------------------------------------------------------------------------ + + /** Not intended to be instantiated. */ + private StorageOptions() {} +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/TransferOptions.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/TransferOptions.java new file mode 100644 index 00000000..11a0d488 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/TransferOptions.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; +import com.alibaba.flink.shuffle.common.config.MemorySize; + +import java.time.Duration; + +/** The set of configuration options relating to network stack. */ +public class TransferOptions { + + /** Data port to write shuffle data to and read shuffle data from shuffle workers. */ + public static final ConfigOption SERVER_DATA_PORT = + new ConfigOption("remote-shuffle.transfer.server.data-port") + .defaultValue(10086) + .description( + "Data port to write shuffle data to and read shuffle data from shuffle" + + " workers."); + + /** The number of Netty threads at the server (shuffle worker) side. */ + public static final ConfigOption NUM_THREADS_SERVER = + new ConfigOption("remote-shuffle.transfer.server.num-threads") + .defaultValue(Math.max(32, Runtime.getRuntime().availableProcessors())) + .description( + "The number of Netty threads at the server (shuffle worker) side."); + + /** + * The maximum TCP connection backlog of the Netty server. The default 0 means that the Netty's + * default value will be used. + */ + public static final ConfigOption CONNECT_BACKLOG = + new ConfigOption("remote-shuffle.transfer.server.backlog") + .defaultValue(0) // default: 0 => Netty's default + .description( + "The maximum TCP connection backlog of the Netty server. The default " + + "'0' means that the Netty's default value will be used."); + + /** + * The number of Netty threads at the client (flink job) side. The default '-1' means that 2 * + * (the number of slots) will be used. + */ + public static final ConfigOption NUM_THREADS_CLIENT = + new ConfigOption("remote-shuffle.transfer.client.num-threads") + .defaultValue(-1) // default: -1 => Not specified and let upper layer deduce. + .description( + "The number of Netty threads at the client (flink job) side. The " + + "default '-1' means that 2 * (the number of slots) will be " + + "used."); + + /** The TCP connection setup timeout of the Netty client. */ + public static final ConfigOption CLIENT_CONNECT_TIMEOUT = + new ConfigOption("remote-shuffle.transfer.client.connect-timeout") + .defaultValue(Duration.ofMinutes(2)) + .description("The TCP connection setup timeout of the Netty client."); + + /** Number of retries when failed to connect to the remote shuffle worker. */ + public static final ConfigOption CONNECTION_RETRIES = + new ConfigOption("remote-shuffle.transfer.client.connect-retries") + .defaultValue(3) + .description( + "Number of retries when failed to connect to the remote shuffle worker."); + + /** Time to wait between two consecutive connection retries. */ + public static final ConfigOption CONNECTION_RETRY_WAIT = + new ConfigOption("remote-shuffle.transfer.client.connect-retry-wait") + .defaultValue(Duration.ofSeconds(3)) + .description("Time to wait between two consecutive connection retries."); + + /** + * The Netty transport type, either 'nio' or 'epoll'. The 'auto' means selecting the proper mode + * automatically based on the platform. Note that the 'epoll' mode can get better performance, + * less GC and have more advanced features which are only available on modern Linux. + */ + public static final ConfigOption TRANSPORT_TYPE = + new ConfigOption("remote-shuffle.transfer.transport-type") + .defaultValue("auto") + .description( + "The Netty transport type, either 'nio' or 'epoll'. The 'auto' means " + + "selecting the proper mode automatically based on the platform." + + " Note that the 'epoll' mode can get better performance, less" + + " GC and have more advanced features which are only available" + + " on modern Linux."); + /** + * The Netty send and receive buffer size. The default '0' means the system buffer size (cat + * /proc/sys/net/ipv4/tcp_[rw]mem) and is 4 MiB in modern Linux. + */ + public static final ConfigOption SEND_RECEIVE_BUFFER_SIZE = + new ConfigOption("remote-shuffle.transfer.send-receive-buffer-size") + .defaultValue(MemorySize.ZERO) // default: 0 => Netty's default + .description( + "The Netty send and receive buffer size. The default '0b' means the " + + "system buffer size (cat /proc/sys/net/ipv4/tcp_[rw]mem) and" + + " is 4 MiB in modern Linux."); + + /** The time interval to send heartbeat between the Netty server and Netty client. */ + public static final ConfigOption HEARTBEAT_INTERVAL = + new ConfigOption("remote-shuffle.transfer.heartbeat.interval") + .defaultValue(Duration.ofMinutes(1)) + .description( + "The time interval to send heartbeat between the Netty server and Netty" + + " client."); + + /** Heartbeat timeout used to detect broken TCP connections. */ + public static final ConfigOption HEARTBEAT_TIMEOUT = + new ConfigOption("remote-shuffle.transfer.heartbeat.timeout") + .defaultValue(Duration.ofMinutes(5)) + .description("Heartbeat timeout used to detect broken Netty connections."); + + // ------------------------------------------------------------------------ + + /** Not intended to be instantiated. */ + private TransferOptions() {} +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/WorkerOptions.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/WorkerOptions.java new file mode 100644 index 00000000..659ea5de --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/WorkerOptions.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; +import com.alibaba.flink.shuffle.common.config.MemorySize; + +import java.time.Duration; + +/** The options for the shuffle worker. */ +public class WorkerOptions { + + // ------------------------------------------------------------------------ + // General ShuffleWorker Options + // ------------------------------------------------------------------------ + + /** + * The external address of the network interface where the shuffle worker is exposed. If not + * set, it will be determined automatically. Note: Different workers may need different values + * for this usually it can be specified in a non-shared shuffle worker specific configuration + * file. + */ + public static final ConfigOption HOST = + new ConfigOption("remote-shuffle.worker.host") + .defaultValue(null) + .description( + "The external address of the network interface where the shuffle worker" + + " is exposed. If not set, it will be determined automatically." + + " Note: Different workers may need different values for this " + + "option, usually it can be specified in a non-shared shuffle" + + " worker specific configuration file."); + + /** + * The automatic address binding policy used by the shuffle worker if {@link #HOST} is not set. + * The valid types include 'name' and 'ip': 'name' means using hostname as binding address, 'ip' + * means using host's ip address as binding address. + */ + public static final ConfigOption HOST_BIND_POLICY = + new ConfigOption("remote-shuffle.worker.bind-policy") + .defaultValue("ip") + .description( + String.format( + "The automatic address binding policy used by the shuffle" + + " worker if '%s' is not set. The valid types include" + + " 'name' and 'ip': 'name' means using hostname as " + + "binding address, 'ip' means using host's ip address " + + "as binding address.", + HOST.key())); + + /** The local address of the network interface that the shuffle worker binds to. */ + public static final ConfigOption BIND_HOST = + new ConfigOption("remote-shuffle.worker.bind-host") + .defaultValue("0.0.0.0") + .description( + "The local address of the network interface that the shuffle worker " + + "binds to."); + + /** + * Defines network port range the shuffle worker expects incoming RPC connections. Accepts a + * list of ports (”50100,50101”), ranges (“50100-50200”) or a combination of both. The default + * '0' means that the shuffle worker will search for a free port itself. + */ + public static final ConfigOption RPC_PORT = + new ConfigOption("remote-shuffle.worker.rpc-port") + .defaultValue("0") + .description( + "Defines network port range the shuffle worker expects incoming RPC " + + "connections. Accepts a list of ports (”50100,50101”), ranges" + + " (“50100-50200”) or a combination of both. The default '0' " + + "means that the shuffle worker will search for a free port " + + "itself."); + + /** The local network port that the shuffle worker binds to. */ + public static final ConfigOption RPC_BIND_PORT = + new ConfigOption("remote-shuffle.worker.rpc-bind-port") + .defaultValue(null) + .description( + "The local network port that the shuffle worker binds to. If not " + + "configured, the external port (configured by '" + + RPC_PORT.key() + + "') will be used."); + + /** + * Maximum time to wait before reproducing the data stored in the lost worker (heartbeat + * timeout). The lost worker may become available again in this timeout. + */ + public static final ConfigOption MAX_WORKER_RECOVER_TIME = + new ConfigOption("remote-shuffle.worker.max-recovery-time") + .defaultValue(Duration.ofMinutes(3)) + .description( + "Maximum time to wait before reproducing the data stored in the lost " + + "worker (heartbeat timeout). The lost worker may become " + + "available again in this timeout."); + + // ------------------------------------------------------------------------ + // ShuffleWorker Memory Options + // ------------------------------------------------------------------------ + + /** Heap memory size to be used by the shuffle worker. */ + public static final ConfigOption FRAMEWORK_HEAP_MEMORY = + new ConfigOption("remote-shuffle.worker.memory.heap-size") + .defaultValue(MemorySize.parse("1g")) + .description("Heap memory size to be used by the shuffle worker."); + + /** Off-heap memory size to be used by the shuffle worker. */ + public static final ConfigOption FRAMEWORK_OFF_HEAP_MEMORY = + new ConfigOption("remote-shuffle.worker.memory.off-heap-size") + .defaultValue(MemorySize.parse("128m")) + .description("Off-heap memory size to be used by the shuffle worker."); + + /** JVM metaspace size to be used by the shuffle worker. */ + public static final ConfigOption JVM_METASPACE = + new ConfigOption("remote-shuffle.worker.memory.jvm-metaspace-size") + .defaultValue(MemorySize.parse("128m")) + .description("JVM metaspace size to be used by the shuffle worker."); + + /** JVM overhead size for the shuffle worker java process. */ + public static final ConfigOption JVM_OVERHEAD = + new ConfigOption("remote-shuffle.worker.memory.jvm-overhead-size") + .defaultValue(MemorySize.parse("128m")) + .description("JVM overhead size for the shuffle worker java process."); + + /** Java options to start the JVM of the shuffle worker with. */ + public static final ConfigOption JVM_OPTIONS = + new ConfigOption("remote-shuffle.worker.jvm-opts") + .defaultValue("") + .description("Java options to start the JVM of the shuffle worker with."); + + /** Not intended to be instantiated. */ + private WorkerOptions() {} +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/memory/AbstractCommonProcessSpec.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/memory/AbstractCommonProcessSpec.java new file mode 100644 index 00000000..66c1b9e7 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/memory/AbstractCommonProcessSpec.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config.memory; + +import com.alibaba.flink.shuffle.common.config.MemorySize; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** + * Common memory components of shuffle manager and worker processes. + * + *

The process memory consists of the following components. + * + *

    + *
  • JVM Heap + *
  • JVM Direct(Off-Heap) + *
  • JVM Metaspace + *
  • JVM Overhead + *
+ * + *

The relationships of process memory components are shown below. The memory of the kubernetes + * container will be set to the Total Process Memory. + * + *

+ *               ┌ ─ ─ Total Process Memory  ─ ─ ┐
+ *               │┌─────────────────────────────┐│
+ *                │        JVM Heap             │
+ *               │└─────────────────────────────┘│
+ *               │┌─────────────────────────────┐│
+ *                │        JVM Direct           │
+ *               │└─────────────────────────────┘│
+ *               │┌─────────────────────────────┐│
+ *                │        JVM Metaspace        │
+ *               │└─────────────────────────────┘│
+ *                ┌─────────────────────────────┐
+ *               ││        JVM Overhead         ││
+ *                └─────────────────────────────┘
+ *               └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘
+ * 
+ */ +public abstract class AbstractCommonProcessSpec implements ProcessMemorySpec { + + private static final long serialVersionUID = -8835426046927573310L; + + private final MemorySize metaspace; + private final MemorySize overhead; + + public AbstractCommonProcessSpec(MemorySize metaspace, MemorySize overhead) { + this.metaspace = checkNotNull(metaspace); + this.overhead = checkNotNull(overhead); + } + + @Override + public MemorySize getJvmMetaspaceSize() { + return metaspace; + } + + @Override + public MemorySize getJvmOverheadSize() { + return overhead; + } + + @Override + public MemorySize getTotalProcessMemorySize() { + return getJvmHeapMemorySize() + .add(getJvmDirectMemorySize()) + .add(getJvmMetaspaceSize()) + .add(getJvmOverheadSize()); + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/memory/ProcessMemorySpec.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/memory/ProcessMemorySpec.java new file mode 100644 index 00000000..5a6a2fc2 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/memory/ProcessMemorySpec.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config.memory; + +import com.alibaba.flink.shuffle.common.config.MemorySize; + +import java.io.Serializable; + +/** Common interface for shuffle manager and worker JVM process memory components. */ +public interface ProcessMemorySpec extends Serializable { + MemorySize getJvmHeapMemorySize(); + + MemorySize getJvmDirectMemorySize(); + + MemorySize getJvmMetaspaceSize(); + + MemorySize getJvmOverheadSize(); + + MemorySize getTotalProcessMemorySize(); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/memory/ShuffleManagerProcessSpec.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/memory/ShuffleManagerProcessSpec.java new file mode 100644 index 00000000..d26770f9 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/memory/ShuffleManagerProcessSpec.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config.memory; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.config.MemorySize; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** Specification of ShuffleManager memory. */ +public class ShuffleManagerProcessSpec extends AbstractCommonProcessSpec { + + private static final long serialVersionUID = -1434517095656544445L; + + private final MemorySize frameworkHeap; + private final MemorySize frameworkOffHeap; + + public ShuffleManagerProcessSpec(Configuration memConfig) { + super( + memConfig.getMemorySize(ManagerOptions.JVM_METASPACE), + memConfig.getMemorySize(ManagerOptions.JVM_OVERHEAD)); + this.frameworkHeap = + checkNotNull(memConfig.getMemorySize(ManagerOptions.FRAMEWORK_HEAP_MEMORY)); + this.frameworkOffHeap = + checkNotNull(memConfig.getMemorySize(ManagerOptions.FRAMEWORK_OFF_HEAP_MEMORY)); + } + + @Override + public MemorySize getJvmHeapMemorySize() { + return frameworkHeap; + } + + @Override + public MemorySize getJvmDirectMemorySize() { + return frameworkOffHeap; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/memory/ShuffleWorkerProcessSpec.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/memory/ShuffleWorkerProcessSpec.java new file mode 100644 index 00000000..618bc383 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/memory/ShuffleWorkerProcessSpec.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config.memory; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.config.MemorySize; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.config.MemoryOptions; +import com.alibaba.flink.shuffle.core.config.WorkerOptions; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** Specification of ShuffleWorker memory. */ +public class ShuffleWorkerProcessSpec extends AbstractCommonProcessSpec { + + private static final long serialVersionUID = -7719025949639324784L; + + private final MemorySize frameworkHeap; + private final MemorySize frameworkOffHeap; + private final MemorySize networkOffHeap; + + public ShuffleWorkerProcessSpec(Configuration memConfig) { + super( + memConfig.getMemorySize(WorkerOptions.JVM_METASPACE), + memConfig.getMemorySize(WorkerOptions.JVM_OVERHEAD)); + this.frameworkHeap = + checkNotNull(memConfig.getMemorySize(WorkerOptions.FRAMEWORK_HEAP_MEMORY)); + this.frameworkOffHeap = + checkNotNull(memConfig.getMemorySize(WorkerOptions.FRAMEWORK_OFF_HEAP_MEMORY)); + + MemorySize readingMemory = + CommonUtils.checkNotNull( + memConfig.getMemorySize(MemoryOptions.MEMORY_SIZE_FOR_DATA_READING)); + MemorySize writingMemory = + CommonUtils.checkNotNull( + memConfig.getMemorySize(MemoryOptions.MEMORY_SIZE_FOR_DATA_WRITING)); + + this.networkOffHeap = readingMemory.add(writingMemory); + } + + @Override + public MemorySize getJvmHeapMemorySize() { + return frameworkHeap; + } + + @Override + public MemorySize getJvmDirectMemorySize() { + return frameworkOffHeap.add(networkOffHeap); + } + + public MemorySize getFrameworkOffHeap() { + return frameworkOffHeap; + } + + public MemorySize getNetworkOffHeap() { + return networkOffHeap; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/memory/util/ProcessMemoryUtils.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/memory/util/ProcessMemoryUtils.java new file mode 100644 index 00000000..c1ca03a9 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/config/memory/util/ProcessMemoryUtils.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config.memory.util; + +import com.alibaba.flink.shuffle.core.config.memory.ProcessMemorySpec; + +import java.util.ArrayList; +import java.util.List; + +/** Utils for calculating JVM args. */ +public class ProcessMemoryUtils { + + public static String generateJvmArgsStr(ProcessMemorySpec memorySpec, String jvmExtraOptions) { + List commandStrings = new ArrayList<>(); + // jvm mem opts + commandStrings.add(generateJvmMemArgsStr(memorySpec)); + // jvm extra opts + commandStrings.add(jvmExtraOptions); + + return String.join(" ", commandStrings); + } + + private static String generateJvmMemArgsStr(ProcessMemorySpec memorySpec) { + final List jvmArgs = new ArrayList<>(); + + jvmArgs.add("-Xmx" + memorySpec.getJvmHeapMemorySize().getBytes()); + jvmArgs.add("-Xms" + memorySpec.getJvmHeapMemorySize().getBytes()); + jvmArgs.add("-XX:MaxDirectMemorySize=" + memorySpec.getJvmDirectMemorySize().getBytes()); + jvmArgs.add("-XX:MaxMetaspaceSize=" + memorySpec.getJvmMetaspaceSize().getBytes()); + + return String.join(" ", jvmArgs); + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/exception/DuplicatedPartitionException.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/exception/DuplicatedPartitionException.java new file mode 100644 index 00000000..3ce2c1b2 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/exception/DuplicatedPartitionException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.exception; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; + +/** + * Exception to be thrown when the newly added data partition (identified by JobID, DataSetID and + * DataPartitionID) already exists. + */ +public class DuplicatedPartitionException extends ShuffleException { + + private static final long serialVersionUID = 2308889211057032738L; + + public DuplicatedPartitionException(String message) { + super(message); + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/exception/PartitionNotFoundException.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/exception/PartitionNotFoundException.java new file mode 100644 index 00000000..02575ef5 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/exception/PartitionNotFoundException.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.exception; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; + +/** + * Exception to be thrown when the target partition can not be consumed which is a hint for data + * producer to reproduce the corresponding data partition. + */ +public class PartitionNotFoundException extends ShuffleException { + + private static final long serialVersionUID = -4217817087530222073L; + + public PartitionNotFoundException( + DataSetID dataSetID, DataPartitionID dataPartitionID, String message) { + super( + String.format( + "Data partition with %s and %s is not found: %s.", + dataSetID, dataPartitionID, message)); + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/executor/ExecutorThreadFactory.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/executor/ExecutorThreadFactory.java new file mode 100644 index 00000000..9ef56ce8 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/executor/ExecutorThreadFactory.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.executor; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.utils.FatalExitExceptionHandler; + +import javax.annotation.Nullable; + +import java.lang.Thread.UncaughtExceptionHandler; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * A thread factory intended for use by critical thread pools. Critical thread pools here mean + * thread pools that support remote shuffle service core coordination and processing work, and which + * must not simply cause unnoticed errors. + * + *

The thread factory can be given an {@link UncaughtExceptionHandler} for the threads. If no + * handler is explicitly given, the default handler for uncaught exceptions will log the exceptions + * and kill the process afterwards. That guarantees that critical exceptions are not accidentally + * lost and leave the system running in an inconsistent state. + * + *

Threads created by this factory are all called '(pool-name)-thread-n', where + * (pool-name) is configurable, and n is an incrementing number. + * + *

All threads created by this factory are daemon threads and have the default (normal) priority. + * + *

This class is copied from Apache Flink (org.apache.flink.runtime.util.ExecutorThreadFactory). + */ +public class ExecutorThreadFactory implements ThreadFactory { + + /** The thread pool name used when no explicit pool name has been specified. */ + private static final String DEFAULT_POOL_NAME = "shuffle-executor-pool"; + + private final AtomicInteger threadNumber = new AtomicInteger(1); + + private final ThreadGroup group; + + private final String namePrefix; + + private final int threadPriority; + + @Nullable private final UncaughtExceptionHandler exceptionHandler; + + // ------------------------------------------------------------------------ + + /** + * Creates a new thread factory using the default thread pool name ('flink-executor-pool') and + * the default uncaught exception handler (log exception and kill process). + */ + public ExecutorThreadFactory() { + this(DEFAULT_POOL_NAME); + } + + /** + * Creates a new thread factory using the given thread pool name and the default uncaught + * exception handler (log exception and kill process). + * + * @param poolName The pool name, used as the threads' name prefix + */ + public ExecutorThreadFactory(String poolName) { + this(poolName, FatalExitExceptionHandler.INSTANCE); + } + + /** + * Creates a new thread factory using the given thread pool name and the given uncaught + * exception handler. + * + * @param poolName The pool name, used as the threads' name prefix + * @param exceptionHandler The uncaught exception handler for the threads + */ + public ExecutorThreadFactory(String poolName, UncaughtExceptionHandler exceptionHandler) { + this(poolName, Thread.NORM_PRIORITY, exceptionHandler); + } + + ExecutorThreadFactory( + final String poolName, + final int threadPriority, + @Nullable final UncaughtExceptionHandler exceptionHandler) { + this.namePrefix = CommonUtils.checkNotNull(poolName) + "-thread-"; + this.threadPriority = threadPriority; + this.exceptionHandler = exceptionHandler; + + SecurityManager securityManager = System.getSecurityManager(); + this.group = + (securityManager != null) + ? securityManager.getThreadGroup() + : Thread.currentThread().getThreadGroup(); + } + + // ------------------------------------------------------------------------ + + @Override + public Thread newThread(Runnable runnable) { + Thread t = new Thread(group, runnable, namePrefix + threadNumber.getAndIncrement()); + t.setDaemon(true); + + t.setPriority(threadPriority); + + // optional handler for uncaught exceptions + if (exceptionHandler != null) { + t.setUncaughtExceptionHandler(exceptionHandler); + } + + return t; + } + + // -------------------------------------------------------------------------------------------- + + /** Builder for {@link ExecutorThreadFactory}. */ + public static final class Builder { + private String poolName; + private int priority = Thread.NORM_PRIORITY; + private UncaughtExceptionHandler exceptionHandler = FatalExitExceptionHandler.INSTANCE; + + public Builder setPoolName(final String poolName) { + this.poolName = poolName; + return this; + } + + public Builder setThreadPriority(final int priority) { + this.priority = priority; + return this; + } + + public Builder setExceptionHandler(final UncaughtExceptionHandler exceptionHandler) { + this.exceptionHandler = exceptionHandler; + return this; + } + + public ExecutorThreadFactory build() { + return new ExecutorThreadFactory(poolName, priority, exceptionHandler); + } + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/executor/SimpleSingleThreadExecutorPool.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/executor/SimpleSingleThreadExecutorPool.java new file mode 100644 index 00000000..58beffd9 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/executor/SimpleSingleThreadExecutorPool.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.executor; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import java.util.ArrayDeque; +import java.util.Queue; + +/** + * A simple {@link SingleThreadExecutorPool} implementation which assigns all existing {@link + * SingleThreadExecutor}s in a round-robin way. More complicated policies like work stealing can be + * implemented in the future. + */ +public class SimpleSingleThreadExecutorPool implements SingleThreadExecutorPool { + + /** All available {@link SingleThreadExecutor}s. */ + private final Queue singleThreadExecutors = new ArrayDeque<>(); + + /** Total number of {@link SingleThreadExecutor}s. */ + private final int numExecutors; + + /** Whether this {@link SingleThreadExecutorPool} has been destroyed or not. */ + private boolean isDestroyed; + + public SimpleSingleThreadExecutorPool(int numExecutors, String threadName) { + CommonUtils.checkArgument(numExecutors > 0, "Must be positive."); + CommonUtils.checkArgument(threadName != null, "Must be not null."); + + this.numExecutors = numExecutors; + for (int i = 0; i < numExecutors; ++i) { + this.singleThreadExecutors.add(new SingleThreadExecutor(threadName + "-" + i)); + } + } + + @Override + public SingleThreadExecutor getSingleThreadExecutor() { + synchronized (singleThreadExecutors) { + if (isDestroyed) { + throw new ShuffleException("The executor pool has been destroyed."); + } + + SingleThreadExecutor executor = CommonUtils.checkNotNull(singleThreadExecutors.poll()); + singleThreadExecutors.add(executor); + return executor; + } + } + + @Override + public int getNumExecutors() { + return numExecutors; + } + + @Override + public void destroy() { + synchronized (singleThreadExecutors) { + isDestroyed = true; + + for (SingleThreadExecutor singleThreadExecutor : singleThreadExecutors) { + CommonUtils.runQuietly(singleThreadExecutor::shutDown); + } + singleThreadExecutors.clear(); + } + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/executor/SingleThreadExecutor.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/executor/SingleThreadExecutor.java new file mode 100644 index 00000000..712a5b8c --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/executor/SingleThreadExecutor.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.executor; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import javax.annotation.Nonnull; +import javax.annotation.concurrent.GuardedBy; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.Executor; +import java.util.concurrent.RejectedExecutionException; + +/** + * A single-thread {@link Executor} implementation in order to avoid potential race condition and + * simplify multi-threads logic design. + */ +public class SingleThreadExecutor implements Executor { + + /** Lock to protect shared structures and avoid potential race condition. */ + private final Object lock = new Object(); + + /** The executor thread responsible for processing all the pending tasks. */ + private final ExecutorThread executorThread; + + /** All pending {@link Runnable} tasks waiting to be processed by this {@link Executor}. */ + @GuardedBy("lock") + private final Queue tasks = new ArrayDeque<>(); + + /** Whether this {@link Executor} has been shut down or not. */ + @GuardedBy("lock") + private boolean isShutDown; + + public SingleThreadExecutor(String threadName) { + CommonUtils.checkArgument(threadName != null, "Must be not null."); + this.executorThread = new ExecutorThread(threadName); + this.executorThread.start(); + } + + @Override + public void execute(@Nonnull Runnable command) { + synchronized (lock) { + if (isShutDown) { + throw new RejectedExecutionException("Executor has been shut down."); + } + + try { + boolean triggerProcessing = tasks.isEmpty(); + tasks.add(command); + + // notify the executor if there is no task available + // for processing except for this newly added one + if (triggerProcessing) { + lock.notify(); + } + } catch (Throwable throwable) { + throw new RejectedExecutionException("Failed to add new task."); + } + } + } + + /** Returns true if the program is running in the main executor thread. */ + public boolean inExecutorThread() { + return Thread.currentThread() == executorThread; + } + + /** + * Shuts down this {@link Executor} which releases all resources. After that, no task can be + * processed any more. + */ + public void shutDown() { + synchronized (lock) { + isShutDown = true; + executorThread.interrupt(); + } + } + + public boolean isShutDown() { + synchronized (lock) { + return isShutDown; + } + } + + /** + * The executor thread which polls {@link Runnable} tasks from the task queue and executes the + * polled tasks. + */ + private class ExecutorThread extends Thread { + + private ExecutorThread(String threadName) { + super(threadName); + } + + @Override + public void run() { + do { + List pendingTasks; + synchronized (lock) { + // by design, only shut down or new tasks can wake up the + // executor thread, this while loop is added for safety + while (!isShutDown && tasks.isEmpty()) { + CommonUtils.runQuietly(lock::wait); + } + + // exit only when this task executor has been shut down + if (isShutDown) { + tasks.clear(); + break; + } + + pendingTasks = new ArrayList<>(tasks); + tasks.clear(); + } + + // run all pending tasks one by one in FIFO order + for (Runnable task : pendingTasks) { + CommonUtils.runQuietly(task::run); + } + } while (true); + } + } + + // --------------------------------------------------------------------------------------------- + // For test + // --------------------------------------------------------------------------------------------- + + Thread getExecutorThread() { + return executorThread; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/executor/SingleThreadExecutorPool.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/executor/SingleThreadExecutorPool.java new file mode 100644 index 00000000..f45e21a9 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/executor/SingleThreadExecutorPool.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.executor; + +/** A pool from where to allocate {@link SingleThreadExecutor}s. */ +public interface SingleThreadExecutorPool { + + /** Gets a {@link SingleThreadExecutor} from this single-thread executor pool. */ + SingleThreadExecutor getSingleThreadExecutor(); + + /** Returns the numbers of {@link SingleThreadExecutor} in this executor pool. */ + int getNumExecutors(); + + /** Destroys this executor pool and releases all resources. */ + void destroy(); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/BaseID.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/BaseID.java new file mode 100644 index 00000000..a4be4655 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/BaseID.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.ids; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import java.io.DataOutput; +import java.io.IOException; +import java.io.Serializable; +import java.util.Arrays; + +/** A abstract ID implementation based on bytes array. */ +public abstract class BaseID implements Serializable { + + private static final long serialVersionUID = -2348171244792228610L; + + /** ID represented by a byte array. */ + protected final byte[] id; + + /** Pre-calculated hash-code for acceleration. */ + protected final int hashCode; + + public BaseID(byte[] id) { + CommonUtils.checkArgument(id != null, "Must be not null."); + + this.id = id; + this.hashCode = Arrays.hashCode(id); + } + + public BaseID(int length) { + CommonUtils.checkArgument(length > 0, "Must be positive."); + + this.id = CommonUtils.randomBytes(length); + this.hashCode = Arrays.hashCode(id); + } + + public byte[] getId() { + return id; + } + + /** Returns the number of bytes taken if this ID is serialized. */ + public int getFootprint() { + return 4 + id.length; + } + + /** Serializes this ID to the target {@link ByteBuf}. */ + public void writeTo(ByteBuf byteBuf) { + byteBuf.writeInt(id.length); + byteBuf.writeBytes(id); + } + + /** Persists this ID to the target {@link DataOutput}. */ + public void writeTo(DataOutput dataOutput) throws IOException { + dataOutput.writeInt(id.length); + dataOutput.write(id); + } + + @Override + public boolean equals(Object that) { + if (this == that) { + return true; + } + + if (that == null || getClass() != that.getClass()) { + return false; + } + + BaseID thatID = (BaseID) that; + return hashCode == thatID.hashCode && Arrays.equals(id, thatID.id); + } + + @Override + public int hashCode() { + return hashCode; + } + + @Override + public String toString() { + return "BaseID{" + "ID=" + CommonUtils.bytesToHexString(id) + '}'; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/ChannelID.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/ChannelID.java new file mode 100644 index 00000000..c59ecd0c --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/ChannelID.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.ids; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +/** + * A {@link BaseID} to identify a shuffle-read or shuffle-write transaction. Note that this is not + * to identify a physical connection. + */ +public class ChannelID extends BaseID { + + private static final long serialVersionUID = -7984936977579866045L; + + public ChannelID() { + super(16); + } + + public ChannelID(byte[] id) { + super(id); + } + + public static ChannelID readFrom(ByteBuf byteBuf) { + byte[] bytes = new byte[byteBuf.readInt()]; + byteBuf.readBytes(bytes); + return new ChannelID(bytes); + } + + @Override + public String toString() { + return String.format("ChannelID{ID=%s}", CommonUtils.bytesToHexString(id)); + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/DataPartitionID.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/DataPartitionID.java new file mode 100644 index 00000000..986341b1 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/DataPartitionID.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.ids; + +import com.alibaba.flink.shuffle.core.storage.DataPartition; + +/** Base ID of the data partition. */ +public abstract class DataPartitionID extends BaseID { + + private static final long serialVersionUID = -3448851707447828689L; + + public DataPartitionID(byte[] id) { + super(id); + } + + public abstract DataPartition.DataPartitionType getPartitionType(); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/DataSetID.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/DataSetID.java new file mode 100644 index 00000000..61fee979 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/DataSetID.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.ids; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import java.io.DataInput; +import java.io.IOException; + +/** ID of the data partition collection. */ +public class DataSetID extends BaseID { + + private static final long serialVersionUID = -4348308268446349812L; + + public DataSetID(byte[] id) { + super(id); + } + + /** Deserializes and creates an {@link DataSetID} from the given {@link ByteBuf}. */ + public static DataSetID readFrom(ByteBuf byteBuf) { + byte[] bytes = new byte[byteBuf.readInt()]; + byteBuf.readBytes(bytes); + return new DataSetID(bytes); + } + + /** Deserializes and creates an {@link DataSetID} from the given {@link DataInput}. */ + public static DataSetID readFrom(DataInput dataInput) throws IOException { + byte[] bytes = new byte[dataInput.readInt()]; + dataInput.readFully(bytes); + return new DataSetID(bytes); + } + + @Override + public String toString() { + return "DataSetID{" + "ID=" + CommonUtils.bytesToHexString(id) + '}'; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/InstanceID.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/InstanceID.java new file mode 100644 index 00000000..0bc5c996 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/InstanceID.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.ids; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import java.nio.charset.StandardCharsets; + +/** ID of the instance of remote shuffle components like shuffle client. */ +public class InstanceID extends BaseID { + + private static final long serialVersionUID = 4698892147639170213L; + + public InstanceID(byte[] id) { + super(id); + } + + public InstanceID(String id) { + super(id.getBytes(StandardCharsets.UTF_8)); + } + + public InstanceID() { + super(16); + } + + @Override + public String toString() { + return "InstanceID{" + "ID=" + CommonUtils.bytesToHexString(id) + '}'; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/JobID.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/JobID.java new file mode 100644 index 00000000..568144a3 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/JobID.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.ids; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import java.io.DataInput; +import java.io.IOException; + +/** ID of the data producer. */ +public class JobID extends BaseID { + + private static final long serialVersionUID = 9161717086378913090L; + + public JobID(byte[] id) { + super(id); + } + + /** Deserializes and creates an {@link JobID} from the given {@link ByteBuf}. */ + public static JobID readFrom(ByteBuf byteBuf) { + int length = byteBuf.readInt(); + byte[] bytes = new byte[length]; + byteBuf.readBytes(bytes); + return new JobID(bytes); + } + + /** Deserializes and creates an {@link JobID} from the given {@link DataInput}. */ + public static JobID readFrom(DataInput dataInput) throws IOException { + byte[] bytes = new byte[dataInput.readInt()]; + dataInput.readFully(bytes); + return new JobID(bytes); + } + + @Override + public String toString() { + return "JobID{" + "ID=" + CommonUtils.bytesToHexString(id) + '}'; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/MapPartitionID.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/MapPartitionID.java new file mode 100644 index 00000000..9d31a8b2 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/MapPartitionID.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.ids; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.core.storage.MapPartition; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import java.io.DataInput; +import java.io.IOException; + +/** ID of the {@link MapPartition}. */ +public class MapPartitionID extends DataPartitionID { + + private static final long serialVersionUID = -7215025255975348400L; + + public MapPartitionID(byte[] resultPartitionID) { + super(resultPartitionID); + } + + @Override + public DataPartition.DataPartitionType getPartitionType() { + return DataPartition.DataPartitionType.MAP_PARTITION; + } + + /** Deserializes and creates an {@link MapPartitionID} from the given {@link ByteBuf}. */ + public static MapPartitionID readFrom(ByteBuf byteBuf) { + byte[] bytes = new byte[byteBuf.readInt()]; + byteBuf.readBytes(bytes); + + return new MapPartitionID(bytes); + } + + /** Deserializes and creates an {@link MapPartitionID} from the given {@link DataInput}. */ + public static MapPartitionID readFrom(DataInput dataInput) throws IOException { + byte[] bytes = new byte[dataInput.readInt()]; + dataInput.readFully(bytes); + return new MapPartitionID(bytes); + } + + @Override + public String toString() { + return "MapPartitionID{" + "ID=" + CommonUtils.bytesToHexString(id) + '}'; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/ReducePartitionID.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/ReducePartitionID.java new file mode 100644 index 00000000..79c7d101 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/ReducePartitionID.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.ids; + +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.core.storage.ReducePartition; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** ID of the {@link ReducePartition}. */ +public class ReducePartitionID extends DataPartitionID { + + private static final long serialVersionUID = 5443963157496120768L; + + private final int partitionIndex; + + @Override + public DataPartition.DataPartitionType getPartitionType() { + return DataPartition.DataPartitionType.REDUCE_PARTITION; + } + + public ReducePartitionID(int partitionIndex) { + super(getBytes(partitionIndex)); + this.partitionIndex = partitionIndex; + } + + private static byte[] getBytes(int value) { + byte[] bytes = new byte[4]; + ByteBuffer.wrap(bytes).order(ByteOrder.BIG_ENDIAN).putInt(value); + return bytes; + } + + /** Deserializes and creates an {@link ReducePartitionID} from the given {@link DataInput}. */ + public static ReducePartitionID readFrom(DataInput dataInput) throws IOException { + return new ReducePartitionID(dataInput.readInt()); + } + + public int getPartitionIndex() { + return partitionIndex; + } + + @Override + public int getFootprint() { + return 4; + } + + @Override + public void writeTo(ByteBuf byteBuf) { + byteBuf.writeInt(partitionIndex); + } + + @Override + public void writeTo(DataOutput dataOutput) throws IOException { + dataOutput.writeInt(partitionIndex); + } + + @Override + public String toString() { + return "ReducePartitionID{" + "PartitionIndex=" + partitionIndex + '}'; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/RegistrationID.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/RegistrationID.java new file mode 100644 index 00000000..6e9bd433 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/ids/RegistrationID.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.ids; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import java.nio.charset.StandardCharsets; + +/** Registration ID of remote shuffle components registration like shuffle worker registration. */ +public class RegistrationID extends BaseID { + + private static final long serialVersionUID = 390970375272146036L; + + public RegistrationID(byte[] id) { + super(id); + } + + public RegistrationID(String id) { + super(id.getBytes(StandardCharsets.UTF_8)); + } + + public RegistrationID() { + super(16); + } + + @Override + public String toString() { + return "RegistrationID{" + "ID=" + CommonUtils.bytesToHexString(id) + '}'; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/BacklogListener.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/BacklogListener.java new file mode 100644 index 00000000..4252f3f0 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/BacklogListener.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.listener; + +/** Listener to be notified when there is any backlog available. */ +public interface BacklogListener { + + void notifyBacklog(int backlog); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/BufferListener.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/BufferListener.java new file mode 100644 index 00000000..9b524261 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/BufferListener.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.listener; + +import java.nio.ByteBuffer; +import java.util.List; + +/** Listener to be notified when the allocated {@link ByteBuffer}s are available. */ +public interface BufferListener { + + /** Notifies the allocated {@link ByteBuffer}s to the listener. */ + void notifyBuffers(List allocatedBuffers, Throwable throwable); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/DataCommitListener.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/DataCommitListener.java new file mode 100644 index 00000000..03f93f22 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/DataCommitListener.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.listener; + +/** Listener to be notified when data writing succeeds. */ +public interface DataCommitListener { + + /** Notifies the listener that data writing has been finished successfully. */ + void notifyDataCommitted(); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/DataListener.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/DataListener.java new file mode 100644 index 00000000..9378be2d --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/DataListener.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.listener; + +/** Listener to be notified when data is available. */ +public interface DataListener { + + /** Notifies the listener that there is data available for reading. */ + void notifyDataAvailable(); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/DataRegionCreditListener.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/DataRegionCreditListener.java new file mode 100644 index 00000000..30b431cd --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/DataRegionCreditListener.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.listener; + +import com.alibaba.flink.shuffle.core.memory.Buffer; + +/** + * Listener to be notified when any credit ({@link Buffer}) is available for the target data region. + */ +public interface DataRegionCreditListener { + + /** Notifies the available credits of the corresponding data region to the listener. */ + void notifyCredits(int availableCredits, int dataRegionIndex); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/FailureListener.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/FailureListener.java new file mode 100644 index 00000000..36de89bc --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/FailureListener.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.listener; + +import javax.annotation.Nullable; + +/** Listener to be notified when failure occurs. */ +public interface FailureListener { + + /** Notifies the encountered failure to the listener. */ + void notifyFailure(@Nullable Throwable throwable); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/PartitionStateListener.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/PartitionStateListener.java new file mode 100644 index 00000000..981c05b9 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/listener/PartitionStateListener.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.listener; + +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; + +/** Listener to be notified when state of {@link DataPartition} changes. */ +public interface PartitionStateListener { + + /** Notifies the created data partition to this listener. */ + void onPartitionCreated(DataPartitionMeta partitionMeta) throws Exception; + + /** Notifies the removed data partition to this listener. */ + void onPartitionRemoved(DataPartitionMeta partitionMeta); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/memory/Buffer.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/memory/Buffer.java new file mode 100644 index 00000000..bd8216f5 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/memory/Buffer.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.memory; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import org.apache.flink.shaded.netty4.io.netty.buffer.UnpooledByteBufAllocator; +import org.apache.flink.shaded.netty4.io.netty.buffer.UnpooledDirectByteBuf; + +import java.nio.ByteBuffer; + +/** + * We use {@link UnpooledDirectByteBuf} directly to reduce one copy from {@link Buffer} to Netty's + * {@link org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf} when transmitting data over the + * network. + */ +public class Buffer extends UnpooledDirectByteBuf { + + private final ByteBuffer buffer; + + private final BufferRecycler recycler; + + public Buffer(ByteBuffer buffer, BufferRecycler recycler, int readableBytes) { + super(UnpooledByteBufAllocator.DEFAULT, buffer, buffer.capacity()); + + CommonUtils.checkArgument(recycler != null, "Must be not null."); + CommonUtils.checkArgument(buffer.position() == 0, "Position must be 0."); + + this.buffer = buffer; + this.recycler = recycler; + writerIndex(readableBytes); + } + + @Override + protected void deallocate() { + buffer.clear(); + recycler.recycle(buffer); + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/memory/BufferDispatcher.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/memory/BufferDispatcher.java new file mode 100644 index 00000000..fa216839 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/memory/BufferDispatcher.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.memory; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.listener.BufferListener; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.GuardedBy; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; + +/** Wrapper to allocate buffers from {@link ByteBufferPool}. */ +public class BufferDispatcher { + + private static final Logger LOG = LoggerFactory.getLogger(BufferDispatcher.class); + + private final ByteBufferPool bufferPool; + + @GuardedBy("lock") + private final Queue bufferRequirements; + + private final int totalBuffers; + + @GuardedBy("lock") + private boolean destroyed; + + private final Thread dispatcherThread; + + private final Object lock; + + /** + * @param name Name of the underlying buffer pool. + * @param numBuffers Total number of available buffers when start. + * @param bufferSize Size of a single buffer. + */ + public BufferDispatcher(String name, int numBuffers, int bufferSize) { + this.bufferPool = new ByteBufferPool(name, numBuffers, bufferSize); + this.totalBuffers = numBuffers; + this.bufferRequirements = new ArrayDeque<>(); + this.destroyed = false; + this.lock = new Object(); + this.dispatcherThread = new Thread(new Dispatcher()); + dispatcherThread.setName("Buffer Dispatcher"); + dispatcherThread.setDaemon(true); + dispatcherThread.start(); + } + + /** + * Request buffer(s) asynchronously. Note that buffers will be allocated by priorities and + * notified by {@link BufferListener}. + */ + public void requestBuffer( + JobID jobID, + DataSetID dataSetID, + DataPartitionID dataPartitionID, + int min, + int max, + BufferListener listener) { + synchronized (lock) { + bufferRequirements.add(new BufferRequirement(min, max, listener)); + lock.notifyAll(); + } + } + + /** Destroy this {@link BufferDispatcher} and underlying {@link ByteBufferPool}. */ + public void destroy() { + synchronized (lock) { + bufferPool.destroy(); + destroyed = true; + dispatcherThread.interrupt(); + lock.notifyAll(); + } + } + + private class Dispatcher implements Runnable { + + private static final int REQUEST_TIMEOUT = 60 * 60; // in seconds. + + @Override + public void run() { + while (true) { + List buffers = new ArrayList<>(); + BufferRequirement bufferRequirement = null; + try { + synchronized (lock) { + if (destroyed) { + break; + } + while (bufferRequirements.isEmpty() && !destroyed) { + lock.wait(); + } + bufferRequirement = bufferRequirements.poll(); + if (bufferRequirement == null) { + CommonUtils.checkState( + destroyed, + "Polled a bufferRequirement as null, but BufferDispatcher not destroyed yet."); + break; + } + } + while (buffers.size() < bufferRequirement.min) { + ByteBuffer buffer = bufferPool.requestBlocking(REQUEST_TIMEOUT); + if (buffer == null) { + throw new ShuffleException( + "Memory shortage in " + bufferPool.getName()); + } + buffers.add(buffer); + } + ByteBuffer buffer = null; + while (buffers.size() < bufferRequirement.max + && (buffer = bufferPool.requestBuffer()) != null) { + buffers.add(buffer); + } + bufferRequirement.bufferListener.notifyBuffers(buffers, null); + } catch (Exception e) { + LOG.error("Exception when fulfilling buffer requirement.", e); + buffers.forEach(bufferPool::recycle); + if (bufferRequirement != null) { + bufferRequirement.bufferListener.notifyBuffers(null, e); + } + } + } + } + } + + /** Get total number of buffers. */ + public int numTotalBuffers() { + return totalBuffers; + } + + /** Get number of available buffers. */ + public int numAvailableBuffers() { + return bufferPool.numAvailableBuffers(); + } + + /** Recycle a buffer back for further allocation. */ + public void recycleBuffer( + ByteBuffer buffer, JobID jobID, DataSetID dataSetID, DataPartitionID dataPartitionID) { + bufferPool.recycle(buffer); + } + + private static final class BufferRequirement { + + private final int min; + private final int max; + private final BufferListener bufferListener; + + BufferRequirement(int min, int max, BufferListener bufferListener) { + checkArgument( + min > 0 && max > 0 && max >= min, + String.format("Invalid min=%d, max=%d.", min, max)); + this.min = min; + this.max = max; + this.bufferListener = bufferListener; + } + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/memory/BufferRecycler.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/memory/BufferRecycler.java new file mode 100644 index 00000000..95b14353 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/memory/BufferRecycler.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.memory; + +import java.nio.ByteBuffer; + +/** Recycler which can recycle {@link ByteBuffer}s for reuse. */ +public interface BufferRecycler { + + /** Recycles the target {@link ByteBuffer}. */ + void recycle(ByteBuffer buffer); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/memory/BufferSupplier.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/memory/BufferSupplier.java new file mode 100644 index 00000000..1d3f6048 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/memory/BufferSupplier.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.memory; + +/** A supplier from where to poll {@link Buffer}s. */ +public interface BufferSupplier { + + /** Polls a {@link Buffer} from this supplier. */ + Buffer pollBuffer(); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/memory/ByteBufferPool.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/memory/ByteBufferPool.java new file mode 100644 index 00000000..d36be90d --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/memory/ByteBufferPool.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.memory; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; + +/** Buffer pool to provide {@link ByteBuffer} on shuffle workers. */ +public class ByteBufferPool implements BufferRecycler { + + private static final Logger LOG = LoggerFactory.getLogger(ByteBufferPool.class); + + private final String name; + + private final int numBuffers; + + private final int bufferSize; + + private final BlockingQueue availableBuffers; + + private boolean isDestroyed; + + /** + * @param name Name of this buffer pool. + * @param numBuffers Total number of available buffers when start. + * @param bufferSize Size of a single buffer. + */ + public ByteBufferPool(String name, int numBuffers, int bufferSize) { + LOG.info("Creating buffer pool, numBuffers={}, bufferSize={}.", numBuffers, bufferSize); + this.name = name; + this.numBuffers = numBuffers; + this.bufferSize = bufferSize; + this.availableBuffers = new LinkedBlockingDeque<>(); + for (int i = 0; i < numBuffers; i++) { + ByteBuffer byteBuffer = CommonUtils.allocateDirectByteBuffer(bufferSize); + availableBuffers.add(byteBuffer); + } + this.isDestroyed = false; + } + + /** Name of buffer pool. */ + public String getName() { + return name; + } + + /** Request a buffer by non-blocking mode. */ + public ByteBuffer requestBuffer() { + return availableBuffers.poll(); + } + + /** Request a buffer by blocking mode -- returns null when timeout. */ + public ByteBuffer requestBlocking(int timeoutInSeconds) throws InterruptedException { + return availableBuffers.poll(timeoutInSeconds, TimeUnit.SECONDS); + } + + /** Destroy this buffer pool. */ + public void destroy() { + LOG.info("Destroying buffer pool, numBuffers={}, bufferSize={}.", numBuffers, bufferSize); + availableBuffers.clear(); + isDestroyed = true; + } + + /** Whether this buffer pool is already destroyed. */ + public boolean isDestroyed() { + return isDestroyed; + } + + /** Get the number of available buffers at the moment. */ + public int numAvailableBuffers() { + return availableBuffers.size(); + } + + /** Return a {@link ByteBuffer} back to this buffer pool. */ + @Override + public void recycle(ByteBuffer buffer) { + buffer.clear(); + availableBuffers.add(buffer); + if (availableBuffers.size() > numBuffers) { + LOG.error( + "BUG: {} got more buffers than expectation {}.", name, availableBuffers.size()); + } + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/BufferQueue.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/BufferQueue.java new file mode 100644 index 00000000..c2d19c8c --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/BufferQueue.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.NotThreadSafe; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Queue; + +/** A buffer queue implementation enhanced with a release state. */ +@NotThreadSafe +public class BufferQueue { + + public static final BufferQueue RELEASED_EMPTY_BUFFER_QUEUE = new BufferQueue(); + + /** All available buffers in this buffer queue. */ + private final Queue buffers; + + /** Whether this buffer queue is released or not. */ + private boolean isReleased; + + public BufferQueue(List buffers) { + CommonUtils.checkArgument(buffers != null, "Must be not null."); + this.buffers = new ArrayDeque<>(buffers); + } + + private BufferQueue() { + this.isReleased = true; + this.buffers = new ArrayDeque<>(); + } + + /** Returns the number of available buffers in this buffer queue. */ + public int size() { + return buffers.size(); + } + + /** + * Returns an available buffer from this buffer queue or returns null if no buffer is available + * currently. + */ + @Nullable + public ByteBuffer poll() { + return buffers.poll(); + } + + /** + * Adds an available buffer to this buffer queue and will throw exception if this buffer queue + * has been released. + */ + public void add(ByteBuffer availableBuffer) { + CommonUtils.checkArgument(availableBuffer != null, "Must be not null."); + CommonUtils.checkState(!isReleased, "Buffer queue has been released."); + + buffers.add(availableBuffer); + } + + /** + * Adds a collection of available buffers to this buffer queue and will throw exception if this + * buffer queue has been released. + */ + public void add(Collection availableBuffers) { + CommonUtils.checkArgument(availableBuffers != null, "Must be not null."); + CommonUtils.checkState(!isReleased, "Buffer queue has been released."); + + buffers.addAll(availableBuffers); + } + + /** + * Releases this buffer queue and returns all available buffers. After released, no buffer can + * be added to or polled from this buffer queue. + */ + public List release() { + isReleased = true; + + List released = new ArrayList<>(buffers); + buffers.clear(); + return released; + } + + /** Returns true is this buffer queue has been released. */ + public boolean isReleased() { + return isReleased; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/BufferWithBacklog.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/BufferWithBacklog.java new file mode 100644 index 00000000..6b7c5fd3 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/BufferWithBacklog.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.memory.Buffer; + +/** Data buffer and the number of remaining data buffers returned to the data consumer. */ +public class BufferWithBacklog { + + /** Data buffer read from the target {@link DataPartition}. */ + private final Buffer buffer; + + /** Number of remaining data buffers already read in the {@link DataPartitionReader}. */ + private final long backlog; + + public BufferWithBacklog(Buffer buffer, long backlog) { + CommonUtils.checkArgument(buffer != null, "Must be not null."); + CommonUtils.checkArgument(backlog >= 0, "Must be non-negative."); + + this.buffer = buffer; + this.backlog = backlog; + } + + public Buffer getBuffer() { + return buffer; + } + + public long getBacklog() { + return backlog; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartition.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartition.java new file mode 100644 index 00000000..cd89f60c --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartition.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; +import com.alibaba.flink.shuffle.core.listener.BacklogListener; +import com.alibaba.flink.shuffle.core.listener.DataListener; +import com.alibaba.flink.shuffle.core.listener.DataRegionCreditListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; + +import javax.annotation.Nullable; + +import java.util.concurrent.CompletableFuture; + +/** + * {@link DataPartition} is a collection of partitioned data and is the basic data unit managed by + * remote shuffle workers. One or multiple data producers can write the produced data to a {@link + * DataPartition} and one or multiple data consumers can read data from it. + * + *

In distributed computation, large dataset will be divided into smaller data partitions to be + * processed by parallel tasks. Between adjacent computation vertices in a computation graph, the + * producer task will produce a collection of data, which will be divided into multi partitions and + * consumed by the corresponding consumer task. The well known map-reduce naming style is adopted, + * the data collection produced by the producer task is called {@link MapPartition} and the data + * collection consumed by the consumer task is called {@link ReducePartition}. Logically, each piece + * of data belongs to both a {@link MapPartition} and a {@link ReducePartition}, physically, it can + * be stored in either way. + */ +public interface DataPartition { + + /** Returns the {@link DataPartitionMeta} of this data partition. */ + DataPartitionMeta getPartitionMeta(); + + /** Returns the {@link DataPartitionType} of this data partition. */ + DataPartitionType getPartitionType(); + + /** + * Creates and returns a {@link DataPartitionWriter} with the target {@link MapPartitionID}. + * This method must release all allocated resources itself if any exception occurs. + * + * @param mapPartitionID ID of the {@link MapPartition} that the data belongs to. + * @param dataRegionCreditListener Listener to be notified when available. + * @param failureListener Listener to be notified when any failure occurs. + * @return The target {@link DataPartitionWriter} used to write data to this data partition. + */ + DataPartitionWriter createPartitionWriter( + MapPartitionID mapPartitionID, + DataRegionCreditListener dataRegionCreditListener, + FailureListener failureListener) + throws Exception; + + /** + * Creates and returns a {@link DataPartitionReader} for the target range of reduce partitions. + * This method must release all allocated resources itself if any exception occurs. + * + * @param startPartitionIndex Index of the first logic {@link ReducePartition} to read. + * @param endPartitionIndex Index of the last logic {@link ReducePartition} to read (inclusive). + * @param dataListener Listener to be notified when any data is available. + * @param failureListener Listener to be notified when any failure occurs. + * @return The target {@link DataPartitionReader} used to read data from this data partition. + */ + DataPartitionReader createPartitionReader( + int startPartitionIndex, + int endPartitionIndex, + DataListener dataListener, + BacklogListener backlogListener, + FailureListener failureListener) + throws Exception; + + /** + * Releases this data partition which release all resources including the data. This method can + * be called multiple times so must be reentrant. + * + * @param releaseCause Cause which leads to the release, null means released after consumption. + */ + CompletableFuture releasePartition(@Nullable Throwable releaseCause); + + /** Returns a boolean flag indicating whether this data partition is consumable or not. */ + boolean isConsumable(); + + /** + * Type definition of {@link DataPartition}. All {@link DataPartition}s must be either of type + * {@link #MAP_PARTITION} or type {@link #REDUCE_PARTITION}. + */ + enum DataPartitionType { + + /** + * MAP_PARTITION is a type of {@link DataPartition} in which all data must be of the + * same {@link MapPartitionID} but can have different {@link ReducePartitionID}s. For + * example, data produced by a map task in map-reduce computation model can be stored as a + * map partition. + */ + MAP_PARTITION, + + /** + * REDUCE_PARTITION is a type of {@link DataPartition} in which all data must be of + * the same {@link ReducePartitionID} but can have different {@link MapPartitionID}s. For + * example, data consumed by a reduce task in map-reduce computation model can be stored as + * a reduce partition. + */ + REDUCE_PARTITION + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionFactory.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionFactory.java new file mode 100644 index 00000000..f91fdb29 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionFactory.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; + +import java.io.DataInput; + +/** + * Factory to create new {@link DataPartition}s. Different type of {@link DataPartition}s need + * different {@link DataPartitionFactory}s. + */ +public interface DataPartitionFactory { + + /** Initializes this partition factory with the given {@link Configuration}. */ + void initialize(Configuration configuration) throws Exception; + + /** + * Creates and returns a {@link DataPartition} of the corresponding type to write data to and + * read data from. This method should never allocate any resource to avoid resource leak in case + * of worker crash. + * + * @param dataStore {@link PartitionedDataStore} which stores the created {@link DataPartition}. + * @param jobID ID of the job which is trying to write data to the {@link PartitionedDataStore}. + * @param dataSetID ID of the dataset to witch the {@link DataPartition} belongs to. + * @param dataPartitionID ID of the {@link DataPartition} to be created and written. + * @param numReducePartitions Number of logic {@link ReducePartition}s of the {@link DataSet}. + * @return A new {@link DataPartition} to write data to and read data from. + */ + DataPartition createDataPartition( + PartitionedDataStore dataStore, + JobID jobID, + DataSetID dataSetID, + DataPartitionID dataPartitionID, + int numReducePartitions) + throws Exception; + + /** + * Creates and returns a {@link DataPartition} from the given {@link DataPartitionMeta}. This + * method can be used to recover lost {@link DataPartition}s after failure. + * + * @param dataStore {@link PartitionedDataStore} which stores the created {@link DataPartition}. + * @param partitionMeta {@link DataPartitionMeta} used to construct the {@link DataPartition}. + * @return A {@link DataPartition} constructed from the given {@link DataPartitionMeta}. + */ + DataPartition createDataPartition( + PartitionedDataStore dataStore, DataPartitionMeta partitionMeta) throws Exception; + + /** + * Recovers and returns a {@link DataPartitionMeta} instance from the given {@link DataInput}. + * The created {@link DataPartitionMeta} can be used to construct lost {@link DataPartition}s. + */ + DataPartitionMeta recoverDataPartitionMeta(DataInput dataInput) throws Exception; + + /** Returns the data partition type this partition factory is going to create. */ + DataPartition.DataPartitionType getDataPartitionType(); + + static DataPartition.DataPartitionType getDataPartitionType(String dataPartitionFactoryName) + throws ClassNotFoundException, InstantiationException, IllegalAccessException { + return ((DataPartitionFactory) Class.forName(dataPartitionFactoryName).newInstance()) + .getDataPartitionType(); + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionMeta.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionMeta.java new file mode 100644 index 00000000..accd7238 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionMeta.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; + +import java.io.DataOutput; +import java.io.Serializable; +import java.util.List; + +/** + * Meta information of the {@link DataPartition}. It is supposed to be able to reconstruct the lost + * {@link DataPartition} from the corresponding {@link DataPartitionMeta}. + */ +public abstract class DataPartitionMeta implements Serializable { + + private static final long serialVersionUID = -5045608362993147450L; + + protected final StorageMeta storageMeta; + + protected final JobID jobID; + + protected final DataSetID dataSetID; + + public DataPartitionMeta(JobID jobID, DataSetID dataSetID, StorageMeta storageMeta) { + CommonUtils.checkArgument(jobID != null, "Must be not null."); + CommonUtils.checkArgument(dataSetID != null, "Must be not null."); + CommonUtils.checkArgument(storageMeta != null, "Must be not null."); + + this.jobID = jobID; + this.dataSetID = dataSetID; + this.storageMeta = storageMeta; + } + + public JobID getJobID() { + return jobID; + } + + public DataSetID getDataSetID() { + return dataSetID; + } + + public abstract DataPartitionID getDataPartitionID(); + + public StorageMeta getStorageMeta() { + return storageMeta; + } + + /** + * Returns the factory class name which can be used to create the corresponding {@link + * DataPartition}. + */ + public abstract String getPartitionFactoryClassName(); + + /** + * Returns all {@link MapPartitionID}s of the data stored in the corresponding {@link + * DataPartition}. + */ + public abstract List getMapPartitionIDs(); + + /** + * Serializes this {@link DataPartitionMeta} to the {@link DataOutput} which can be used to + * reconstruct the lost {@link DataPartition}. + */ + public abstract void writeTo(DataOutput dataOutput) throws Exception; +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionReader.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionReader.java new file mode 100644 index 00000000..3831d472 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionReader.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.core.memory.BufferRecycler; + +/** + * Data reader for {@link DataPartition}. Each {@link DataPartitionReader} can read data from one or + * multiple consecutive {@link ReducePartition}s. + */ +public interface DataPartitionReader extends Comparable { + + /** + * Opens this partition reader for data reading. A partition reader must be opened first before + * reading any data. + */ + void open() throws Exception; + + /** + * Reads data from the corresponding {@link DataPartition}. It can allocate buffers from the + * given {@link BufferQueue} if any memory is needed and the allocated buffers must be recycled + * by the given {@link BufferRecycler}. + */ + boolean readData(BufferQueue buffers, BufferRecycler recycler) throws Exception; + + /** + * Returns a {@link BufferWithBacklog} instance containing one data buffer read and the number + * of remaining data buffers already read in this partition reader. + */ + BufferWithBacklog nextBuffer() throws Exception; + + /** Releases this partition reader which releases all resources if any exception occurs. */ + void release(Throwable throwable) throws Exception; + + /** + * Returns true if this partition reader has finished reading the target {@link DataPartition} + * and all the data read has been consumed by the corresponding data consumer. + */ + boolean isFinished(); + + /** + * Returns the data reading priority of this partition reader. Lower value means higher priority + * and multiple partition readers of the same {@link DataPartition} will read data in the order + * of this priority. + */ + long getPriority(); + + /** + * Notifies the failure to this partition reader when any exception occurs at the corresponding + * data consumer side. + */ + void onError(Throwable throwable); + + /** + * Whether this partition reader has been opened or not. A partition reader must be opened first + * before reading any data. + */ + boolean isOpened(); + + @Override + default int compareTo(DataPartitionReader that) { + return Long.compare(getPriority(), that.getPriority()); + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionReadingView.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionReadingView.java new file mode 100644 index 00000000..9801f44a --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionReadingView.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import javax.annotation.Nullable; + +/** + * When trying to read data from a {@link DataPartition}, the data consumer needs to call {@link + * PartitionedDataStore#createDataPartitionReadingView} which will return an instance of this + * interface, then the consumer can read data from the target {@link DataPartition} through the + * {@link #nextBuffer} method. + */ +public interface DataPartitionReadingView { + + /** Reads a buffer from this reading view. Returns null if no buffer is available. */ + @Nullable + BufferWithBacklog nextBuffer() throws Exception; + + /** + * Notifies an error to the {@link DataPartitionReader} if the data consumer encounters any + * unrecoverable failure. + */ + void onError(Throwable throwable); + + /** Returns true if all target data has been consumed successfully. */ + boolean isFinished(); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionWriter.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionWriter.java new file mode 100644 index 00000000..e9fe602c --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionWriter.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; +import com.alibaba.flink.shuffle.core.listener.DataCommitListener; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.memory.BufferRecycler; +import com.alibaba.flink.shuffle.core.memory.BufferSupplier; + +/** + * Data writer for {@link DataPartition}. Each {@link DataPartitionWriter} can write data of one + * same {@link MapPartitionID} to the target {@link DataPartition}, that is, those data must come + * from the same data producer. + */ +public interface DataPartitionWriter extends BufferSupplier { + + /** + * Returns the corresponding {@link MapPartitionID} of the data to be written by this partition + * writer. + */ + MapPartitionID getMapPartitionID(); + + /** + * Writes the input data to the target {@link DataPartition} and returns true if all data has + * been written successfully. + */ + boolean writeData() throws Exception; + + /** + * Adds a data {@link Buffer} of the given {@link MapPartitionID} and {@link ReducePartitionID} + * to this partition writer. + */ + void addBuffer(ReducePartitionID reducePartitionID, Buffer buffer); + + /** + * Starts a new data region and announces the number of credits required by the data region. + * + * @param dataRegionIndex Index of the new data region to be written. + * @param isBroadcastRegion Whether to broadcast data to all reduce partitions in this region. + */ + void startRegion(int dataRegionIndex, boolean isBroadcastRegion); + + /** + * Finishes the current data region, after which the current data region is completed and ready + * to be processed. + */ + void finishRegion(); + + /** + * Finishes the data input, which means no data can be added to this partition writer any more. + * + * @param commitListener Listener to be notified after all data of the input is committed. + */ + void finishDataInput(DataCommitListener commitListener); + + /** + * Assigns credits to this partition writer to be used to receive data from the corresponding + * data producer. Returns true if this partition writer still needs more credits (buffers) for + * data receiving. + */ + boolean assignCredits(BufferQueue credits, BufferRecycler recycler); + + /** + * Notifies the failure to this partition writer when any exception occurs at the corresponding + * data producer side. + */ + void onError(Throwable throwable); + + /** Releases this partition writer which releases all resources if any exception occurs. */ + void release(Throwable throwable) throws Exception; +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionWritingView.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionWritingView.java new file mode 100644 index 00000000..cab8049f --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataPartitionWritingView.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; +import com.alibaba.flink.shuffle.core.listener.DataCommitListener; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.memory.BufferSupplier; + +import javax.annotation.Nullable; + +/** + * When trying to write data to a {@link DataPartition}, the data producer needs to call {@link + * PartitionedDataStore#createDataPartitionWritingView} which will return an instance of this + * interface, then the producer can write data to the target {@link DataPartition} data through + * these {@link #regionStarted}, {@link #onBuffer}, {@link #regionFinished} and {@link #finish} + * method. + * + *

Data {@link Buffer} is the smallest piece of data can be written and data region is the basic + * data unit can be written which can contain multiple data {@link Buffer}s. Data must be divided + * into one or multiple data regions before writing. Each data region is a piece data that can be + * consumed independently, which means a data region can not contain any partial records and data + * compression should never span multiple data regions. As a result, we can rearrange the data + * regions consumed by the same data consumer freely. + * + *

Before writing a data region, the {@link #regionStarted} method must be called, after which, + * data {@link Buffer}s can be written through the {@link #onBuffer} method. Then the data region + * can be marked as finished by the {@link #regionFinished} method. After all target data has been + * written, the {@link #finish} method can be called to finish data writing. + */ +public interface DataPartitionWritingView { + + /** + * Writes a {@link Buffer} of the given {@link MapPartitionID} and {@link ReducePartitionID} to + * the corresponding {@link DataPartition} through this partition writing view. + */ + void onBuffer(Buffer buffer, ReducePartitionID reducePartitionID); + + /** + * Marks the starting of a new data region and announces the number of credits required by the + * new data region. + * + * @param dataRegionIndex Index of the new data region to be written. + * @param isBroadcastRegion Whether to broadcast data to all reduce partitions in this region. + */ + void regionStarted(int dataRegionIndex, boolean isBroadcastRegion); + + /** + * Marks the current data region as finished, after which no data of the same region will be + * written any more and the current data region is completed and ready to be processed. + */ + void regionFinished(); + + /** + * Finishes the data input, which means no data will be written through this partition writing + * view any more. + * + * @param commitListener Listener to be notified after all data of the input is committed. + */ + void finish(DataCommitListener commitListener); + + /** + * Notifies an error to the {@link DataPartitionWriter} if the data producer encounters any + * unrecoverable failure. + */ + void onError(@Nullable Throwable throwable); + + /** + * Returns a {@link BufferSupplier} instance from which free {@link Buffer}s can be allocated + * for data receiving of the network stack. + */ + BufferSupplier getBufferSupplier(); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataSet.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataSet.java new file mode 100644 index 00000000..83aef92c --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/DataSet.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.exception.DuplicatedPartitionException; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; + +import javax.annotation.concurrent.NotThreadSafe; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * {@link DataSet} is a collection of {@link DataPartition}s which have the same {@link DataSetID}. + * For example, {@link DataPartition}s produced by different parallel tasks of the same computation + * vertex can have the same {@link DataSetID} and belong to the same {@link DataSet}. + */ +@NotThreadSafe +public class DataSet { + + /** All {@link DataPartition}s in this dataset. */ + private final HashMap dataPartitions; + + /** ID of the job which produced this dataset. */ + private final JobID jobID; + + /** ID of this dataset. */ + private final DataSetID dataSetID; + + public DataSet(JobID jobID, DataSetID dataSetID) { + CommonUtils.checkArgument(jobID != null, "Must be not null."); + CommonUtils.checkArgument(dataSetID != null, "Must be not null."); + + this.jobID = jobID; + this.dataSetID = dataSetID; + this.dataPartitions = new HashMap<>(); + } + + /** Returns the {@link DataPartition} with the given {@link DataPartitionID}. */ + public DataPartition getDataPartition(DataPartitionID partitionID) { + CommonUtils.checkArgument(partitionID != null, "Must be not null."); + + return dataPartitions.get(partitionID); + } + + /** + * Adds the given {@link DataPartition} to this dataset. Exception will be thrown if the {@link + * DataPartition} with the same {@link DataPartitionID} already exists in this dataset. + */ + public void addDataPartition(DataPartition dataPartition) { + CommonUtils.checkArgument(dataPartition != null, "Must be not null."); + + DataPartitionID partitionID = dataPartition.getPartitionMeta().getDataPartitionID(); + if (dataPartitions.containsKey(partitionID)) { + throw new DuplicatedPartitionException(dataPartition.getPartitionMeta().toString()); + } + dataPartitions.put(partitionID, dataPartition); + } + + public JobID getJobID() { + return jobID; + } + + public DataSetID getDataSetID() { + return dataSetID; + } + + /** Returns the number of {@link DataPartition}s in this dataset. */ + public int getNumDataPartitions() { + return dataPartitions.size(); + } + + /** + * Returns true if this dataset contains the target {@link DataPartition} of the given {@link + * DataPartitionID}. + */ + public boolean containsDataPartition(DataPartitionID dataPartitionID) { + return dataPartitions.containsKey(dataPartitionID); + } + + /** + * Removes the {@link DataPartition} with the given {@link DataPartitionID} from this dataset. + */ + public DataPartition removeDataPartition(DataPartitionID partitionID) { + CommonUtils.checkArgument(partitionID != null, "Must be not null."); + + return dataPartitions.remove(partitionID); + } + + /** + * Removes all {@link DataPartition}s in this dataset and returns a list of the cleared {@link + * DataPartition}s. + */ + public List clearDataPartitions() { + List partitions = new ArrayList<>(dataPartitions.values()); + dataPartitions.clear(); + return partitions; + } + + /** Returns all {@link DataPartition}s belonging to this dataset. */ + public ArrayList getDataPartitions() { + return new ArrayList<>(dataPartitions.values()); + } + + @Override + public String toString() { + return "DataSet{" + + "JobID=" + + jobID + + ", DataSetID=" + + dataSetID + + ", NumDataPartitions=" + + dataPartitions.size() + + '}'; + } + + // --------------------------------------------------------------------------------------------- + // For test + // --------------------------------------------------------------------------------------------- + + public Set getDataPartitionIDs() { + return new HashSet<>(dataPartitions.keySet()); + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/MapPartition.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/MapPartition.java new file mode 100644 index 00000000..2e15b9b0 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/MapPartition.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +/** Definition of {@link DataPartitionType#MAP_PARTITION}. */ +public interface MapPartition extends DataPartition { + + @Override + MapPartitionMeta getPartitionMeta(); + + @Override + default DataPartitionType getPartitionType() { + return DataPartitionType.MAP_PARTITION; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/MapPartitionMeta.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/MapPartitionMeta.java new file mode 100644 index 00000000..84de0a00 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/MapPartitionMeta.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; + +import java.util.Collections; +import java.util.List; + +/** {@link DataPartitionMeta} of {@link MapPartition}. */ +public abstract class MapPartitionMeta extends DataPartitionMeta { + + private static final long serialVersionUID = 1476185767576064848L; + + protected final MapPartitionID partitionID; + + public MapPartitionMeta( + JobID jobID, DataSetID dataSetID, MapPartitionID partitionID, StorageMeta storageMeta) { + super(jobID, dataSetID, storageMeta); + + CommonUtils.checkArgument(partitionID != null, "Must be not null."); + this.partitionID = partitionID; + } + + @Override + public MapPartitionID getDataPartitionID() { + return partitionID; + } + + @Override + public List getMapPartitionIDs() { + return Collections.singletonList(getDataPartitionID()); + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/PartitionedDataStore.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/PartitionedDataStore.java new file mode 100644 index 00000000..68e8acf1 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/PartitionedDataStore.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.executor.SingleThreadExecutor; +import com.alibaba.flink.shuffle.core.executor.SingleThreadExecutorPool; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.memory.BufferDispatcher; + +import javax.annotation.Nullable; + +/** + * {@link PartitionedDataStore} is the storage of {@link DataPartition}s. Different types of {@link + * DataPartition}s can be added to and removed from this data store. + */ +public interface PartitionedDataStore { + + /** + * Creates a {@link DataPartitionWritingView} instance as the channel to write data to. This + * method must be called before writing any data to this data store and for each logic {@link + * MapPartition}, a new exclusive writing view will be created and returned. + * + * @return A {@link DataPartitionWritingView} instance as the channel to write data to. + */ + DataPartitionWritingView createDataPartitionWritingView(WritingViewContext context) + throws Exception; + + /** + * Creates a {@link DataPartitionReadingView} instance as the channel to read data from. This + * method must be called before reading any data from this data store and for each logic {@link + * ReducePartition} being read, a new exclusive reading view will be created and returned. + * + * @return A {@link DataPartitionReadingView} instance as the channel to read data from. + */ + DataPartitionReadingView createDataPartitionReadingView(ReadingViewContext context) + throws Exception; + + /** Returns a boolean flag indicating whether the target {@link DataPartition} is consumable. */ + boolean isDataPartitionConsumable(DataPartitionMeta partitionMeta); + + /** + * Adds a new {@link DataPartition} to this data store. This happens when adding an external + * {@link DataPartition} or restarting from failure. Exception will be thrown if the target + * partition (identified by and {@link DataPartitionID}, {@link DataSetID} and {@link JobID}) + * already exists in this data store. + */ + void addDataPartition(DataPartitionMeta partitionMeta) throws Exception; + + /** + * Removes the {@link DataPartition} identified by the given {@link DataPartitionMeta} from this + * data store. Different from {@link #releaseDataPartition}, this method does not releases the + * corresponding {@link DataPartition}. + */ + void removeDataPartition(DataPartitionMeta partitionMeta); + + /** + * Releases the and removes {@link DataPartition} identified by the given {@link DataSetID} and + * {@link DataPartitionID} from this data store. + * + *

Note: This method works asynchronously so does not release target partition immediately. + */ + void releaseDataPartition( + DataSetID dataSetID, DataPartitionID partitionID, @Nullable Throwable throwable); + + /** + * Releases all the {@link DataPartition}s belonging to the target {@link DataSet} identified by + * the given {@link DataSetID} from this data store. + * + *

Note: This method works asynchronously so does not release target partitions immediately. + */ + void releaseDataSet(DataSetID dataSetID, @Nullable Throwable throwable); + + /** + * Releases all the {@link DataPartition}s produced by the corresponding job identified by the + * given {@link JobID} from this data store. + * + *

Note: This method works asynchronously so does not release target partitions immediately. + */ + void releaseDataByJobID(JobID jobID, @Nullable Throwable throwable); + + /** + * Shuts down this data store and releases the resources. + * + * @param releaseData Whether to also release all data or not. + */ + void shutDown(boolean releaseData); + + /** Returns true if this data store has been shut down. */ + boolean isShutDown(); + + /** Returns the cluster {@link Configuration} to read the configured values. */ + Configuration getConfiguration(); + + /** + * Returns the {@link BufferDispatcher} to allocate {@link java.nio.ByteBuffer}s for data + * writing. + */ + BufferDispatcher getWritingBufferDispatcher(); + + /** + * Returns the {@link BufferDispatcher} to allocate {@link java.nio.ByteBuffer}s for data + * reading. + */ + BufferDispatcher getReadingBufferDispatcher(); + + /** + * Returns the {@link SingleThreadExecutorPool} to allocate {@link SingleThreadExecutor}s for + * {@link DataPartition} processing. + */ + SingleThreadExecutorPool getExecutorPool(StorageMeta storageMeta); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/ReadingViewContext.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/ReadingViewContext.java new file mode 100644 index 00000000..ca2fa0bc --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/ReadingViewContext.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; +import com.alibaba.flink.shuffle.core.listener.BacklogListener; +import com.alibaba.flink.shuffle.core.listener.DataListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; + +/** Context used to create {@link DataPartitionReadingView}. */ +public class ReadingViewContext { + + /** ID of the {@link DataPartition} to read data from. */ + private final DataPartitionID partitionID; + + /** ID of the {@link DataSet} to which the {@link DataPartition} belongs. */ + private final DataSetID dataSetID; + + /** Index of the first logic {@link ReducePartition} to be read (inclusive). */ + private final int startPartitionIndex; + + /** Index of the last logic {@link ReducePartition} to be read (inclusive). */ + private final int endPartitionIndex; + + /** Listener to be notified when there is any data available for reading. */ + private final DataListener dataListener; + + /** Listener to be notified when there is any backlog available in the reading view. */ + private final BacklogListener backlogListener; + + /** Listener to be notified when any internal exception occurs. */ + private final FailureListener failureListener; + + public ReadingViewContext( + DataSetID dataSetID, + DataPartitionID partitionID, + int startPartitionIndex, + int endPartitionIndex, + DataListener dataListener, + BacklogListener backlogListener, + FailureListener failureListener) { + CommonUtils.checkArgument(dataSetID != null, "Must be not null."); + CommonUtils.checkArgument(partitionID != null, "Must be not null."); + CommonUtils.checkArgument(startPartitionIndex >= 0, "Must be non-negative."); + CommonUtils.checkArgument(endPartitionIndex >= startPartitionIndex, "Illegal index range."); + CommonUtils.checkArgument(dataListener != null, "Must be not null."); + CommonUtils.checkArgument(backlogListener != null, "Must be not null."); + CommonUtils.checkArgument(failureListener != null, "Must be not null."); + + if (partitionID.getPartitionType() == DataPartition.DataPartitionType.REDUCE_PARTITION) { + ReducePartitionID reducePartitionID = (ReducePartitionID) partitionID; + CommonUtils.checkArgument( + reducePartitionID.getPartitionIndex() == endPartitionIndex + && reducePartitionID.getPartitionIndex() == startPartitionIndex, + "Illegal reduce partition index range."); + } + + this.partitionID = partitionID; + this.dataSetID = dataSetID; + this.startPartitionIndex = startPartitionIndex; + this.endPartitionIndex = endPartitionIndex; + this.dataListener = dataListener; + this.backlogListener = backlogListener; + this.failureListener = failureListener; + } + + public DataPartitionID getPartitionID() { + return partitionID; + } + + public DataSetID getDataSetID() { + return dataSetID; + } + + public int getStartPartitionIndex() { + return startPartitionIndex; + } + + public int getEndPartitionIndex() { + return endPartitionIndex; + } + + public DataListener getDataListener() { + return dataListener; + } + + public BacklogListener getBacklogListener() { + return backlogListener; + } + + public FailureListener getFailureListener() { + return failureListener; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/ReducePartition.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/ReducePartition.java new file mode 100644 index 00000000..c1c95327 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/ReducePartition.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +/** Definition of {@link DataPartitionType#REDUCE_PARTITION}. */ +public interface ReducePartition extends DataPartition { + + @Override + ReducePartitionMeta getPartitionMeta(); + + @Override + default DataPartitionType getPartitionType() { + return DataPartitionType.REDUCE_PARTITION; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/ReducePartitionMeta.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/ReducePartitionMeta.java new file mode 100644 index 00000000..40d65ffb --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/ReducePartitionMeta.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** {@link DataPartitionMeta} of {@link ReducePartition}. */ +public abstract class ReducePartitionMeta extends DataPartitionMeta { + + private static final long serialVersionUID = -7617298646112796354L; + + protected final ReducePartitionID partitionID; + + protected final Set mapPartitionIDS = new HashSet<>(); + + public ReducePartitionMeta( + JobID jobID, + DataSetID dataSetID, + ReducePartitionID partitionID, + StorageMeta storageMeta) { + super(jobID, dataSetID, storageMeta); + + CommonUtils.checkArgument(partitionID != null, "Must be not null."); + this.partitionID = partitionID; + } + + @Override + public ReducePartitionID getDataPartitionID() { + return partitionID; + } + + @Override + public List getMapPartitionIDs() { + return new ArrayList<>(mapPartitionIDS); + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/StorageMeta.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/StorageMeta.java new file mode 100644 index 00000000..7eb46aa3 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/StorageMeta.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.io.Serializable; +import java.util.Objects; + +/** Meta information of the data storage for {@link DataPartition}. */ +public class StorageMeta implements Serializable { + + private static final long serialVersionUID = 7636731224603174535L; + + private final String storagePath; + + private final StorageType storageType; + + public StorageMeta(String storagePath, StorageType storageType) { + CommonUtils.checkArgument(storagePath != null, "Must be not null."); + CommonUtils.checkArgument(storageType != null, "Must be not null."); + + this.storagePath = storagePath; + this.storageType = storageType; + } + + public String getStoragePath() { + return storagePath; + } + + public StorageType getStorageType() { + return storageType; + } + + public void writeTo(DataOutput dataOutput) throws IOException { + dataOutput.writeUTF(storageType.name()); + dataOutput.writeUTF(storagePath); + } + + public static StorageMeta readFrom(DataInput dataInput) throws IOException { + StorageType storageType = StorageType.valueOf(dataInput.readUTF()); + String storagePath = dataInput.readUTF(); + return new StorageMeta(storagePath, storageType); + } + + @Override + public boolean equals(Object that) { + if (this == that) { + return true; + } + + if (that == null || getClass() != that.getClass()) { + return false; + } + + StorageMeta thatMeta = (StorageMeta) that; + return Objects.equals(storagePath, thatMeta.storagePath) + && storageType == thatMeta.storageType; + } + + @Override + public int hashCode() { + return Objects.hash(storagePath, storageType); + } + + @Override + public String toString() { + return "StorageMeta{" + + "StoragePath='" + + storagePath + + ", StorageType=" + + storageType + + '}'; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/StorageType.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/StorageType.java new file mode 100644 index 00000000..87d513e7 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/StorageType.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +/** Definition of storage type for {@link DataPartition}. */ +public enum StorageType { + SSD, + HDD, + MEMORY +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/WritingViewContext.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/WritingViewContext.java new file mode 100644 index 00000000..c5dca80c --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/storage/WritingViewContext.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.listener.DataRegionCreditListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; + +/** Context used to create {@link DataPartitionWritingView}. */ +public class WritingViewContext { + + /** ID of the job which is trying to write data. */ + private final JobID jobID; + + /** ID of the {@link DataSet} that the written data belongs to. */ + private final DataSetID dataSetID; + + /** ID of the {@link DataPartition} that the written data belongs to. */ + private final DataPartitionID dataPartitionID; + + /** ID of the logic {@link MapPartition} that the written data belongs to. */ + private final MapPartitionID mapPartitionID; + + /** Number of the {@link ReducePartition}s of the whole {@link DataSet}. */ + private final int numReducePartitions; + + /** Factory class name used to crate the target {@link DataPartition}. */ + private final String partitionFactoryClassName; + + /** Listener to be notified when there are any new credits available. */ + private final DataRegionCreditListener dataRegionCreditListener; + + /** Listener to be notified when any internal exception occurs. */ + private final FailureListener failureListener; + + public WritingViewContext( + JobID jobID, + DataSetID dataSetID, + DataPartitionID dataPartitionID, + MapPartitionID mapPartitionID, + int numReducePartitions, + String partitionFactoryClassName, + DataRegionCreditListener dataRegionCreditListener, + FailureListener failureListener) { + CommonUtils.checkArgument(jobID != null, "Must be not null."); + CommonUtils.checkArgument(dataSetID != null, "Must be not null."); + CommonUtils.checkArgument(dataPartitionID != null, "Must be not null."); + CommonUtils.checkArgument(mapPartitionID != null, "Must be not null."); + CommonUtils.checkArgument(numReducePartitions > 0, "Must be positive."); + CommonUtils.checkArgument(partitionFactoryClassName != null, "Must be not null."); + CommonUtils.checkArgument(dataRegionCreditListener != null, "Must be not null."); + CommonUtils.checkArgument(failureListener != null, "Must be not null."); + + this.jobID = jobID; + this.dataSetID = dataSetID; + this.dataPartitionID = dataPartitionID; + this.mapPartitionID = mapPartitionID; + this.numReducePartitions = numReducePartitions; + this.partitionFactoryClassName = partitionFactoryClassName; + this.dataRegionCreditListener = dataRegionCreditListener; + this.failureListener = failureListener; + } + + public JobID getJobID() { + return jobID; + } + + public DataSetID getDataSetID() { + return dataSetID; + } + + public DataPartitionID getDataPartitionID() { + return dataPartitionID; + } + + public MapPartitionID getMapPartitionID() { + return mapPartitionID; + } + + public int getNumReducePartitions() { + return numReducePartitions; + } + + public String getPartitionFactoryClassName() { + return partitionFactoryClassName; + } + + public DataRegionCreditListener getDataRegionCreditListener() { + return dataRegionCreditListener; + } + + public FailureListener getFailureListener() { + return failureListener; + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/utils/BashJavaUtils.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/utils/BashJavaUtils.java new file mode 100644 index 00000000..377c2dfe --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/utils/BashJavaUtils.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.utils; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; +import com.alibaba.flink.shuffle.core.config.WorkerOptions; +import com.alibaba.flink.shuffle.core.config.memory.ShuffleManagerProcessSpec; +import com.alibaba.flink.shuffle.core.config.memory.ShuffleWorkerProcessSpec; +import com.alibaba.flink.shuffle.core.config.memory.util.ProcessMemoryUtils; + +import org.apache.commons.cli.Options; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.core.utils.ConfigurationParserUtils.CONFIG_DIR_OPTION; +import static com.alibaba.flink.shuffle.core.utils.ConfigurationParserUtils.DYNAMIC_PROPERTY_OPTION; + +/** Utility class for using java utilities in bash scripts. */ +public class BashJavaUtils { + private static final Logger LOG = LoggerFactory.getLogger(BashJavaUtils.class); + + public static final String EXECUTION_PREFIX = "BASH_JAVA_UTILS_EXEC_RESULT:"; + + private BashJavaUtils() {} + + public static void main(String[] args) throws Exception { + checkArgument(args.length > 0, "Command not specified."); + + Command command = Command.valueOf(args[0]); + String[] commandArgs = Arrays.copyOfRange(args, 1, args.length); + String outputLine = runCommand(command, commandArgs); + System.out.println(EXECUTION_PREFIX + outputLine); + } + + private static String runCommand(Command command, String[] commandArgs) throws Exception { + Configuration configuration = + ConfigurationParserUtils.loadConfiguration(filterCmdArgs(commandArgs)); + switch (command) { + case GET_SHUFFLE_WORKER_JVM_PARAMS: + return getShuffleWorkerJvmParams(configuration); + case GET_SHUFFLE_MANAGER_JVM_PARAMS: + return getShuffleManagerJvmParams(configuration); + default: + // unexpected, Command#valueOf should fail if a unknown command is passed in + throw new RuntimeException("Unexpected, something is wrong."); + } + } + + private static String[] filterCmdArgs(String[] args) { + final List filteredArgs = new ArrayList<>(); + final Iterator iter = Arrays.asList(args).iterator(); + final Options cmdOptions = getCmdOptions(); + + while (iter.hasNext()) { + String token = iter.next(); + if (cmdOptions.hasOption(token)) { + filteredArgs.add(token); + if (cmdOptions.getOption(token).hasArg() && iter.hasNext()) { + filteredArgs.add(iter.next()); + } + } else if (token.startsWith("-D")) { + // "-Dkey=value" + filteredArgs.add(token); + } + } + + return filteredArgs.toArray(new String[0]); + } + + private static Options getCmdOptions() { + final Options cmdOptions = new Options(); + cmdOptions.addOption(CONFIG_DIR_OPTION); + cmdOptions.addOption(DYNAMIC_PROPERTY_OPTION); + return cmdOptions; + } + + /** Generate and print JVM parameters of Shuffle Worker resources as one line. */ + private static String getShuffleWorkerJvmParams(Configuration configuration) { + + ShuffleWorkerProcessSpec shuffleWorkerProcessSpec = + new ShuffleWorkerProcessSpec(configuration); + + logShuffleWorkerMemoryConfiguration(shuffleWorkerProcessSpec); + + return ProcessMemoryUtils.generateJvmArgsStr( + shuffleWorkerProcessSpec, configuration.getString(WorkerOptions.JVM_OPTIONS)); + } + + /** Generate and print JVM parameters of Shuffle Manager resources as one line. */ + private static String getShuffleManagerJvmParams(Configuration configuration) { + ShuffleManagerProcessSpec shuffleManagerProcessSpec = + new ShuffleManagerProcessSpec(configuration); + + logShuffleManagerMemoryConfiguration(shuffleManagerProcessSpec); + + return ProcessMemoryUtils.generateJvmArgsStr( + shuffleManagerProcessSpec, configuration.getString(ManagerOptions.JVM_OPTIONS)); + } + + private static void logShuffleManagerMemoryConfiguration(ShuffleManagerProcessSpec spec) { + LOG.info("ShuffleManager Memory configuration:"); + LOG.info( + " Total Process Memory: {}", + spec.getTotalProcessMemorySize().toHumanReadableString()); + LOG.info( + " JVM Heap Memory: {}", + spec.getJvmHeapMemorySize().toHumanReadableString()); + LOG.info( + " JVM Direct Memory: {}", + spec.getJvmDirectMemorySize().toHumanReadableString()); + LOG.info( + " JVM Metaspace: {}", + spec.getJvmMetaspaceSize().toHumanReadableString()); + LOG.info( + " JVM Overhead: {}", + spec.getJvmOverheadSize().toHumanReadableString()); + } + + private static void logShuffleWorkerMemoryConfiguration(ShuffleWorkerProcessSpec spec) { + LOG.info("ShuffleWorker Memory configuration:"); + LOG.info( + " Total Process Memory: {}", + spec.getTotalProcessMemorySize().toHumanReadableString()); + LOG.info( + " JVM Heap Memory: {}", + spec.getJvmHeapMemorySize().toHumanReadableString()); + LOG.info( + " Total JVM Direct Memory: {}", + spec.getJvmDirectMemorySize().toHumanReadableString()); + LOG.info( + " Framework: {}", + spec.getFrameworkOffHeap().toHumanReadableString()); + LOG.info( + " Network: {}", + spec.getNetworkOffHeap().toHumanReadableString()); + LOG.info( + " JVM Metaspace: {}", + spec.getJvmMetaspaceSize().toHumanReadableString()); + LOG.info( + " JVM Overhead: {}", + spec.getJvmOverheadSize().toHumanReadableString()); + } + + /** Commands that BashJavaUtils supports. */ + public enum Command { + /** Get JVM parameters of shuffle worker. */ + GET_SHUFFLE_WORKER_JVM_PARAMS, + + /** Get JVM parameters of shuffle manager. */ + GET_SHUFFLE_MANAGER_JVM_PARAMS, + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/utils/BufferUtils.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/utils/BufferUtils.java new file mode 100644 index 00000000..c14ffa66 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/utils/BufferUtils.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.utils; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.memory.BufferDispatcher; +import com.alibaba.flink.shuffle.core.memory.BufferRecycler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.nio.ByteBuffer; +import java.util.Collection; + +/** Utility methods to manipulate buffers. */ +public class BufferUtils { + + private static final Logger LOG = LoggerFactory.getLogger(BufferUtils.class); + + /** Recycles the given {@link Buffer} and logs the error if any exception occurs. */ + public static void recycleBuffer(@Nullable Buffer buffer) { + if (buffer == null) { + return; + } + + try { + buffer.release(); + } catch (Throwable throwable) { + LOG.error("Fatal: failed to release the target buffer.", throwable); + } + } + + /** Recycles the given {@link Buffer}s and logs the error if any exception occurs. */ + public static void recycleBuffers(@Nullable Collection buffers) { + if (buffers == null) { + return; + } + + for (Buffer buffer : buffers) { + recycleBuffer(buffer); + } + // clear method is not supported by all collections + CommonUtils.runQuietly(buffers::clear); + } + + /** + * Recycles the given {@link ByteBuffer} to the target {@link BufferDispatcher} logs the error + * if any exception occurs. + */ + public static void recycleBuffer( + @Nullable ByteBuffer buffer, + BufferDispatcher bufferDispatcher, + JobID jobID, + DataSetID dataSetID, + DataPartitionID partitionID) { + if (buffer == null) { + return; + } + + try { + CommonUtils.checkArgument(bufferDispatcher != null, "Must be not mull."); + + bufferDispatcher.recycleBuffer(buffer, jobID, dataSetID, partitionID); + } catch (Throwable throwable) { + LOG.error("Fatal: failed to release the target buffer.", throwable); + } + } + + /** + * Recycles the given {@link ByteBuffer}s to the target {@link BufferDispatcher} logs the error + * if any exception occurs. + */ + public static void recycleBuffers( + @Nullable Collection buffers, + BufferDispatcher bufferDispatcher, + JobID jobID, + DataSetID dataSetID, + DataPartitionID partitionID) { + if (buffers == null) { + return; + } + + for (ByteBuffer buffer : buffers) { + recycleBuffer(buffer, bufferDispatcher, jobID, dataSetID, partitionID); + } + // clear method is not supported by all collections + CommonUtils.runQuietly(buffers::clear); + } + + /** + * Recycles the given {@link Buffer} with the target {@link BufferRecycler} logs the error if + * any exception occurs. + */ + public static void recycleBuffer(@Nullable ByteBuffer buffer, BufferRecycler recycler) { + if (buffer == null) { + return; + } + + try { + CommonUtils.checkArgument(recycler != null, "Must be not mull."); + + recycler.recycle(buffer); + } catch (Throwable throwable) { + LOG.error("Fatal: failed to release the target buffer.", throwable); + } + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/utils/ConfigurationParserUtils.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/utils/ConfigurationParserUtils.java new file mode 100644 index 00000000..62eec65c --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/utils/ConfigurationParserUtils.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.utils; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.exception.ShuffleException; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.DefaultParser; +import org.apache.commons.cli.Option; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; + +import java.io.File; +import java.io.IOException; +import java.util.Properties; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** Utility class to load and parse {@link Configuration} from args and config file. */ +public class ConfigurationParserUtils { + + /** The default shuffle config directory name. */ + public static final String DEFAULT_SHUFFLE_CONF_DIR = "conf"; + + public static Configuration loadConfiguration(String[] args) throws IOException { + final DefaultParser parser = new DefaultParser(); + final Options options = new Options(); + options.addOption(DYNAMIC_PROPERTY_OPTION); + options.addOption(CONFIG_DIR_OPTION); + + final CommandLine commandLine; + try { + commandLine = parser.parse(options, args, true); + } catch (ParseException e) { + throw new ShuffleException("Failed to parse the command line arguments.", e); + } + + final Properties dynamicProperties = + commandLine.getOptionProperties(DYNAMIC_PROPERTY_OPTION.getOpt()); + if (commandLine.hasOption(CONFIG_DIR_OPTION.getOpt())) { + return new Configuration( + commandLine.getOptionValue(CONFIG_DIR_OPTION.getOpt()), dynamicProperties); + } else { + return new Configuration( + deriveShuffleConfDirectoryFromLibDirectory(), dynamicProperties); + } + } + + private static String deriveShuffleConfDirectoryFromLibDirectory() { + final String libJar = + ConfigurationParserUtils.class + .getProtectionDomain() + .getCodeSource() + .getLocation() + .getPath(); + + final File libDirectory = checkNotNull(new File(libJar).getParentFile()); + final File homeDirectory = checkNotNull(libDirectory.getParentFile()); + final File confDirectory = new File(homeDirectory, DEFAULT_SHUFFLE_CONF_DIR); + + return confDirectory.getAbsolutePath(); + } + + public static final Option CONFIG_DIR_OPTION = + Option.builder("c") + .longOpt("configDir") + .required(false) + .hasArg(true) + .argName("configuration directory") + .desc( + "Directory which contains the configuration file " + + Configuration.REMOTE_SHUFFLE_CONF_FILENAME + + ".") + .build(); + + public static final Option DYNAMIC_PROPERTY_OPTION = + Option.builder("D") + .argName("property=value") + .numberOfArgs(2) + .valueSeparator('=') + .desc("use value for given property") + .build(); +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/utils/FatalExitExceptionHandler.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/utils/FatalExitExceptionHandler.java new file mode 100644 index 00000000..0ddae0a1 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/utils/FatalExitExceptionHandler.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.utils; + +import com.alibaba.flink.shuffle.common.utils.FatalErrorExitUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Handler for uncaught exceptions that will log the exception and kill the process afterwards. + * + *

This guarantees that critical exceptions are not accidentally lost and leave the system + * running in an inconsistent state. + * + *

This class is copied from Apache Flink + * (org.apache.flink.runtime.util.FatalExitExceptionHandler). + */ +public final class FatalExitExceptionHandler implements Thread.UncaughtExceptionHandler { + + private static final Logger LOG = LoggerFactory.getLogger(FatalExitExceptionHandler.class); + + public static final FatalExitExceptionHandler INSTANCE = new FatalExitExceptionHandler(); + public static final int EXIT_CODE = -17; + + @Override + @SuppressWarnings("finally") + public void uncaughtException(Thread t, Throwable e) { + try { + LOG.error( + "FATAL: Thread '{}' produced an uncaught exception. Stopping the process...", + t.getName(), + e); + } finally { + FatalErrorExitUtils.exitProcessIfNeeded(EXIT_CODE); + } + } +} diff --git a/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/utils/ListenerUtils.java b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/utils/ListenerUtils.java new file mode 100644 index 00000000..f74a28a8 --- /dev/null +++ b/shuffle-core/src/main/java/com/alibaba/flink/shuffle/core/utils/ListenerUtils.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.utils; + +import com.alibaba.flink.shuffle.core.listener.BacklogListener; +import com.alibaba.flink.shuffle.core.listener.DataCommitListener; +import com.alibaba.flink.shuffle.core.listener.DataListener; +import com.alibaba.flink.shuffle.core.listener.DataRegionCreditListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +/** Utility methods to manipulate listeners. */ +public class ListenerUtils { + + private static final Logger LOG = LoggerFactory.getLogger(ListenerUtils.class); + + /** + * Notifies the available credits to the given {@link DataRegionCreditListener} and logs errors + * if any exception occurs. + */ + public static void notifyAvailableCredits( + int availableCredits, + int dataRegionIndex, + @Nullable DataRegionCreditListener dataRegionCreditListener) { + if (availableCredits <= 0 || dataRegionCreditListener == null) { + return; + } + + try { + dataRegionCreditListener.notifyCredits(availableCredits, dataRegionIndex); + } catch (Throwable throwable) { + LOG.error("Fatal: failed to notify available credit to listener.", throwable); + } + } + + /** + * Notifies the target {@link DataListener} of available data and logs errors if any exception + * occurs. + */ + public static void notifyAvailableData(@Nullable DataListener listener) { + if (listener == null) { + return; + } + + try { + listener.notifyDataAvailable(); + } catch (Throwable throwable) { + LOG.error("Fatal: failed to notify available data to listener.", throwable); + } + } + + /** + * Notifies the target {@link BacklogListener} of available backlog and logs errors if any + * exception occurs. + */ + public static void notifyBacklog(@Nullable BacklogListener listener, int backlog) { + if (listener == null) { + return; + } + + try { + listener.notifyBacklog(backlog); + } catch (Throwable throwable) { + LOG.error("Fatal: failed to notify available data to listener.", throwable); + } + } + + /** + * Notifies the target {@link FailureListener} of the encountered failure data and logs errors + * if any exception occurs. + */ + public static void notifyFailure( + @Nullable FailureListener listener, @Nullable Throwable error) { + if (listener == null) { + return; + } + + try { + listener.notifyFailure(error); + } catch (Throwable throwable) { + LOG.error("Fatal: failed to notify failure to listener.", throwable); + } + } + + /** + * Notifies the target {@link DataCommitListener} that all the data has been committed and logs + * errors if any exception occurs. + */ + public static void notifyDataCommitted(@Nullable DataCommitListener listener) { + if (listener == null) { + return; + } + + try { + listener.notifyDataCommitted(); + } catch (Throwable throwable) { + LOG.error("Fatal: failed to notify data committed event to listener.", throwable); + } + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/config/memory/ProcessMemoryUtilsTest.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/config/memory/ProcessMemoryUtilsTest.java new file mode 100644 index 00000000..46781456 --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/config/memory/ProcessMemoryUtilsTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config.memory; + +import com.alibaba.flink.shuffle.common.config.MemorySize; +import com.alibaba.flink.shuffle.core.config.memory.util.ProcessMemoryUtils; + +import org.junit.Test; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +/** Test for {@link ProcessMemoryUtils}. */ +public class ProcessMemoryUtilsTest { + + @Test + public void testGenerateJavaStartCommand() { + ProcessMemorySpec memorySpec = + new TestingProcessMemorySpec( + MemorySize.parse("32mb"), + MemorySize.parse("64mb"), + MemorySize.parse("128mb"), + MemorySize.parse("256mb")); + + String jvmOptions = "-XX:+PrintGCDetails -XX:+PrintGCDateStamps -XX:ParallelGCThreads=4"; + + String command = ProcessMemoryUtils.generateJvmArgsStr(memorySpec, jvmOptions); + + assertThat( + command, + is( + "-Xmx" + + 32 * 1024 * 1024 + + " -Xms" + + 32 * 1024 * 1024 + + " -XX:MaxDirectMemorySize=" + + 64 * 1024 * 1024 + + " -XX:MaxMetaspaceSize=" + + 128 * 1024 * 1024 + + " -XX:+PrintGCDetails -XX:+PrintGCDateStamps -XX:ParallelGCThreads=4")); + } + + /** Simple {@link ProcessMemorySpec} implementation for testing purposes. */ + public static class TestingProcessMemorySpec implements ProcessMemorySpec { + + private final MemorySize heapMemory; + private final MemorySize directMemory; + private final MemorySize metaspace; + private final MemorySize overhead; + + public TestingProcessMemorySpec( + MemorySize heapMemory, + MemorySize directMemory, + MemorySize metaspace, + MemorySize overhead) { + this.directMemory = directMemory; + this.heapMemory = heapMemory; + this.metaspace = metaspace; + this.overhead = overhead; + } + + @Override + public MemorySize getJvmHeapMemorySize() { + return heapMemory; + } + + @Override + public MemorySize getJvmDirectMemorySize() { + return directMemory; + } + + @Override + public MemorySize getJvmMetaspaceSize() { + return metaspace; + } + + @Override + public MemorySize getJvmOverheadSize() { + return overhead; + } + + @Override + public MemorySize getTotalProcessMemorySize() { + return heapMemory.add(directMemory).add(metaspace).add(overhead); + } + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/config/memory/ShuffleManagerProcessSpecTest.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/config/memory/ShuffleManagerProcessSpecTest.java new file mode 100644 index 00000000..4fb02291 --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/config/memory/ShuffleManagerProcessSpecTest.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config.memory; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.config.MemorySize; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; + +import org.junit.Test; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +/** Test for {@link ShuffleManagerProcessSpec}. */ +public class ShuffleManagerProcessSpecTest { + + @Test + public void testShuffleManagerProcessSpec() { + Configuration memConfig = new Configuration(); + + memConfig.setMemorySize(ManagerOptions.FRAMEWORK_HEAP_MEMORY, MemorySize.parse("256m")); + memConfig.setMemorySize(ManagerOptions.FRAMEWORK_OFF_HEAP_MEMORY, MemorySize.parse("128m")); + memConfig.setMemorySize(ManagerOptions.JVM_METASPACE, MemorySize.parse("32m")); + memConfig.setMemorySize(ManagerOptions.JVM_OVERHEAD, MemorySize.parse("32m")); + + ShuffleManagerProcessSpec processSpec = new ShuffleManagerProcessSpec(memConfig); + + assertThat(processSpec.getJvmHeapMemorySize(), is(MemorySize.parse("256m"))); + assertThat(processSpec.getJvmDirectMemorySize(), is(MemorySize.parse("128m"))); + assertThat(processSpec.getJvmOverheadSize(), is(MemorySize.parse("32m"))); + assertThat(processSpec.getJvmMetaspaceSize(), is(MemorySize.parse("32m"))); + assertThat(processSpec.getTotalProcessMemorySize(), is(MemorySize.parse("448m"))); + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/config/memory/ShuffleWorkerProcessSpecTest.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/config/memory/ShuffleWorkerProcessSpecTest.java new file mode 100644 index 00000000..dfadc70a --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/config/memory/ShuffleWorkerProcessSpecTest.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.config.memory; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.config.MemorySize; +import com.alibaba.flink.shuffle.core.config.MemoryOptions; +import com.alibaba.flink.shuffle.core.config.WorkerOptions; + +import org.junit.Test; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +/** Test for {@link ShuffleWorkerProcessSpec}. */ +public class ShuffleWorkerProcessSpecTest { + + @Test + public void testShuffleWorkerProcessSpec() { + Configuration memConfig = new Configuration(); + + memConfig.setMemorySize(WorkerOptions.FRAMEWORK_HEAP_MEMORY, MemorySize.parse("256m")); + memConfig.setMemorySize(WorkerOptions.FRAMEWORK_OFF_HEAP_MEMORY, MemorySize.parse("128m")); + memConfig.setMemorySize(WorkerOptions.JVM_METASPACE, MemorySize.parse("32m")); + memConfig.setMemorySize(WorkerOptions.JVM_OVERHEAD, MemorySize.parse("32m")); + memConfig.setMemorySize(MemoryOptions.MEMORY_BUFFER_SIZE, MemorySize.parse("1m")); + memConfig.setMemorySize( + MemoryOptions.MEMORY_SIZE_FOR_DATA_READING, MemorySize.parse("32m")); + memConfig.setMemorySize( + MemoryOptions.MEMORY_SIZE_FOR_DATA_WRITING, MemorySize.parse("32m")); + + ShuffleWorkerProcessSpec processSpec = new ShuffleWorkerProcessSpec(memConfig); + + assertThat(processSpec.getJvmHeapMemorySize(), is(MemorySize.parse("256m"))); + assertThat(processSpec.getJvmDirectMemorySize(), is(MemorySize.parse("192m"))); + assertThat(processSpec.getJvmMetaspaceSize(), is(MemorySize.parse("32m"))); + assertThat(processSpec.getJvmOverheadSize(), is(MemorySize.parse("32m"))); + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/executor/SimpleSingleThreadExecutorPoolTest.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/executor/SimpleSingleThreadExecutorPoolTest.java new file mode 100644 index 00000000..7d0ad3f8 --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/executor/SimpleSingleThreadExecutorPoolTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.executor; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import org.junit.Test; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; + +/** Tests for {@link SimpleSingleThreadExecutorPool}. */ +public class SimpleSingleThreadExecutorPoolTest { + + @Test + public void testFairness() throws Exception { + int numExecutors = 10; + int numExecutorRequests = 10; + CountDownLatch latch = new CountDownLatch(numExecutors); + SingleThreadExecutorPool executorPool = new SimpleSingleThreadExecutorPool(10, "test"); + ExecutorService executorService = Executors.newFixedThreadPool(numExecutors); + + try { + ConcurrentHashMap executorCounters = + new ConcurrentHashMap<>(); + for (int i = 0; i < numExecutors; ++i) { + TestTask testTask = + new TestTask(executorPool, executorCounters, numExecutorRequests, latch); + executorService.submit(testTask); + } + latch.await(); + + for (AtomicInteger counter : executorCounters.values()) { + assertEquals(numExecutorRequests, counter.get()); + } + } finally { + executorService.shutdown(); + executorPool.destroy(); + } + } + + @Test(expected = ShuffleException.class) + public void testDestroy() { + SingleThreadExecutorPool executorPool = new SimpleSingleThreadExecutorPool(10, "test"); + executorPool.destroy(); + executorPool.getSingleThreadExecutor(); + } + + private static class TestTask implements Runnable { + + private final SingleThreadExecutorPool executorPool; + + private final ConcurrentHashMap executorCounters; + + private final int numExecutorRequests; + + private final CountDownLatch latch; + + private TestTask( + SingleThreadExecutorPool executorPool, + ConcurrentHashMap executorCounters, + int numExecutorRequests, + CountDownLatch latch) { + CommonUtils.checkArgument(executorPool != null, "Must be not null."); + CommonUtils.checkArgument(executorCounters != null, "Must be not null."); + CommonUtils.checkArgument(numExecutorRequests > 0, "Must be positive."); + CommonUtils.checkArgument(latch != null, "Must be not null."); + + this.executorPool = executorPool; + this.executorCounters = executorCounters; + this.numExecutorRequests = numExecutorRequests; + this.latch = latch; + } + + @Override + public void run() { + for (int i = 0; i < numExecutorRequests; ++i) { + SingleThreadExecutor executor = executorPool.getSingleThreadExecutor(); + AtomicInteger counter = + executorCounters.computeIfAbsent( + executor, (ignored) -> new AtomicInteger(0)); + counter.incrementAndGet(); + } + latch.countDown(); + } + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/executor/SingleThreadExecutorTest.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/executor/SingleThreadExecutorTest.java new file mode 100644 index 00000000..6e63f8ff --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/executor/SingleThreadExecutorTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.executor; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import org.junit.Test; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.RejectedExecutionException; + +import static org.junit.Assert.assertEquals; + +/** Tests for {@link SingleThreadExecutor}. */ +public class SingleThreadExecutorTest { + + @Test(timeout = 60000, expected = RejectedExecutionException.class) + public void testShutDown() throws Exception { + SingleThreadExecutor executor = new SingleThreadExecutor("test-thread"); + executor.shutDown(); + executor.getExecutorThread().join(); + executor.execute(() -> {}); + } + + @Test + public void testSingleThreadExecution() throws Exception { + SingleThreadExecutor executor = new SingleThreadExecutor("test-thread"); + + try { + int count = 1000; + CountDownLatch latch = new CountDownLatch(count); + TestTask testTask = new TestTask(latch); + + for (int i = 0; i < count; ++i) { + executor.execute(testTask); + } + + latch.await(); + assertEquals(count, testTask.counter); + } finally { + executor.shutDown(); + } + } + + private static class TestTask implements Runnable { + + private final CountDownLatch latch; + + private int counter; + + private TestTask(CountDownLatch latch) { + CommonUtils.checkArgument(latch != null, "Must be not null."); + this.latch = latch; + } + + @Override + public void run() { + ++counter; + latch.countDown(); + } + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/memory/BufferDispatcherTest.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/memory/BufferDispatcherTest.java new file mode 100644 index 00000000..4103b435 --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/memory/BufferDispatcherTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.memory; + +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** Test for {@link BufferDispatcher}. */ +public class BufferDispatcherTest { + + @Test + public void testRequestMaxBufferMoreThanPoolSize() throws Exception { + BufferDispatcher bufferDispatcher = new BufferDispatcher("BUFFER POOL", 2, 16); + Object lock = new Object(); + AtomicReference> res = new AtomicReference<>(); + bufferDispatcher.requestBuffer( + null, + null, + null, + 1, + 3, + (buffers, exception) -> { + synchronized (lock) { + res.set(buffers); + lock.notify(); + } + }); + synchronized (lock) { + if (res.get() == null) { + lock.wait(); + } + } + assertTrue(res.get() != null); + assertEquals(2, res.get().size()); + res.get().forEach((buffer) -> bufferDispatcher.recycleBuffer(buffer, null, null, null)); + bufferDispatcher.destroy(); + } + + @Test + public void testRequestInsufficientBuffer() throws Exception { + BufferDispatcher bufferDispatcher = new BufferDispatcher("BUFFER POOL", 2, 16); + Object lock = new Object(); + AtomicReference> res0 = new AtomicReference<>(); + bufferDispatcher.requestBuffer( + null, + null, + null, + 1, + 1, + (buffers, exception) -> { + synchronized (lock) { + res0.set(buffers); + lock.notify(); + } + }); + synchronized (lock) { + if (res0.get() == null) { + lock.wait(); + } + } + assertTrue(res0.get() != null); + assertEquals(1, res0.get().size()); + + AtomicReference> res1 = new AtomicReference<>(); + bufferDispatcher.requestBuffer( + null, + null, + null, + 2, + 2, + (buffers, exception) -> { + synchronized (lock) { + res1.set(buffers); + lock.notify(); + } + }); + synchronized (lock) { + if (res1.get() == null) { + lock.wait(1000); + } + } + assertTrue(res1.get() == null); + + res0.get().forEach((buffer) -> bufferDispatcher.recycleBuffer(buffer, null, null, null)); + synchronized (lock) { + if (res1.get() == null) { + lock.wait(); + } + } + assertTrue(res1.get() != null); + assertEquals(2, res1.get().size()); + + res1.get().forEach((buffer) -> bufferDispatcher.recycleBuffer(buffer, null, null, null)); + bufferDispatcher.destroy(); + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/memory/BufferTest.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/memory/BufferTest.java new file mode 100644 index 00000000..7a61bd8b --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/memory/BufferTest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.memory; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import org.junit.Test; + +import java.nio.ByteBuffer; + +import static org.junit.Assert.assertEquals; + +/** Tests for {@link Buffer}. */ +public class BufferTest { + + @Test + public void testRecycle() { + int numBuffers = 100; + int bufferSize = 1024; + TestBufferRecycler recycler = new TestBufferRecycler(); + for (int i = 0; i < numBuffers; ++i) { + Buffer buffer = createBuffer(bufferSize, 0, recycler); + buffer.release(); + } + assertEquals(numBuffers, recycler.getNumRecycledBuffers()); + } + + @Test + public void testReadableBytes() { + int bufferSize = 1024; + int readableBytes = 512; + Buffer buffer = createBuffer(bufferSize, readableBytes, new TestBufferRecycler()); + assertEquals(readableBytes, buffer.readableBytes()); + } + + private Buffer createBuffer(int bufferSize, int readableBytes, BufferRecycler recycler) { + CommonUtils.checkArgument(bufferSize >= readableBytes); + + ByteBuffer byteBuffer = ByteBuffer.allocateDirect(bufferSize); + return new Buffer(byteBuffer, recycler, readableBytes); + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/memory/ByteBufferPoolTest.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/memory/ByteBufferPoolTest.java new file mode 100644 index 00000000..e1a9db60 --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/memory/ByteBufferPoolTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.memory; + +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** Test for {@link ByteBufferPool}. */ +public class ByteBufferPoolTest { + + @Test + public void testRequestBuffers() throws Exception { + ByteBufferPool bufferPool = new ByteBufferPool("BUFFER POOL", 2, 16); + assertEquals(2, bufferPool.numAvailableBuffers()); + + ByteBuffer buffer = bufferPool.requestBuffer(); + assertEquals(16, buffer.capacity()); + assertEquals(1, bufferPool.numAvailableBuffers()); + + buffer = bufferPool.requestBlocking(1); + assertTrue(buffer != null); + assertEquals(0, bufferPool.numAvailableBuffers()); + assertTrue(bufferPool.requestBuffer() == null); + assertNull(bufferPool.requestBlocking(1)); + } + + @Test + public void testDestroyBufferPool() { + ByteBufferPool bufferPool = new ByteBufferPool("BUFFER POOL", 2, 16); + assertFalse(bufferPool.isDestroyed()); + bufferPool.destroy(); + assertTrue(bufferPool.isDestroyed()); + } + + @Test + public void testRequestAndRecycleBuffers() throws Exception { + ByteBufferPool bufferPool = new ByteBufferPool("BUFFER POOL", 2, 16); + + ByteBuffer buffer0 = bufferPool.requestBuffer(); + assertEquals(1, bufferPool.numAvailableBuffers()); + ByteBuffer buffer1 = bufferPool.requestBuffer(); + assertEquals(0, bufferPool.numAvailableBuffers()); + + bufferPool.recycle(buffer0); + assertEquals(1, bufferPool.numAvailableBuffers()); + bufferPool.recycle(buffer1); + assertEquals(2, bufferPool.numAvailableBuffers()); + + buffer0 = bufferPool.requestBuffer(); + buffer1 = bufferPool.requestBuffer(); + assertEquals(0, bufferPool.numAvailableBuffers()); + + AtomicReference gotBuffer = new AtomicReference(null); + Object lock = new Object(); + Thread t = + new Thread( + () -> { + try { + synchronized (lock) { + ByteBuffer buffer = bufferPool.requestBlocking(1); + assertTrue(buffer != null); + gotBuffer.set(buffer); + lock.notify(); + } + } catch (Exception e) { + } + }); + t.start(); + bufferPool.recycle(buffer0); + synchronized (lock) { + if (gotBuffer.get() == null) { + lock.wait(); + } + } + t.join(500); + assertTrue(gotBuffer.get() != null); + bufferPool.recycle(gotBuffer.get()); + bufferPool.recycle(buffer1); + assertEquals(2, bufferPool.numAvailableBuffers()); + bufferPool.destroy(); + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/memory/TestBufferRecycler.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/memory/TestBufferRecycler.java new file mode 100644 index 00000000..8b4a9dde --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/memory/TestBufferRecycler.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.memory; + +import java.nio.ByteBuffer; + +import static org.junit.Assert.assertEquals; + +/** A {@link BufferRecycler} implementation for tests. */ +public class TestBufferRecycler implements BufferRecycler { + + private int numRecycledBuffers; + + @Override + public void recycle(ByteBuffer buffer) { + ++numRecycledBuffers; + + assertEquals(buffer.capacity(), buffer.limit()); + assertEquals(0, buffer.position()); + } + + public int getNumRecycledBuffers() { + return numRecycledBuffers; + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/storage/BufferQueueTest.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/storage/BufferQueueTest.java new file mode 100644 index 00000000..7134d088 --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/storage/BufferQueueTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; + +/** Tests for {@link BufferQueue}. */ +public class BufferQueueTest { + + @Test + public void testPollBuffer() { + ByteBuffer buffer1 = ByteBuffer.allocate(1024); + ByteBuffer buffer2 = ByteBuffer.allocate(1024); + BufferQueue bufferQueue = new BufferQueue(Arrays.asList(buffer1, buffer2)); + + assertEquals(2, bufferQueue.size()); + assertEquals(buffer1, bufferQueue.poll()); + + assertEquals(1, bufferQueue.size()); + assertEquals(buffer2, bufferQueue.poll()); + + assertEquals(0, bufferQueue.size()); + assertNull(bufferQueue.poll()); + } + + @Test + public void testAddBuffer() { + ByteBuffer buffer1 = ByteBuffer.allocate(1024); + ByteBuffer buffer2 = ByteBuffer.allocate(1024); + BufferQueue bufferQueue = new BufferQueue(new ArrayList<>()); + + assertEquals(0, bufferQueue.size()); + bufferQueue.add(buffer1); + + assertEquals(1, bufferQueue.size()); + bufferQueue.add(buffer2); + + assertEquals(2, bufferQueue.size()); + bufferQueue.add(Arrays.asList(ByteBuffer.allocate(1024), ByteBuffer.allocate(1024))); + assertEquals(4, bufferQueue.size()); + } + + @Test + public void testReleaseBufferQueue() { + ByteBuffer buffer1 = ByteBuffer.allocate(1024); + ByteBuffer buffer2 = ByteBuffer.allocate(1024); + BufferQueue bufferQueue = new BufferQueue(Arrays.asList(buffer1, buffer2)); + + assertEquals(2, bufferQueue.size()); + List buffers = bufferQueue.release(); + assertEquals(buffer1, buffers.get(0)); + assertEquals(buffer2, buffers.get(1)); + + assertEquals(0, bufferQueue.size()); + try { + bufferQueue.add(buffer1); + } catch (IllegalStateException exception) { + assertNull(bufferQueue.poll()); + return; + } + + fail("IllegalStateException expected."); + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/storage/DataSetTest.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/storage/DataSetTest.java new file mode 100644 index 00000000..48dbca62 --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/storage/DataSetTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; + +import org.junit.Test; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** Tests for {@link DataSet}. */ +public class DataSetTest { + + private static final JobID JOB_ID = new JobID(CommonUtils.randomBytes(16)); + + private static final DataSetID DATA_SET_ID = new DataSetID(CommonUtils.randomBytes(16)); + + @Test + public void testAddDataPartition() { + int numDataPartitions = 10; + DataSet dataSet = new DataSet(JOB_ID, DATA_SET_ID); + Set partitionIDS = addDataPartitions(dataSet, numDataPartitions); + + assertEquals(numDataPartitions, dataSet.getNumDataPartitions()); + assertEquals(partitionIDS, dataSet.getDataPartitionIDs()); + } + + @Test(expected = ShuffleException.class) + public void testAddExistingDataPartition() { + DataSet dataSet = new DataSet(JOB_ID, DATA_SET_ID); + DataPartitionID partitionID = new MapPartitionID(CommonUtils.randomBytes(16)); + DataPartition dataPartition = new NoOpDataPartition(JOB_ID, DATA_SET_ID, partitionID); + + dataSet.addDataPartition(dataPartition); + dataSet.addDataPartition(dataPartition); + } + + @Test + public void testGetDataPartition() { + int numDataPartitions = 10; + DataSet dataSet = new DataSet(JOB_ID, DATA_SET_ID); + Set partitionIDS = addDataPartitions(dataSet, numDataPartitions); + + for (DataPartitionID partitionID : partitionIDS) { + assertNotNull(dataSet.getDataPartition(partitionID)); + } + assertNull(dataSet.getDataPartition(new MapPartitionID(CommonUtils.randomBytes(16)))); + } + + @Test + public void testRemoveDataPartition() { + int numDataPartitions = 10; + DataSet dataSet = new DataSet(JOB_ID, DATA_SET_ID); + Set partitionIDS = addDataPartitions(dataSet, numDataPartitions); + + dataSet.removeDataPartition(new MapPartitionID(CommonUtils.randomBytes(16))); + assertEquals(numDataPartitions, dataSet.getNumDataPartitions()); + + int count = 0; + for (DataPartitionID partitionID : partitionIDS) { + ++count; + dataSet.removeDataPartition(partitionID); + assertEquals(numDataPartitions - count, dataSet.getNumDataPartitions()); + } + } + + @Test + public void testClearDataPartition() { + int numDataPartitions = 10; + DataSet dataSet = new DataSet(JOB_ID, DATA_SET_ID); + Set partitionIDS = addDataPartitions(dataSet, numDataPartitions); + + List dataPartitions = dataSet.clearDataPartitions(); + for (DataPartition dataPartition : dataPartitions) { + assertTrue( + partitionIDS.contains(dataPartition.getPartitionMeta().getDataPartitionID())); + } + assertEquals(0, dataSet.getNumDataPartitions()); + } + + private Set addDataPartitions(DataSet dataSet, int numDataPartitions) { + Set partitionIDS = new HashSet<>(); + for (int i = 0; i < numDataPartitions; ++i) { + DataPartitionID partitionID = new MapPartitionID(CommonUtils.randomBytes(16)); + partitionIDS.add(partitionID); + dataSet.addDataPartition(new NoOpDataPartition(JOB_ID, DATA_SET_ID, partitionID)); + } + return partitionIDS; + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/storage/NoOpDataPartition.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/storage/NoOpDataPartition.java new file mode 100644 index 00000000..b201f3af --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/storage/NoOpDataPartition.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.listener.BacklogListener; +import com.alibaba.flink.shuffle.core.listener.DataListener; +import com.alibaba.flink.shuffle.core.listener.DataRegionCreditListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; + +import javax.annotation.Nullable; + +import java.util.concurrent.CompletableFuture; + +/** No-op {@link DataPartition} implementation for tests. */ +public class NoOpDataPartition implements DataPartition { + + private final DataPartitionMeta partitionMeta; + + public NoOpDataPartition(JobID jobID, DataSetID dataSetID, DataPartitionID partitionID) { + CommonUtils.checkArgument(jobID != null); + CommonUtils.checkArgument(dataSetID != null); + CommonUtils.checkArgument(partitionID != null); + + this.partitionMeta = + new TestDataPartitionMeta( + jobID, dataSetID, partitionID, new StorageMeta("/tmp", StorageType.SSD)); + } + + @Override + public DataPartitionMeta getPartitionMeta() { + return partitionMeta; + } + + @Override + public DataPartitionType getPartitionType() { + return null; + } + + @Override + public DataPartitionWriter createPartitionWriter( + MapPartitionID mapPartitionID, + DataRegionCreditListener dataRegionCreditListener, + FailureListener failureListener) { + return null; + } + + @Override + public DataPartitionReader createPartitionReader( + int startPartitionIndex, + int endPartitionIndex, + DataListener dataListener, + BacklogListener backlogListener, + FailureListener failureListener) { + return null; + } + + @Override + public CompletableFuture releasePartition(@Nullable Throwable releaseCause) { + return null; + } + + @Override + public boolean isConsumable() { + return false; + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/storage/TestDataPartitionMeta.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/storage/TestDataPartitionMeta.java new file mode 100644 index 00000000..67e4c669 --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/storage/TestDataPartitionMeta.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.storage; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; + +import java.io.DataOutput; +import java.util.List; + +/** A {@link DataPartitionMeta} implementation for tests. */ +public class TestDataPartitionMeta extends DataPartitionMeta { + + private final DataPartitionID partitionID; + + public TestDataPartitionMeta( + JobID jobID, + DataSetID dataSetID, + DataPartitionID partitionID, + StorageMeta storageMeta) { + super(jobID, dataSetID, storageMeta); + + CommonUtils.checkArgument(partitionID != null); + this.partitionID = partitionID; + } + + @Override + public DataPartitionID getDataPartitionID() { + return partitionID; + } + + @Override + public String getPartitionFactoryClassName() { + return null; + } + + @Override + public List getMapPartitionIDs() { + return null; + } + + @Override + public void writeTo(DataOutput dataOutput) {} +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/utils/OneShotLatch.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/utils/OneShotLatch.java new file mode 100644 index 00000000..25ea7c3e --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/utils/OneShotLatch.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.utils; + +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * Latch for synchronizing parts of code in tests. Once the latch has fired once calls to {@link + * #await()} will return immediately in the future. + * + *

A part of the code that should only run after other code calls {@link #await()}. The call will + * only return once the other part is finished and calls {@link #trigger()}. + */ +public final class OneShotLatch { + + private final Object lock = new Object(); + private final Set waitersSet = Collections.newSetFromMap(new IdentityHashMap<>()); + + private volatile boolean triggered; + + /** Fires the latch. Code that is blocked on {@link #await()} will now return. */ + public void trigger() { + synchronized (lock) { + triggered = true; + lock.notifyAll(); + } + } + + /** + * Waits until {@link OneShotLatch#trigger()} is called. Once {@code trigger()} has been called + * this call will always return immediately. + * + * @throws InterruptedException Thrown if the thread is interrupted while waiting. + */ + public void await() throws InterruptedException { + synchronized (lock) { + while (!triggered) { + Thread thread = Thread.currentThread(); + try { + waitersSet.add(thread); + lock.wait(); + } finally { + waitersSet.remove(thread); + } + } + } + } + + /** + * Waits until {@link OneShotLatch#trigger()} is called. Once {@code #trigger()} has been called + * this call will always return immediately. + * + *

If the latch is not triggered within the given timeout, a {@code TimeoutException} will be + * thrown after the timeout. + * + *

A timeout value of zero means infinite timeout and make this equivalent to {@link + * #await()}. + * + * @param timeout The value of the timeout, a value of zero indicating infinite timeout. + * @param timeUnit The unit of the timeout + * @throws InterruptedException Thrown if the thread is interrupted while waiting. + * @throws TimeoutException Thrown, if the latch is not triggered within the timeout time. + */ + public void await(long timeout, TimeUnit timeUnit) + throws InterruptedException, TimeoutException { + if (timeout < 0) { + throw new IllegalArgumentException("time may not be negative"); + } + if (timeUnit == null) { + throw new NullPointerException("timeUnit"); + } + + if (timeout == 0) { + await(); + } else { + final long deadline = System.nanoTime() + timeUnit.toNanos(timeout); + long millisToWait; + + synchronized (lock) { + while (!triggered + && (millisToWait = (deadline - System.nanoTime()) / 1_000_000) > 0) { + lock.wait(millisToWait); + } + + if (!triggered) { + throw new TimeoutException(); + } + } + } + } + + /** + * Checks if the latch was triggered. + * + * @return True, if the latch was triggered, false if not. + */ + public boolean isTriggered() { + return triggered; + } + + public int getWaitersCount() { + synchronized (lock) { + return waitersSet.size(); + } + } + + /** Resets the latch so that {@link #isTriggered()} returns false. */ + public void reset() { + synchronized (lock) { + triggered = false; + } + } + + @Override + public String toString() { + return "Latch " + (triggered ? "TRIGGERED" : "PENDING"); + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/utils/TestLogger.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/utils/TestLogger.java new file mode 100644 index 00000000..08b9843c --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/utils/TestLogger.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.utils; + +import org.junit.Rule; +import org.junit.rules.TestRule; +import org.junit.rules.TestWatcher; +import org.junit.runner.Description; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.PrintWriter; +import java.io.StringWriter; + +/** + * Adds automatic test name logging. Every test which wants to log which test is currently executed + * and why it failed, simply has to extend this class. + */ +public class TestLogger { + protected final Logger log = LoggerFactory.getLogger(getClass()); + + static { + TestSignalHandler.register(); + } + + @Rule + public TestRule watchman = + new TestWatcher() { + + @Override + public void starting(Description description) { + log.info( + "\n================================================================================" + + "\nTest {} is running." + + "\n--------------------------------------------------------------------------------", + description); + } + + @Override + public void succeeded(Description description) { + log.info( + "\n--------------------------------------------------------------------------------" + + "\nTest {} successfully run." + + "\n================================================================================", + description); + } + + @Override + public void failed(Throwable e, Description description) { + log.error( + "\n--------------------------------------------------------------------------------" + + "\nTest {} failed with:\n{}" + + "\n================================================================================", + description, + exceptionToString(e)); + } + }; + + @Rule public final TestRule nameProvider = new TestNameProvider(); + + private static String exceptionToString(Throwable t) { + if (t == null) { + return "(null)"; + } + + try { + StringWriter stm = new StringWriter(); + PrintWriter wrt = new PrintWriter(stm); + t.printStackTrace(wrt); + wrt.close(); + return stm.toString(); + } catch (Throwable ignored) { + return t.getClass().getName() + " (error while printing stack trace)"; + } + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/utils/TestNameProvider.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/utils/TestNameProvider.java new file mode 100644 index 00000000..1f27b974 --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/utils/TestNameProvider.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.utils; + +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; + +import javax.annotation.Nullable; + +/** + * A rule that provides the current test name per thread. Currently, the test name is available for + * all tests that extend {@link TestLogger}. + */ +public class TestNameProvider implements TestRule { + + private static final ThreadLocal testName = new ThreadLocal<>(); + + @Nullable + public static String getCurrentTestName() { + return testName.get(); + } + + @Override + public Statement apply(Statement base, Description description) { + return new Statement() { + @Override + public void evaluate() throws Throwable { + testName.set(description.getDisplayName()); + try { + base.evaluate(); + } finally { + testName.set(null); + } + } + }; + } +} diff --git a/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/utils/TestSignalHandler.java b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/utils/TestSignalHandler.java new file mode 100644 index 00000000..aea1983d --- /dev/null +++ b/shuffle-core/src/test/java/com/alibaba/flink/shuffle/core/utils/TestSignalHandler.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.core.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import sun.misc.Signal; + +/** + * This signal handler / signal logger is based on Apache Hadoop's + * org.apache.hadoop.util.SignalLogger. + */ +public class TestSignalHandler { + + private static final Logger LOG = LoggerFactory.getLogger(TestSignalHandler.class); + + private static boolean registered = false; + + /** Our signal handler. */ + private static class Handler implements sun.misc.SignalHandler { + + private final sun.misc.SignalHandler prevHandler; + + Handler(String name) { + prevHandler = Signal.handle(new Signal(name), this); + } + + /** + * Handle an incoming signal. + * + * @param signal The incoming signal + */ + @Override + public void handle(Signal signal) { + LOG.warn( + "RECEIVED SIGNAL {}: SIG{}. Shutting down as requested.", + signal.getNumber(), + signal.getName()); + prevHandler.handle(signal); + } + } + + /** Register some signal handlers. */ + public static void register() { + synchronized (TestSignalHandler.class) { + if (registered) { + return; + } + registered = true; + + final String[] signals = + System.getProperty("os.name").startsWith("Windows") + ? new String[] {"TERM", "INT"} + : new String[] {"TERM", "HUP", "INT"}; + + for (String signalName : signals) { + try { + new TestSignalHandler.Handler(signalName); + } catch (Exception e) { + LOG.info("Error while registering signal handler", e); + } + } + } + } +} diff --git a/shuffle-core/src/test/resources/log4j2-test.properties b/shuffle-core/src/test/resources/log4j2-test.properties new file mode 100644 index 00000000..d7fcb327 --- /dev/null +++ b/shuffle-core/src/test/resources/log4j2-test.properties @@ -0,0 +1,26 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +rootLogger.level=OFF +rootLogger.appenderRef.test.ref=TestLogger +appender.testlogger.name=TestLogger +appender.testlogger.type=CONSOLE +appender.testlogger.target=SYSTEM_ERR +appender.testlogger.layout.type=PatternLayout +appender.testlogger.layout.pattern=%-4r [%t] %-5p %c %x - %m%n diff --git a/shuffle-dist/pom.xml b/shuffle-dist/pom.xml new file mode 100644 index 00000000..01fd6f7d --- /dev/null +++ b/shuffle-dist/pom.xml @@ -0,0 +1,307 @@ + + + + + + com.alibaba.flink.shuffle + flink-shuffle-parent + 1.0-SNAPSHOT + + 4.0.0 + + shuffle-dist + + + + + com.alibaba.flink.shuffle + shuffle-common + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-coordinator + ${project.version} + + + org.apache.flink + * + + + + + + com.alibaba.flink.shuffle + shuffle-core + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-kubernetes + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-kubernetes-operator + ${project.version} + + provided + + + + com.alibaba.flink.shuffle + shuffle-storage + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-transfer + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-metrics + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-yarn + ${project.version} + + + + + com.alibaba.flink.shuffle + shuffle-examples + ${project.version} + provided + + + + + + org.apache.flink + flink-shaded-netty + 4.1.49.Final-${flink.shaded.version} + + + + org.apache.flink + flink-shaded-zookeeper-3 + ${zookeeper.version} + + + + + + org.apache.logging.log4j + log4j-slf4j-impl + + + + org.apache.logging.log4j + log4j-api + + + + org.apache.logging.log4j + log4j-core + + + + + org.apache.logging.log4j + log4j-1.2-api + + + + + + + + symlink-build-target + + + unix + + + + + + org.codehaus.mojo + exec-maven-plugin + 1.5.0 + + + remove-build-target-link + clean + + exec + + + rm + + -f + ${project.basedir}/../build-target + + + + + create-build-target-link + package + + exec + + + ln + + -sfn + + ${project.basedir}/target/flink-remote-shuffle-${project.version}-bin/flink-remote-shuffle-${project.version} + + ${project.basedir}/../build-target + + + + + + + + + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + + + dependency-convergence + + enforce + + + true + + + + + + + org.apache.maven.plugins + maven-shade-plugin + + + + shade-remote-shuffle + none + + + package + + shade + + + false + false + ${project.artifactId}-${project.version} + + + + * + + log4j.properties + log4j-test.properties + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + flink-rpc-akka.jar + + + + + + + org.apache.logging.log4j:* + + + + + reference.conf + + + + + + javax.ws.rs + + com.alibaba.flink.shuffle.shaded.javax.ws.rs + + + + + + + + + + maven-assembly-plugin + + + bin + package + + single + + + + src/main/assemblies/bin.xml + + flink-remote-shuffle-${project.version}-bin + false + + + + + + + + diff --git a/shuffle-dist/src/main/assemblies/bin.xml b/shuffle-dist/src/main/assemblies/bin.xml new file mode 100644 index 00000000..f851badd --- /dev/null +++ b/shuffle-dist/src/main/assemblies/bin.xml @@ -0,0 +1,113 @@ + + + bin + + dir + + + true + flink-remote-shuffle-${project.version} + + + + lib + false + false + false + true + true + 0644 + + + org.apache.logging.log4j:log4j-api + org.apache.logging.log4j:log4j-core + org.apache.logging.log4j:log4j-slf4j-impl + org.apache.logging.log4j:log4j-1.2-api + + + + + + + + target/shuffle-dist-${project.version}.jar + lib/ + 0644 + + + + + + ../shuffle-kubernetes-operator/target/shuffle-kubernetes-operator-${project.version}.jar + + opt/ + 0644 + + + + + ../shuffle-plugin/target/shuffle-plugin-${project.version}.jar + lib/ + 0644 + + + + + + + src/main/shuffle-bin/bin + bin + 0755 + + + + + src/main/shuffle-bin/conf + conf + 0644 + + + + + src/main/shuffle-bin/ + log + 0644 + + **/* + + + + + ../shuffle-examples/target + examples + 0644 + + *.jar + + + shuffle-examples*.jar + original-*.jar + + + + diff --git a/shuffle-dist/src/main/shuffle-bin/bin/config.sh b/shuffle-dist/src/main/shuffle-bin/bin/config.sh new file mode 100755 index 00000000..53dcb1c6 --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/bin/config.sh @@ -0,0 +1,367 @@ +#!/usr/bin/env bash +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +bin=`dirname "$0"` +SHUFFLE_BIN_DIR=`cd "$bin"; pwd` + +# Define shuffle dir. +SHUFFLE_HOME=`dirname "$SHUFFLE_BIN_DIR"` +if [ -z "$SHUFFLE_LIB_DIR" ]; then SHUFFLE_LIB_DIR=$SHUFFLE_HOME/lib; fi +if [ -z "$SHUFFLE_CONF_DIR" ]; then SHUFFLE_CONF_DIR=$SHUFFLE_HOME/conf; fi +if [ -z "$SHUFFLE_LOG_DIR" ]; then SHUFFLE_LOG_DIR=$SHUFFLE_HOME/log; fi + +# Define HOSTNAME if it is not already set +if [ -z "${HOSTNAME}" ]; then + HOSTNAME=`hostname` +fi + +UNAME=$(uname -s) +if [ "${UNAME:0:6}" == "CYGWIN" ]; then + JAVA_RUN=java +else + if [[ -d "$JAVA_HOME" ]]; then + JAVA_RUN="$JAVA_HOME"/bin/java + else + JAVA_RUN=java + fi +fi + +SHUFFLE_CONF_FILE="remote-shuffle-conf.yaml" +YAML_CONF=${SHUFFLE_CONF_DIR}/${SHUFFLE_CONF_FILE} + +manglePathList() { + UNAME=$(uname -s) + # a path list, for example a java classpath + if [ "${UNAME:0:6}" == "CYGWIN" ]; then + echo `cygpath -wp "$1"` + else + echo $1 + fi +} + +# Looks up a config value by key from a simple YAML-style key-value map. +# $1: key to look up +# $2: default value to return if key does not exist +# $3: config file to read from +readFromConfig() { + local key=$1 + local defaultValue=$2 + local configFile=$3 + + # first extract the value with the given key (1st sed), then trim the result (2nd sed) + # if a key exists multiple times, take the "last" one (tail) + local value=`sed -n "s/^[ ]*${key}[ ]*: \([^#]*\).*$/\1/p" "${configFile}" | sed "s/^ *//;s/ *$//" | tail -n 1` + + [ -z "$value" ] && echo "$defaultValue" || echo "$value" +} + +constructShuffleClassPath() { + local SHUFFLE_DIST + local SHUFFLE_CLASSPATH + + while read -d '' -r jarfile ; do + if [[ "$jarfile" =~ .*/shuffle-dist[^/]*.jar$ ]]; then + SHUFFLE_DIST="$SHUFFLE_DIST":"$jarfile" + elif [[ "$SHUFFLE_CLASSPATH" == "" ]]; then + SHUFFLE_CLASSPATH="$jarfile"; + else + SHUFFLE_CLASSPATH="$SHUFFLE_CLASSPATH":"$jarfile" + fi + done < <(find "$SHUFFLE_LIB_DIR" ! -type d -name '*.jar' -print0 | sort -z) + + if [[ "$SHUFFLE_DIST" == "" ]]; then + # write error message to stderr since stdout is stored as the classpath + (>&2 echo "[ERROR] Shuffle distribution jar not found in $SHUFFLE_LIB_DIR.") + + # exit function with empty classpath to force process failure + exit 1 + fi + + echo "$SHUFFLE_CLASSPATH""$SHUFFLE_DIST" +} + +findShuffleDistJar() { + local SHUFFLE_DIST="`find "$SHUFFLE_LIB_DIR" -name 'shuffle-dist*.jar'`" + + if [[ "$SHUFFLE_DIST" == "" ]]; then + # write error message to stderr since stdout is stored as the classpath + (>&2 echo "[ERROR] Shuffle distribution jar not found in $SHUFFLE_LIB_DIR.") + + # exit function with empty classpath to force process failure + exit 1 + fi + + echo "$SHUFFLE_DIST" +} + +######################################################################################################################## +# DEFAULT CONFIG VALUES: These values will be used when nothing has been specified in conf/remote-shuffle-conf.yaml +# -or- the respective environment variables are not set. +######################################################################################################################## + +DEFAULT_ENV_SSH_OPTS="" # Optional SSH parameters running in cluster mode + +######################################################################################################################## +# CONFIG KEYS: The default values can be overwritten by the following keys in conf/remote-shuffle-conf.yaml +######################################################################################################################## + +KEY_ENV_SSH_OPTS="env.ssh.opts" + +if [ -z "${SHUFFLE_SSH_OPTS}" ]; then + SHUFFLE_SSH_OPTS=$(readFromConfig ${KEY_ENV_SSH_OPTS} "${DEFAULT_ENV_SSH_OPTS}" "${YAML_CONF}") +fi + +extractLoggingOutputs() { + local output="$1" + local EXECUTION_PREFIX="BASH_JAVA_UTILS_EXEC_RESULT:" + + echo "${output}" | grep -v ${EXECUTION_PREFIX} +} + +extractExecutionResults() { + local output="$1" + local expected_lines="$2" + local EXECUTION_PREFIX="BASH_JAVA_UTILS_EXEC_RESULT:" + local execution_results + local num_lines + + execution_results=$(echo "${output}" | grep ${EXECUTION_PREFIX}) + num_lines=$(echo "${execution_results}" | wc -l) + # explicit check for empty result, becuase if execution_results is empty, then wc returns 1 + if [[ -z ${execution_results} ]]; then + echo "[ERROR] The execution result is empty." 1>&2 + exit 1 + fi + if [[ ${num_lines} -ne ${expected_lines} ]]; then + echo "[ERROR] The execution results has unexpected number of lines, expected: ${expected_lines}, actual: ${num_lines}." 1>&2 + echo "[ERROR] An execution result line is expected following the prefix '${EXECUTION_PREFIX}'" 1>&2 + echo "$output" 1>&2 + exit 1 + fi + + echo "${execution_results//${EXECUTION_PREFIX}/}" +} + +runBashJavaUtilsCmd() { + local cmd=$1 + local class_path="$(constructShuffleClassPath)" + local dynamic_args=("${@:2}") + local java_utils_log_conf="/tmp/log4j2-bash-java-utils.properties" + + # log config for bash java utils. + cat > ${java_utils_log_conf} << EOF + rootLogger.level = INFO + rootLogger.appenderRef.console.ref = ConsoleAppender + appender.console.name = ConsoleAppender + appender.console.type = CONSOLE + appender.console.layout.type = PatternLayout + appender.console.layout.pattern = %-5p %x - %m%n +EOF + + local log_setting=("-Dlog4j.configurationFile=file:${java_utils_log_conf}") + local output=`"${JAVA_RUN}" "${log_setting}" -classpath "${class_path}" com.alibaba.flink.shuffle.core.utils.BashJavaUtils ${cmd} "${dynamic_args[@]}" 2>&1 | tail -n 1000` + if [[ $? -ne 0 ]]; then + echo "[ERROR] Cannot run BashJavaUtils to execute command ${cmd}." 1>&2 + # Print the output in case the user redirect the log to console. + echo "$output" 1>&2 + exit 1 + fi + + echo "$output" +} + +parseJvmArgsAndExportLogs() { + args=("${@}") + java_utils_output=$(runBashJavaUtilsCmd "${args[@]}") + logging_output=$(extractLoggingOutputs "${java_utils_output}") + params_output=$(extractExecutionResults "${java_utils_output}" 1) + + if [[ $? -ne 0 ]]; then + echo "[ERROR] Could not get JVM parameters properly." + echo "[ERROR] Raw output from BashJavaUtils:" + echo "$java_utils_output" + exit 1 + fi + + export JVM_ARGS="$(echo "${params_output}" | head -n1)" + export SHUFFLE_INHERITED_LOGS=" +JVM params and resource extraction logs: +jvm_opts: $JVM_ARGS +logs: $logging_output +" +} + +parseShuffleManagerJvmArgsAndExportLogs() { + args=("${@}") + parseJvmArgsAndExportLogs GET_SHUFFLE_MANAGER_JVM_PARAMS "${args[@]}" +} + +parseShuffleWorkerJvmArgsAndExportLogs() { + args=("${@}") + parseJvmArgsAndExportLogs GET_SHUFFLE_WORKER_JVM_PARAMS "${args[@]}" +} + +extractHostName() { + # handle comments: extract first part of string (before first # character) + WORKER=`echo $1 | cut -d'#' -f 1` + + # Extract the hostname from the network hierarchy + if [[ "$WORKER" =~ ^.*/([0-9a-zA-Z.-]+)$ ]]; then + WORKER=${BASH_REMATCH[1]} + fi + + echo $WORKER +} + +readManagers() { + MANAGERS_FILE="${SHUFFLE_CONF_DIR}/managers" + + if [[ ! -f "${MANAGERS_FILE}" ]]; then + echo "No managers file. Please specify managers in 'conf/managers'." + exit 1 + fi + + MANAGERS=() + + MANAGERS_ALL_LOCALHOST=true + GOON=true + while $GOON; do + read line || GOON=false + HOST=$( extractHostName $line) + if [ -n "$HOST" ] ; then + if [ "${HOST}" == "localhost" ] ; then + HOST="127.0.0.1" + fi + MANAGERS+=(${HOST}) + if [ "${HOST}" != "127.0.0.1" ] ; then + MANAGERS_ALL_LOCALHOST=false + fi + fi + done < "$MANAGERS_FILE" + + # when supports multiple shuffle managers, only need to remove the code below. + if [[ "${#MANAGERS[@]}" -gt 1 ]]; then + echo "at present, only one shuffle manager can be started, using the first ip address in managers file." + MANAGERS=(${MANAGERS[0]}) + fi +} + +readWorkers() { + WORKERS_FILE="${SHUFFLE_CONF_DIR}/workers" + + if [[ ! -f "$WORKERS_FILE" ]]; then + echo "No workers file. Please specify workers in 'conf/workers'." + exit 1 + fi + + WORKERS=() + + WORKERS_ALL_LOCALHOST=true + GOON=true + while $GOON; do + read line || GOON=false + HOST=$( extractHostName $line) + if [ -n "$HOST" ] ; then + WORKERS+=(${HOST}) + if [ "${HOST}" != "localhost" ] && [ "${HOST}" != "127.0.0.1" ] ; then + WORKERS_ALL_LOCALHOST=false + fi + fi + done < "$WORKERS_FILE" +} + +# starts or stops shuffleManagers on specified IP address. +# note that this is for HA in the future. At present, only one shuffle manager can be started. +# usage: ShuffleManagers start|stop +ShuffleManagers() { + CMD=$1 + + readManagers + MANAGERS_ARGS=("${@:2}") + MANAGERS_ARGS+=(-D remote-shuffle.manager.rpc-address=${MANAGERS[@]}) + + if [ ${MANAGERS_ALL_LOCALHOST} = true ] ; then + # all-local setup + for manager in ${MANAGERS[@]}; do + "${SHUFFLE_BIN_DIR}"/shufflemanager.sh "${CMD}" "${MANAGERS_ARGS[@]}" + done + else + # non-local setup + # start/stop shuffleManager instance(s) using pdsh (Parallel Distributed Shell) when available + TMP_ARGS=() + for marg in ${MANAGERS_ARGS[@]}; do + if [ $marg = "-D" ] ; then + TMP_ARGS+=(-D) + else + TMP_ARGS+=(\"$marg\") + fi + done + SSH_MANAGER_ARGS=${TMP_ARGS[@]} + + command -v pdsh >/dev/null 2>&1 + if [[ $? -ne 0 ]]; then + for manager in ${MANAGERS[@]}; do + ssh -n $SHUFFLE_SSH_OPTS $manager -- "nohup /bin/bash -l \"${SHUFFLE_BIN_DIR}/shufflemanager.sh\" \"${CMD}\" $SSH_MANAGER_ARGS &" + done + else + PDSH_SSH_ARGS="" PDSH_SSH_ARGS_APPEND=$SHUFFLE_SSH_OPTS pdsh -w $(IFS=, ; echo "${MANAGERS[*]}") \ + "nohup /bin/bash -l \"${SHUFFLE_BIN_DIR}/shufflemanager.sh\" \"${CMD}\" $SSH_MANAGER_ARGS " + fi + fi +} + +# starts or stops shuffleWorkers on all workers +# usage: ShuffleWorkers start|stop +ShuffleWorkers() { + CMD=$1 + + WORKER_ARGS=("${@:2}") + WORKER_ARGS+=(-D remote-shuffle.manager.rpc-address="${MANAGERS[@]}") + + readWorkers + + if [ ${WORKERS_ALL_LOCALHOST} = true ] ; then + # all-local setup + for worker in ${WORKERS[@]}; do + "${SHUFFLE_BIN_DIR}"/shuffleworker.sh "${CMD}" "${WORKER_ARGS[@]}" + done + else + # non-local setup + # start/stop shuffleWorker instance(s) using pdsh (Parallel Distributed Shell) when available + TMP_ARGS=() + for warg in ${WORKER_ARGS[@]}; do + if [ $warg = "-D" ] ; then + TMP_ARGS+=(-D) + else + TMP_ARGS+=(\"$warg\") + fi + done + SSH_WORKER_ARGS=${TMP_ARGS[@]} + command -v pdsh >/dev/null 2>&1 + if [[ $? -ne 0 ]]; then + for worker in ${WORKERS[@]}; do + ssh -n $SHUFFLE_SSH_OPTS $worker -- "nohup /bin/bash -l \"${SHUFFLE_BIN_DIR}/shuffleworker.sh\" \"${CMD}\" $SSH_WORKER_ARGS &" + done + else + PDSH_SSH_ARGS="" PDSH_SSH_ARGS_APPEND=$SHUFFLE_SSH_OPTS pdsh -w $(IFS=, ; echo "${WORKERS[*]}") \ + "nohup /bin/bash -l \"${SHUFFLE_BIN_DIR}/shuffleworker.sh\" \"${CMD}\" $SSH_WORKER_ARGS " + fi + fi +} + diff --git a/shuffle-dist/src/main/shuffle-bin/bin/kubernetes-shufflemanager.sh b/shuffle-dist/src/main/shuffle-bin/bin/kubernetes-shufflemanager.sh new file mode 100755 index 00000000..557652d6 --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/bin/kubernetes-shufflemanager.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Start a ShuffleManager on Kubernetes. + +USAGE="Usage: kubernetes-shufflemanager.sh [args]" + +ARGS=("${@:1}") + +bin=`dirname "$0"` +bin=`cd "$bin"; pwd` + +. "$bin"/config.sh + +export LOG_PREFIX=${SHUFFLE_LOG_DIR}"/shufflemanager" + +parseShuffleManagerJvmArgsAndExportLogs "${ARGS[@]}" + +exec "${SHUFFLE_BIN_DIR}"/shuffle-console.sh "kubernetes-shufflemanager" "${ARGS[@]}" diff --git a/shuffle-dist/src/main/shuffle-bin/bin/kubernetes-shuffleworker.sh b/shuffle-dist/src/main/shuffle-bin/bin/kubernetes-shuffleworker.sh new file mode 100755 index 00000000..fdf099db --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/bin/kubernetes-shuffleworker.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Start a ShuffleWorker on Kubernetes. + +USAGE="Usage: kubernetes-shuffleworker.sh [args]" + +ARGS=("${@:1}") + +bin=`dirname "$0"` +bin=`cd "$bin"; pwd` + +. "$bin"/config.sh + +export LOG_PREFIX=${SHUFFLE_LOG_DIR}"/shuffleworker" + +parseShuffleWorkerJvmArgsAndExportLogs "${ARGS[@]}" + +exec "${SHUFFLE_BIN_DIR}"/shuffle-console.sh "kubernetes-shuffleworker" "${ARGS[@]}" diff --git a/shuffle-dist/src/main/shuffle-bin/bin/shuffle-console.sh b/shuffle-dist/src/main/shuffle-bin/bin/shuffle-console.sh new file mode 100755 index 00000000..b1b0301b --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/bin/shuffle-console.sh @@ -0,0 +1,94 @@ +#!/usr/bin/env bash +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Start a shuffle component as a console application. Must be stopped with Ctrl-C +# or with SIGTERM by kill or the controlling process. +USAGE="Usage: shuffle-console.sh (kubernetes-shufflemanager|kubernetes-shuffleworker|shufflemanager|shuffleworker) [args]" + +SERVICE=$1 +ARGS=("${@:2}") # get remaining arguments as array + +bin=`dirname "$0"` +bin=`cd "$bin"; pwd` + +. "$bin"/config.sh + +case $SERVICE in + (kubernetes-shufflemanager) + CLASS_TO_RUN=com.alibaba.flink.shuffle.kubernetes.manager.KubernetesShuffleManagerRunner + ;; + + (kubernetes-shuffleworker) + CLASS_TO_RUN=com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerRunner + ;; + + (shufflemanager) + CLASS_TO_RUN=com.alibaba.flink.shuffle.coordinator.manager.ShuffleManagerRunner + ;; + + (shuffleworker) + CLASS_TO_RUN=com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerRunner + ;; + + (*) + echo "Unknown service '${SERVICE}'. $USAGE." + exit 1 + ;; +esac + +SHUFFLE_CLASSPATH=`constructShuffleClassPath` + +SHUFFLE_PID_DIR="/tmp" +pid=$SHUFFLE_PID_DIR/$SERVICE.pid +# The lock needs to be released after use because this script is started foreground +command -v flock >/dev/null 2>&1 +flock_exist=$? +if [[ ${flock_exist} -eq 0 ]]; then + exec 200<"$SHUFFLE_PID_DIR" + flock 200 +fi +# Remove the pid file when all the processes are dead +if [ -f "$pid" ]; then + all_dead=0 + while read each_pid; do + # Check whether the process is still running + kill -0 $each_pid > /dev/null 2>&1 + [[ $? -eq 0 ]] && all_dead=1 + done < "$pid" + [ ${all_dead} -eq 0 ] && rm $pid +fi +id=$([ -f "$pid" ] && echo $(wc -l < "$pid") || echo "0") + +if [ -z "${LOG_PREFIX}" ]; then + LOG_PREFIX="${SHUFFLE_LOG_DIR}/${SERVICE}-${id}-${HOSTNAME}" +fi + +log="${LOG_PREFIX}.log" + +log_setting=("-Dlog.file=${log}" "-Dlog4j.configurationFile=file:${SHUFFLE_CONF_DIR}/log4j2.properties") + +echo "Starting $SERVICE as a console application on host $HOSTNAME." + +# Add the current process id to pid file +echo $$ >> "$pid" 2>/dev/null + +# Release the lock because the java process runs in the foreground and would block other processes from modifying the pid file +[[ ${flock_exist} -eq 0 ]] && flock -u 200 + +exec "$JAVA_RUN" "${log_setting[@]}" $JVM_ARGS -classpath "`manglePathList "$SHUFFLE_CLASSPATH"`" ${CLASS_TO_RUN} "${ARGS[@]}" diff --git a/shuffle-dist/src/main/shuffle-bin/bin/shuffle-daemon.sh b/shuffle-dist/src/main/shuffle-bin/bin/shuffle-daemon.sh new file mode 100755 index 00000000..49d06079 --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/bin/shuffle-daemon.sh @@ -0,0 +1,176 @@ +#!/usr/bin/env bash +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +# Start/stop a rss daemon. +USAGE="Usage: shuffle-daemon.sh (start|stop|stop-all) (shufflemanager|shuffleworker) [args]" + +STARTSTOP=$1 +DAEMON=$2 +ARGS=("${@:3}") # get remaining arguments as array +bin=`dirname "$0"` +bin=`cd "$bin"; pwd` + +. "$bin"/config.sh + +case $DAEMON in + (shufflemanager) + CLASS_TO_RUN=com.alibaba.flink.shuffle.coordinator.manager.ShuffleManagerRunner + ;; + + (shuffleworker) + CLASS_TO_RUN=com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerRunner + ;; + + (*) + echo "Unknown daemon '${DAEMON}'. $USAGE." + exit 1 + ;; +esac + +SHUFFLE_CLASSPATH=`constructShuffleClassPath` + +SHUFFLE_PID_DIR="/tmp" +pid=$SHUFFLE_PID_DIR/$DAEMON.pid + +# Log files for daemons are indexed from the process ID's position in the PID +# file. The following lock prevents a race condition during daemon startup +# when multiple daemons read, index, and write to the PID file concurrently. +# The lock is created on the PID directory since a lock file cannot be safely +# removed. The daemon is started with the lock closed and the lock remains +# active in this script until the script exits. +command -v flock >/dev/null 2>&1 +if [[ $? -eq 0 ]]; then + exec 200<"$SHUFFLE_PID_DIR" + flock 200 +fi + +# Ascending ID depending on number of lines in pid file. +# This allows us to start multiple daemon of each type. +id=$([ -f "$pid" ] && echo $(wc -l < "$pid") || echo "0") + +if [ -z "${LOG_PREFIX}" ]; then + LOG_PREFIX="${SHUFFLE_LOG_DIR}/${DAEMON}-${id}-${HOSTNAME}" +fi + +log="${LOG_PREFIX}.log" +out="${LOG_PREFIX}.out" + +log_setting=("-Dlog.file=${log}" "-Dlog4j.configurationFile=file:${SHUFFLE_CONF_DIR}/log4j2.properties") +IS_NUMBER="^[0-9]+$" +function guaranteed_kill { + to_stop_pid=$1 + daemon=$2 + + # send sigterm for graceful shutdown + kill $to_stop_pid + # if timeout exists, use it + if command -v timeout &> /dev/null ; then + # wait 10 seconds for process to stop. By default, we kills the JVM 5 seconds after sigterm. + timeout 10 tail --pid=$to_stop_pid -f /dev/null + if [ "$?" -eq 124 ]; then + echo "Daemon $daemon didn't stop within 10 seconds. Killing it." + # send sigkill + kill -9 $to_stop_pid + fi + fi +} + +case $STARTSTOP in + + (start) + + # Print a warning if daemons are already running on host + if [ -f "$pid" ]; then + active=() + while IFS='' read -r p || [[ -n "$p" ]]; do + kill -0 $p >/dev/null 2>&1 + if [ $? -eq 0 ]; then + active+=($p) + fi + done < "${pid}" + + count="${#active[@]}" + + if [ ${count} -gt 0 ]; then + echo "[INFO] $count instance(s) of $DAEMON are already running on $HOSTNAME." + fi + fi + + echo "Starting $DAEMON daemon on host $HOSTNAME." + + "$JAVA_RUN" "${log_setting[@]}" $JVM_ARGS -classpath "`manglePathList "$SHUFFLE_CLASSPATH"`" ${CLASS_TO_RUN} "${ARGS[@]}" > "$out" 200<&- 2>&1 < /dev/null & + + mypid=$! + + # Add to pid file if successful start + if [[ ${mypid} =~ ${IS_NUMBER} ]] && kill -0 $mypid > /dev/null 2>&1 ; then + echo $mypid >> "$pid" + else + echo "Error starting $DAEMON daemon." + exit 1 + fi + ;; + + (stop) + if [ -f "$pid" ]; then + # Remove last in pid file + to_stop=$(tail -n 1 "$pid") + + if [ -z $to_stop ]; then + rm "$pid" # If all stopped, clean up pid file + echo "No $DAEMON daemon to stop on host $HOSTNAME." + else + sed \$d "$pid" > "$pid.tmp" # all but last line + + # If all stopped, clean up pid file + [ $(wc -l < "$pid.tmp") -eq 0 ] && rm "$pid" "$pid.tmp" || mv "$pid.tmp" "$pid" + + if kill -0 $to_stop > /dev/null 2>&1; then + echo "Stopping $DAEMON daemon (pid: $to_stop) on host $HOSTNAME." + guaranteed_kill $to_stop $DAEMON + else + echo "No $DAEMON daemon (pid: $to_stop) is running anymore on $HOSTNAME." + fi + fi + else + echo "No $DAEMON daemon to stop on host $HOSTNAME." + fi + ;; + + (stop-all) + if [ -f "$pid" ]; then + mv "$pid" "${pid}.tmp" + + while read to_stop; do + if kill -0 $to_stop > /dev/null 2>&1; then + echo "Stopping $DAEMON daemon (pid: $to_stop) on host $HOSTNAME." + guaranteed_kill $to_stop $DAEMON + else + echo "Skipping $DAEMON daemon (pid: $to_stop), because it is not running anymore on $HOSTNAME." + fi + done < "${pid}.tmp" + rm "${pid}.tmp" + fi + ;; + + (*) + echo "Unexpected argument '$STARTSTOP'. $USAGE." + exit 1 + ;; + +esac diff --git a/shuffle-dist/src/main/shuffle-bin/bin/shufflemanager.sh b/shuffle-dist/src/main/shuffle-bin/bin/shufflemanager.sh new file mode 100755 index 00000000..7ef6c34f --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/bin/shufflemanager.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +# Start/stop a rss shufflemanager. +USAGE="Usage: shufflemanager.sh (start|start-foreground)|stop|stop-all [args]" +ARGS=("${@:2}") # get remaining arguments as array +STARTSTOP=$1 + +if [[ $STARTSTOP != "start" ]] && [[ $STARTSTOP != "start-foreground" ]] && [[ $STARTSTOP != "stop" ]] && [[ $STARTSTOP != "stop-all" ]]; then + echo $USAGE + exit 1 +fi + +bin=`dirname "$0"` +bin=`cd "$bin"; pwd` + +. "$bin"/config.sh + +ENTRYPOINT=shufflemanager + +if [[ $STARTSTOP == "start" ]] || [[ $STARTSTOP == "start-foreground" ]]; then + export LOG_PREFIX=${SHUFFLE_LOG_DIR}"/shufflemanager" + parseShuffleManagerJvmArgsAndExportLogs "${ARGS[@]}" +fi + +if [[ $STARTSTOP == "start-foreground" ]]; then + exec "${SHUFFLE_BIN_DIR}"/shuffle-console.sh $ENTRYPOINT "${ARGS[@]}" +else + "${SHUFFLE_BIN_DIR}"/shuffle-daemon.sh $STARTSTOP $ENTRYPOINT "${ARGS[@]}" +fi diff --git a/shuffle-dist/src/main/shuffle-bin/bin/shuffleworker.sh b/shuffle-dist/src/main/shuffle-bin/bin/shuffleworker.sh new file mode 100755 index 00000000..789639d4 --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/bin/shuffleworker.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Start/stop a rss shuffleworker. +USAGE="Usage: shuffleworker.sh (start|start-foreground|stop|stop-all) [args]" + +STARTSTOP=$1 + +ARGS=("${@:2}") + +if [[ $STARTSTOP != "start" ]] && [[ $STARTSTOP != "start-foreground" ]] && [[ $STARTSTOP != "stop" ]] && [[ $STARTSTOP != "stop-all" ]]; then + echo $USAGE + exit 1 +fi + +bin=`dirname "$0"` +bin=`cd "$bin"; pwd` + +. "$bin"/config.sh + +ENTRYPOINT=shuffleworker + +if [[ $STARTSTOP == "start" ]] || [[ $STARTSTOP == "start-foreground" ]]; then + export LOG_PREFIX=${SHUFFLE_LOG_DIR}"/shuffleworker" + parseShuffleWorkerJvmArgsAndExportLogs "${ARGS[@]}" +fi + +if [[ $STARTSTOP == "start-foreground" ]]; then + exec "${SHUFFLE_BIN_DIR}"/shuffle-console.sh $ENTRYPOINT "${ARGS[@]}" +else + # Start a single shuffleWorker + "${SHUFFLE_BIN_DIR}"/shuffle-daemon.sh $STARTSTOP $ENTRYPOINT "${ARGS[@]}" +fi diff --git a/shuffle-dist/src/main/shuffle-bin/bin/start-cluster.sh b/shuffle-dist/src/main/shuffle-bin/bin/start-cluster.sh new file mode 100755 index 00000000..2e1f162d --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/bin/start-cluster.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +ARGS=("${@:1}") + +bin=`dirname "$0"` +bin=`cd "$bin"; pwd` + +. "$bin"/config.sh + +# Start the shuffleManager instance +shopt -s nocasematch + +echo "Starting cluster." +# Start single shuffleManager +ShuffleManagers start "${ARGS[@]}" + +shopt -u nocasematch + +# Start ShuffleWorker instance(s) +ShuffleWorkers start "${ARGS[@]}" diff --git a/shuffle-dist/src/main/shuffle-bin/bin/stop-cluster.sh b/shuffle-dist/src/main/shuffle-bin/bin/stop-cluster.sh new file mode 100755 index 00000000..e40c4746 --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/bin/stop-cluster.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +bin=`dirname "$0"` +bin=`cd "$bin"; pwd` + +. "$bin"/config.sh + +# Stop shuffleWorker instance(s) +ShuffleWorkers stop + +# Stop shuffleManager instance +shopt -s nocasematch +ShuffleManagers stop "${ARGS[@]}" +shopt -u nocasematch diff --git a/shuffle-dist/src/main/shuffle-bin/bin/yarn-shufflemanager.sh b/shuffle-dist/src/main/shuffle-bin/bin/yarn-shufflemanager.sh new file mode 100755 index 00000000..4bbaf008 --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/bin/yarn-shufflemanager.sh @@ -0,0 +1,262 @@ +#!/usr/bin/env bash +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Start a ShuffleManager on Yarn. + +usage() { + echo " +Usage: + start start shuffle manager on Yarn. + stop stop shuffle manager on Yarn. + If the applicatioin name is not specified, the default shufflemanager with the + name 'Flink-Remote-Shuffle-Manager' will be killed. + restart resart shuffle manager on Yarn. + + Optional args: + <-q, --queue> queue name to start shuffle manager, default queue is 'default' + <-n, --name> application name, the name should not contain spaces + <-c, --clusterid> the cluster id to start shuffle manager, default cluster id is 'default-cluster' + <-p, --priority> application priority when deploying shuffle manager service on Yarn, default is 0 + <-a, --attempts> application master max attempt counts, default is 1000000 + <--am-mem-mb> size of memory in megabytes allocated for starting shuffle manager, default is 2048 + <--am-overhead-mb> size of overhead memory in megabytes allocated for starting shuffle manager, default is 512 + <--rm-heartbeat-ms> heartbeat interval between shuffle manager am and yarn resource manager, default is 1000 + <--check-old-job> flag indicating whether to check old job is exist, default is False, set to True to enable it + -h, --help display this help and exit + + example1: yarn-shufflemanager.sh start --queue default --am-mem-mb 1024 + example2: yarn-shufflemanager.sh start -q default -D remote-shuffle.xxx=xxx + example3: yarn-shufflemanager.sh stop +" +} + +JAR_NAME="shuffle-dist" +RUN_CLASS="com.alibaba.flink.shuffle.yarn.entry.manager.YarnShuffleManagerEntrypoint" + +confDir="$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +homeDir=$(dirname $confDir) +libDir=$homeDir/lib +logDir=$homeDir/log +runCmd="" +jarFile="" +cmdType="" + +# Set default values for all args +yarnCmd=yarn +queue=default +name="Flink-Remote-Shuffle-Manager" +priority=0 +attempts=1000000 +memSize=2048 +memOverheadSize=512 +heartbeatInterval=1000 +checkOldJobExist=False +clusterid=default-cluster + +parseArgs() { + runCmd=$runCmd" -D remote-shuffle.yarn.manager-home-dir=$homeDir" + while [[ $# -gt 0 ]]; do + case "$1" in + start | stop | restart) + cmdType="$1" + ;; + -q | --queue) + queue="$2" + echo "arg, queue: $queue" + runCmd=$runCmd" -D remote-shuffle.yarn.manager-app-queue-name=$queue" + shift + ;; + -n | --name) + name="$2" + splitNames=(${name// / }) + if [[ ${#splitNames[*]} -gt 1 ]]; then + echo "The specified name by -n or --name can not contain spaces." + exit 1 + fi + echo "arg, name: $name" + runCmd=$runCmd" -D remote-shuffle.yarn.manager-app-name=$name" + shift + ;; + -c | --clusterid) + clusterid="$2" + echo "arg, clusterid: $clusterid" + runCmd=$runCmd" -D remote-shuffle.cluster.id=$clusterid" + shift + ;; + -p | --priority) + priority="$2" + echo "arg, priority: $priority" + runCmd=$runCmd" -D remote-shuffle.yarn.manager-app-priority=$priority" + shift + ;; + -a | --attempts) + attempts="$2" + echo "arg, attempts: $attempts" + runCmd=$runCmd" -D remote-shuffle.yarn.manager-am-max-attempts=$attempts" + shift + ;; + --am-mem-mb) + memSize="$2" + echo "arg, memSize: $memSize" + runCmd=$runCmd" -D remote-shuffle.yarn.manager-am-memory-size-mb=$memSize" + shift + ;; + --am-overhead-mb) + memOverheadSize="$2" + echo "arg, memOverheadSize: $memOverheadSize" + runCmd=$runCmd" -D remote-shuffle.yarn.manager-am-memory-overhead-mb=$memOverheadSize" + shift + ;; + --rm-heartbeat-ms) + heartbeatInterval="$2" + echo "arg, heartbeatInterval: $heartbeatInterval" + runCmd=$runCmd" -D remote-shuffle.yarn.manager-rm-heartbeat-interval-ms=$heartbeatInterval" + shift + ;; + --check-old-job) + checkOldJobExist="$2" + echo "arg, checkOldJobExist: $checkOldJobExist" + shift + ;; + -D) + runCmd=$runCmd" -D $2" + echo "arg, internal arg: $2" + shift + ;; + -h | --help) + usage + exit + ;; + --) + shift + break + ;; + *) + echo "$1 is not a valid option" + exit 1 + ;; + esac + shift + done + + if [[ $cmdType != "" ]]; then + echo + echo "cmd type: $cmdType" + echo + fi +} + +checkYarnEnv() { + isYarnExist=false + if [[ -x "$(command -v yarn)" ]]; then + isYarnExist=true + yarnCmd="yarn" + fi + + if [[ $HADOOP_YARN_HOME != "" && -x "$(command -v $HADOOP_YARN_HOME/bin/yarn)" ]]; then + isYarnExist=true + yarnCmd="$HADOOP_YARN_HOME/bin/yarn" + fi + + if [[ $isYarnExist == false ]]; then + echo "yarn command is not exist, please make sure 'yarn' command is executable or export HADOOP_YARN_HOME" + exit 1 + fi +} + +checkPath() { + if [[ ! -d $libDir ]]; then + echo "Please make sure the script is in the compiled directory, which contains lib, conf, log, etc." + exit 1 + fi + jarFile=$(ls "$libDir"/$JAR_NAME*.jar) + if [[ ! -f $jarFile ]]; then + echo "Please make sure the script is in the compiled directory, and the compiled shuffle-dist jar is in the lib directory." + exit 1 + fi +} + +checkEnv() { + checkYarnEnv + checkPath +} + +getYarnAppIdByName() { + appIds=$($yarnCmd application -list | egrep '^application_' | grep "$name" | grep "RUNNING" | awk -v matchName="$name" '{if ($2==matchName) {print $1}}') + echo $appIds +} + +checkAppExist() { + echo "Checking whether shuffle manager yarn application with name $name is exist" + oldAppId=$(getYarnAppIdByName) + if [[ $oldAppId != "" ]]; then + echo "Shuffle manager with name $name is still running." + echo "If you want to start multiple shuffle managers, please specify different names for different shuffle managers by -n" + exit 1 + fi +} + +submitYarnJob() { + if [[ $checkOldJobExist == "True" ]]; then + checkAppExist + fi + runCmd=$(echo "$yarnCmd jar $jarFile $RUN_CLASS $runCmd" | cat) + echo $runCmd + $runCmd +} + +stopYarnJob() { + echo "Getting shuffle manager yarn application with name $name" + appids=$(getYarnAppIdByName) + if [[ $appids == "" ]]; then + echo "The shuffle manager with name $name is not running. Stop is not required" + else + for appid in $(echo $appids); do + $yarnCmd application -kill "$appid" + echo "The shuffle manager application $appid with name $name is stopped" + done + fi + +} + +restartYarnJob() { + stopYarnJob + submitYarnJob +} + +unrecognizedCmd() { + echo "Unknown command type. Valid commands include start, stop and restart" + exit 1 +} + +main() { + parseArgs "$@" + checkEnv + if [[ $cmdType == "start" ]]; then + submitYarnJob + elif [[ $cmdType == "stop" ]]; then + stopYarnJob + elif [[ $cmdType == "restart" ]]; then + restartYarnJob + else + unrecognizedCmd + fi +} + +main "$@" diff --git a/shuffle-dist/src/main/shuffle-bin/conf/kubernetes-shuffle-cluster-template.yaml b/shuffle-dist/src/main/shuffle-bin/conf/kubernetes-shuffle-cluster-template.yaml new file mode 100644 index 00000000..cc06a48e --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/conf/kubernetes-shuffle-cluster-template.yaml @@ -0,0 +1,64 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +apiVersion: shuffleoperator.alibaba.com/v1 +kind: RemoteShuffle +metadata: + namespace: flink-system-rss + name: flink-remote-shuffle +spec: + shuffleDynamicConfigs: + remote-shuffle.manager.jvm-opts: -verbose:gc -Xloggc:/flink-remote-shuffle/log/gc.log + remote-shuffle.worker.jvm-opts: -verbose:gc -Xloggc:/flink-remote-shuffle/log/gc.log + remote-shuffle.kubernetes.manager.cpu: 4 + remote-shuffle.kubernetes.worker.cpu: 4 + remote-shuffle.kubernetes.worker.limit-factor.cpu: 8 + remote-shuffle.kubernetes.container.image: + remote-shuffle.kubernetes.worker.volume.host-paths: name:disk,path:

,mountPath:/data + remote-shuffle.storage.local-data-dirs: '[SSD]/data' + remote-shuffle.high-availability.mode: ZOOKEEPER + remote-shuffle.ha.zookeeper.quorum: + + shuffleFileConfigs: + log4j2.properties: | + monitorInterval=30 + + rootLogger.level = INFO + rootLogger.appenderRef.console.ref = ConsoleAppender + rootLogger.appenderRef.rolling.ref = RollingFileAppender + + # Log all info to the console + appender.console.name = ConsoleAppender + appender.console.type = CONSOLE + appender.console.layout.type = PatternLayout + appender.console.layout.pattern = %d{yyyy-MM-dd HH:mm:ss,SSS} %-5p [%t] %-60c %x - %m%n + + # Log all info in the given rolling file + appender.rolling.name = RollingFileAppender + appender.rolling.type = RollingFile + appender.rolling.append = true + appender.rolling.fileName = ${sys:log.file} + appender.rolling.filePattern = ${sys:log.file}.%i + appender.rolling.layout.type = PatternLayout + appender.rolling.layout.pattern = %d{yyyy-MM-dd HH:mm:ss,SSS} %-5p [%t] %-60c %x - %m%n + appender.rolling.policies.type = Policies + appender.rolling.policies.size.type = SizeBasedTriggeringPolicy + appender.rolling.policies.size.size=256MB + appender.rolling.policies.startup.type = OnStartupTriggeringPolicy + appender.rolling.strategy.type = DefaultRolloverStrategy + appender.rolling.strategy.max = ${env:MAX_LOG_FILE_NUMBER:-10} diff --git a/shuffle-dist/src/main/shuffle-bin/conf/kubernetes-shuffle-operator-template.yaml b/shuffle-dist/src/main/shuffle-bin/conf/kubernetes-shuffle-operator-template.yaml new file mode 100644 index 00000000..8357b639 --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/conf/kubernetes-shuffle-operator-template.yaml @@ -0,0 +1,95 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: flink-rss-cr +rules: + - apiGroups: ["apiextensions.k8s.io"] + resources: + - customresourcedefinitions + verbs: + - '*' + - apiGroups: ["shuffleoperator.alibaba.com"] + resources: + - remoteshuffles + verbs: + - '*' + - apiGroups: ["shuffleoperator.alibaba.com"] + resources: + - remoteshuffles/status + verbs: + - update + - apiGroups: ["apps"] + resources: + - deployments + - daemonsets + verbs: + - '*' + - apiGroups: [""] + resources: + - configmaps + verbs: + - '*' +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: flink-rss-crb +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: flink-rss-cr +subjects: + - kind: ServiceAccount + name: flink-rss-sa + namespace: flink-system-rss +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: flink-rss-sa + namespace: flink-system-rss +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + namespace: flink-system-rss + name: flink-remote-shuffle-operator +spec: + replicas: 1 + selector: + matchLabels: + app: flink-remote-shuffle-operator + template: + metadata: + labels: + app: flink-remote-shuffle-operator + spec: + serviceAccountName: flink-rss-sa + containers: + - name: flink-remote-shuffle-operator + image: # You need to configure the docker image to be used here. + imagePullPolicy: Always + command: + - bash + args: + - -c + - $JAVA_HOME/bin/java -classpath '/flink-remote-shuffle/opt/*' -Dlog4j.configurationFile=file:/flink-remote-shuffle/conf/log4j2-operator.properties -Dlog.file=/flink-remote-shuffle/log/operator.log com.alibaba.flink.shuffle.kubernetes.operator.RemoteShuffleApplicationOperatorEntrypoint diff --git a/shuffle-dist/src/main/shuffle-bin/conf/log4j2-operator.properties b/shuffle-dist/src/main/shuffle-bin/conf/log4j2-operator.properties new file mode 100644 index 00000000..eb13a240 --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/conf/log4j2-operator.properties @@ -0,0 +1,45 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +# Allows this configuration to be modified at runtime. The file will be checked every 30 seconds. +monitorInterval=30 +# This affects logging for both user code and remote shuffle service +rootLogger.level=INFO +rootLogger.appenderRef.console.ref=ConsoleAppender +rootLogger.appenderRef.rolling.ref=RollingFileAppender +# The following lines set the log level of thirty libraries. +logger.fabric8.name=com.alibaba.flink.shuffle.kubernetes.shaded.io.fabric8 +logger.fabric8.level=WARN +# Log all info to the console +appender.console.name=ConsoleAppender +appender.console.type=CONSOLE +appender.console.layout.type=PatternLayout +appender.console.layout.pattern=%d{yyyy-MM-dd HH:mm:ss,SSS} %-5p [%t] %-60c %x - %m%n +# Log all info in the given rolling file +appender.rolling.name=RollingFileAppender +appender.rolling.type=RollingFile +appender.rolling.append=true +appender.rolling.fileName=${sys:log.file} +appender.rolling.filePattern=${sys:log.file}.%i +appender.rolling.layout.type=PatternLayout +appender.rolling.layout.pattern=%d{yyyy-MM-dd HH:mm:ss,SSS} %-5p [%t] %-60c %x - %m%n +appender.rolling.policies.type=Policies +appender.rolling.policies.size.type=SizeBasedTriggeringPolicy +appender.rolling.policies.size.size=100MB +appender.rolling.policies.startup.type=OnStartupTriggeringPolicy +appender.rolling.strategy.type=DefaultRolloverStrategy +appender.rolling.strategy.max=${env:MAX_LOG_FILE_NUMBER:-10} diff --git a/shuffle-dist/src/main/shuffle-bin/conf/log4j2.properties b/shuffle-dist/src/main/shuffle-bin/conf/log4j2.properties new file mode 100644 index 00000000..2f5022a8 --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/conf/log4j2.properties @@ -0,0 +1,42 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +# Allows this configuration to be modified at runtime. The file will be checked every 30 seconds. +monitorInterval=30 +# This affects logging for both user code and remote shuffle service +rootLogger.level=INFO +rootLogger.appenderRef.console.ref=ConsoleAppender +rootLogger.appenderRef.rolling.ref=RollingFileAppender +# Log all info to the console +appender.console.name=ConsoleAppender +appender.console.type=CONSOLE +appender.console.layout.type=PatternLayout +appender.console.layout.pattern=%d{yyyy-MM-dd HH:mm:ss,SSS} %-5p [%t] %-60c %x - %m%n +# Log all info in the given rolling file +appender.rolling.name=RollingFileAppender +appender.rolling.type=RollingFile +appender.rolling.append=true +appender.rolling.fileName=${sys:log.file} +appender.rolling.filePattern=${sys:log.file}.%i +appender.rolling.layout.type=PatternLayout +appender.rolling.layout.pattern=%d{yyyy-MM-dd HH:mm:ss,SSS} %-5p [%t] %-60c %x - %m%n +appender.rolling.policies.type=Policies +appender.rolling.policies.size.type=SizeBasedTriggeringPolicy +appender.rolling.policies.size.size=100MB +appender.rolling.policies.startup.type=OnStartupTriggeringPolicy +appender.rolling.strategy.type=DefaultRolloverStrategy +appender.rolling.strategy.max=${env:MAX_LOG_FILE_NUMBER:-10} diff --git a/shuffle-dist/src/main/shuffle-bin/conf/managers b/shuffle-dist/src/main/shuffle-bin/conf/managers new file mode 100644 index 00000000..7b9ad531 --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/conf/managers @@ -0,0 +1 @@ +127.0.0.1 diff --git a/shuffle-dist/src/main/shuffle-bin/conf/remote-shuffle-conf.yaml b/shuffle-dist/src/main/shuffle-bin/conf/remote-shuffle-conf.yaml new file mode 100644 index 00000000..59473c6f --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/conf/remote-shuffle-conf.yaml @@ -0,0 +1,30 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Change it to your pre-created shuffle data storage directory for production usage + +remote-shuffle.storage.local-data-dirs: [SSD]/tmp + +# Configure the following options for high availability deployment mode + +# remote-shuffle.high-availability.mode: ZOOKEEPER +# remote-shuffle.ha.zookeeper.quorum: + +# Configure the following shuffle manager address if high availability is not enabled + +# remote-shuffle.manager.rpc-address: for example, 127.0.0.1 for local cluster diff --git a/shuffle-dist/src/main/shuffle-bin/conf/workers b/shuffle-dist/src/main/shuffle-bin/conf/workers new file mode 100644 index 00000000..7b9ad531 --- /dev/null +++ b/shuffle-dist/src/main/shuffle-bin/conf/workers @@ -0,0 +1 @@ +127.0.0.1 diff --git a/shuffle-e2e-tests/pom.xml b/shuffle-e2e-tests/pom.xml new file mode 100644 index 00000000..12883997 --- /dev/null +++ b/shuffle-e2e-tests/pom.xml @@ -0,0 +1,148 @@ + + + + + + com.alibaba.flink.shuffle + flink-shuffle-parent + 1.0-SNAPSHOT + + 4.0.0 + + shuffle-e2e-tests + + + + + org.apache.flink + flink-runtime + ${flink.version} + provided + + + + com.alibaba.flink.shuffle + shuffle-common + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-core + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-plugin + ${project.version} + + + + org.apache.flink + flink-java + ${flink.version} + provided + + + + org.apache.flink + flink-streaming-java_${scala.binary.version} + ${flink.version} + provided + + + + org.apache.flink + flink-table-planner_${scala.binary.version} + ${flink.version} + provided + + + + org.apache.flink + flink-clients_${scala.binary.version} + ${flink.version} + provided + + + + org.apache.flink + flink-hadoop-fs + ${flink.version} + provided + + + + org.apache.flink + flink-rpc-akka + ${flink.version} + test + + + + com.alibaba.flink.shuffle + shuffle-coordinator + ${project.version} + test-jar + test + + + + com.alibaba.flink.shuffle + shuffle-transfer + ${project.version} + test-jar + test + + + + org.apache.flink + flink-shaded-hadoop-2 + 2.4.1-9.0 + provided + + + log4j + log4j + + + org.slf4j + slf4j-log4j12 + + + + + + org.apache.curator + curator-test + ${curator.version} + test + + + + org.apache.logging.log4j + log4j-1.2-api + test + + + + + diff --git a/shuffle-e2e-tests/src/main/java/com/alibaba/flink/shuffle/e2e/TPCDSE2E.java b/shuffle-e2e-tests/src/main/java/com/alibaba/flink/shuffle/e2e/TPCDSE2E.java new file mode 100644 index 00000000..6bff6ce6 --- /dev/null +++ b/shuffle-e2e-tests/src/main/java/com/alibaba/flink/shuffle/e2e/TPCDSE2E.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.table.api.EnvironmentSettings; +import org.apache.flink.table.api.TableEnvironment; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** TPCDS end to end test. */ +public class TPCDSE2E { + + private static final Logger LOG = LoggerFactory.getLogger(TPCDSE2E.class); + + public static void main(String[] args) { + if (args.length != 3) { + throw new IllegalArgumentException("Args: (tpcds-home, parallelism, sqlIdx)."); + } + String tpcdsHome = args[0]; + int parallelism = Integer.valueOf(args[1]); + LOG.info("Running E2E under TPCDS: {}.", tpcdsHome); + + EnvironmentSettings settings = + EnvironmentSettings.newInstance().useBlinkPlanner().inBatchMode().build(); + TableEnvironment stEnv = TableEnvironment.create(settings); + Configuration configuration = stEnv.getConfig().getConfiguration(); + configuration.setInteger("table.exec.resource.default-parallelism", parallelism); + configuration.setString("table.exec.shuffle-mode", "ALL_EDGES_BLOCKING"); + + registerTables(stEnv, tpcdsHome); + + Integer sqlIdx = Integer.valueOf(args[2]); + stEnv.executeSql(sqls()[sqlIdx]).collect(); + } + + private static String[] sqls() { + String[] sqls = new String[3]; + sqls[0] = + "" + + "insert into result_table select count(1) from (select distinct * from store_sales) x"; + sqls[1] = + "" + + "insert into result_table" + + " select count(1) from" + + " (select * from store_sales union select * from store_sales union select * from store_sales union select * from store_sales) x"; + sqls[2] = + "" + + "insert into result_table" + + " select count(1) from" + + " (select distinct * from" + + " store_sales join item on store_sales.ss_item_sk=item.i_item_sk " + + " join date_dim on store_sales.ss_sold_date_sk=d_date_sk) t"; + return sqls; + } + + private static String createTableFooter(String tpcdsHome, String tablePath) { + return String.format( + "" + + "WITH (\n" + + " 'connector' = 'filesystem',\n" + + " 'path' = '%s/%s',\n" + + " 'format' = 'csv',\n" + + " 'csv.field-delimiter' = '|',\n" + + " 'csv.ignore-parse-errors' = 'true'\n" + + ")", + tpcdsHome, tablePath); + } + + private static void registerTables(TableEnvironment tenv, String tpcdsHome) { + tenv.executeSql( + "" + + "CREATE TABLE result_table (res BIGINT)\n" + + createTableFooter(tpcdsHome, "result/")); + tenv.executeSql( + "" + + "CREATE TABLE store_sales (\n" + + " ss_sold_date_sk BIGINT,\n" + + " ss_sold_time_sk BIGINT,\n" + + " ss_item_sk BIGINT,\n" + + " ss_customer_sk BIGINT,\n" + + " ss_cdemo_sk BIGINT,\n" + + " ss_hdemo_sk BIGINT,\n" + + " ss_addr_sk BIGINT,\n" + + " ss_store_sk BIGINT,\n" + + " ss_promo_sk BIGINT,\n" + + " ss_ticket_number BIGINT,\n" + + " ss_quantity BIGINT,\n" + + " ss_wholesale_cost DOUBLE,\n" + + " ss_list_price DOUBLE,\n" + + " ss_sales_price DOUBLE,\n" + + " ss_ext_discount_amt DOUBLE,\n" + + " ss_ext_sales_price DOUBLE,\n" + + " ss_ext_wholesale_cost DOUBLE,\n" + + " ss_ext_list_price DOUBLE,\n" + + " ss_ext_tax DOUBLE,\n" + + " ss_coupon_amt DOUBLE,\n" + + " ss_net_paid DOUBLE,\n" + + " ss_net_paid_inc_tax DOUBLE,\n" + + " ss_net_profit DOUBLE)\n" + + createTableFooter(tpcdsHome, "store_sales/")); + + tenv.executeSql( + "" + + "CREATE TABLE web_sales (\n" + + " ws_sold_date_sk BIGINT,\n" + + " ws_sold_time_sk BIGINT,\n" + + " ws_ship_date_sk BIGINT,\n" + + " ws_item_sk BIGINT,\n" + + " ws_bill_customer_sk BIGINT,\n" + + " ws_bill_cdemo_sk BIGINT,\n" + + " ws_bill_hdemo_sk BIGINT,\n" + + " ws_bill_addr_sk BIGINT,\n" + + " ws_ship_customer_sk BIGINT,\n" + + " ws_ship_cdemo_sk BIGINT,\n" + + " ws_ship_hdemo_sk BIGINT,\n" + + " ws_ship_addr_sk BIGINT,\n" + + " ws_web_page_sk BIGINT,\n" + + " ws_web_site_sk BIGINT,\n" + + " ws_ship_mode_sk BIGINT,\n" + + " ws_warehouse_sk BIGINT,\n" + + " ws_promo_sk BIGINT,\n" + + " ws_order_number BIGINT,\n" + + " ws_quantity BIGINT,\n" + + " ws_wholesale_cost DOUBLE,\n" + + " ws_list_price DOUBLE,\n" + + " ws_sales_price DOUBLE,\n" + + " ws_ext_discount_amt DOUBLE,\n" + + " ws_ext_sales_price DOUBLE,\n" + + " ws_ext_wholesale_cost DOUBLE,\n" + + " ws_ext_list_price DOUBLE,\n" + + " ws_ext_tax DOUBLE,\n" + + " ws_coupon_amt DOUBLE,\n" + + " ws_ext_ship_cost DOUBLE,\n" + + " ws_net_paid DOUBLE,\n" + + " ws_net_paid_inc_tax DOUBLE,\n" + + " ws_net_paid_inc_ship DOUBLE,\n" + + " ws_net_paid_inc_ship_tax DOUBLE,\n" + + " ws_net_profit DOUBLE)\n" + + createTableFooter(tpcdsHome, "web_sales/")); + tenv.executeSql( + "" + + "CREATE TABLE item (\n" + + " i_item_sk BIGINT,\n" + + " i_item_id STRING,\n" + + " i_rec_start_date STRING,\n" + + " i_rec_end_date STRING,\n" + + " i_item_desc STRING,\n" + + " i_current_price DOUBLE,\n" + + " i_wholesale_cost DOUBLE,\n" + + " i_brand_id BIGINT,\n" + + " i_brand STRING,\n" + + " i_class_id BIGINT,\n" + + " i_class STRING,\n" + + " i_category_id BIGINT,\n" + + " i_category STRING,\n" + + " i_manufact_id BIGINT,\n" + + " i_manufact STRING,\n" + + " i_size STRING,\n" + + " i_formulation STRING,\n" + + " i_color STRING,\n" + + " i_units STRING,\n" + + " i_container STRING,\n" + + " i_manager_id BIGINT,\n" + + " i_product_name STRING)\n" + + createTableFooter(tpcdsHome, "item/")); + + tenv.executeSql( + "" + + "CREATE TABLE date_dim (\n" + + " d_date_sk BIGINT,\n" + + " d_date_id STRING,\n" + + " d_date STRING,\n" + + " d_month_seq BIGINT,\n" + + " d_week_seq BIGINT,\n" + + " d_quarter_seq BIGINT,\n" + + " d_year BIGINT,\n" + + " d_dow BIGINT,\n" + + " d_moy BIGINT,\n" + + " d_dom BIGINT,\n" + + " d_qoy BIGINT,\n" + + " d_fy_year BIGINT,\n" + + " d_fy_quarter_seq BIGINT,\n" + + " d_fy_week_seq BIGINT,\n" + + " d_day_name STRING,\n" + + " d_quarter_name STRING,\n" + + " d_holiday STRING,\n" + + " d_weekend STRING,\n" + + " d_following_holiday STRING,\n" + + " d_first_dom BIGINT,\n" + + " d_last_dom BIGINT,\n" + + " d_same_day_ly BIGINT,\n" + + " d_same_day_lq BIGINT,\n" + + " d_current_day STRING,\n" + + " d_current_week STRING,\n" + + " d_current_month STRING,\n" + + " d_current_quarter STRING,\n" + + " d_current_year STRING)\n" + + createTableFooter(tpcdsHome, "date_dim/")); + } +} diff --git a/shuffle-e2e-tests/src/main/java/com/alibaba/flink/shuffle/e2e/WordCount.java b/shuffle-e2e-tests/src/main/java/com/alibaba/flink/shuffle/e2e/WordCount.java new file mode 100644 index 00000000..7c711e8e --- /dev/null +++ b/shuffle-e2e-tests/src/main/java/com/alibaba/flink/shuffle/e2e/WordCount.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e; + +import com.alibaba.flink.shuffle.e2e.util.WordCountData; + +import org.apache.flink.api.common.ExecutionMode; +import org.apache.flink.api.common.InputDependencyConstraint; +import org.apache.flink.api.common.RuntimeExecutionMode; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.utils.MultipleParameterTool; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.ExecutionOptions; +import org.apache.flink.runtime.jobgraph.JobType; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.graph.GlobalStreamExchangeMode; +import org.apache.flink.streaming.api.graph.StreamGraph; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +/** + * Implements the "WordCount" program that computes a simple word occurrence histogram over text + * files in a streaming fashion. + * + *

The input is a plain text file with lines separated by newline characters. + * + *

Usage: WordCount --input <path> --output <path>
+ * If no parameters are provided, the program is run with default data from {@link WordCountData}. + * + *

This example shows how to: + * + *

    + *
  • write a simple Flink Streaming program, + *
  • use tuple data types, + *
  • write and use user-defined functions. + *
+ */ +public class WordCount { + + // ************************************************************************* + // PROGRAM + // ************************************************************************* + + public static void main(String[] args) throws Exception { + // Checking input parameters + final MultipleParameterTool params = MultipleParameterTool.fromArgs(args); + + // set up the execution environment + Configuration configuration = new Configuration(); + configuration.set(ExecutionOptions.RUNTIME_MODE, RuntimeExecutionMode.BATCH); + final StreamExecutionEnvironment env = + StreamExecutionEnvironment.getExecutionEnvironment(configuration); + + // make parameters available in the web interface + env.getConfig().setGlobalJobParameters(params); + env.getConfig().setExecutionMode(ExecutionMode.BATCH); + env.getConfig().setParallelism(1); + env.getConfig().setDefaultInputDependencyConstraint(InputDependencyConstraint.ALL); + + // get input data + DataStream text = null; + if (params.has("input")) { + // union all the inputs from text files + for (String input : params.getMultiParameterRequired("input")) { + if (text == null) { + text = env.readTextFile(input); + } else { + text = text.union(env.readTextFile(input)); + } + } + Preconditions.checkNotNull(text, "Input DataStream should not be null."); + } else { + System.out.println("Executing WordCount example with default input data set."); + System.out.println("Use --input to specify file input."); + // get default test text data + text = env.fromElements(WordCountData.WORDS); + } + + DataStream> counts = + text.flatMap(new Tokenizer()).keyBy(value -> value.f0).sum(1); + + // emit result + if (params.has("output")) { + counts.writeAsText(params.get("output")); + } else { + System.out.println("Printing result to stdout. Use --output to specify output path."); + counts.print(); + } + + StreamGraph streamGraph = env.getStreamGraph(); + streamGraph.setGlobalStreamExchangeMode(GlobalStreamExchangeMode.ALL_EDGES_BLOCKING); + streamGraph.setJobType(JobType.BATCH); + // execute program + env.execute(streamGraph); + } + + // ************************************************************************* + // USER FUNCTIONS + // ************************************************************************* + + /** + * Implements the string tokenizer that splits sentences into words as a user-defined + * FlatMapFunction. The function takes a line (String) and splits it into multiple pairs in the + * form of "(word,1)" ({@code Tuple2}). + */ + public static final class Tokenizer + implements FlatMapFunction> { + + @Override + public void flatMap(String value, Collector> out) { + // normalize and split the line + String[] tokens = value.toLowerCase().split("\\W+"); + + // emit the pairs + for (String token : tokens) { + if (token.length() > 0) { + out.collect(new Tuple2<>(token, 1)); + } + } + } + } +} diff --git a/shuffle-e2e-tests/src/main/java/com/alibaba/flink/shuffle/e2e/util/WordCountData.java b/shuffle-e2e-tests/src/main/java/com/alibaba/flink/shuffle/e2e/util/WordCountData.java new file mode 100644 index 00000000..84e8799c --- /dev/null +++ b/shuffle-e2e-tests/src/main/java/com/alibaba/flink/shuffle/e2e/util/WordCountData.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e.util; + +/** + * Provides the default data sets used for the WordCount example program. The default data sets are + * used, if no parameters are given to the program. + */ +public class WordCountData { + + public static final String[] WORDS = + new String[] { + "To be, or not to be,--that is the question:--", + "Whether 'tis nobler in the mind to suffer", + "The slings and arrows of outrageous fortune", + "Or to take arms against a sea of troubles,", + "And by opposing end them?--To die,--to sleep,--", + "No more; and by a sleep to say we end", + "The heartache, and the thousand natural shocks", + "That flesh is heir to,--'tis a consummation", + "Devoutly to be wish'd. To die,--to sleep;--", + "To sleep! perchance to dream:--ay, there's the rub;", + "For in that sleep of death what dreams may come,", + "When we have shuffled off this mortal coil,", + "Must give us pause: there's the respect", + "That makes calamity of so long life;", + "For who would bear the whips and scorns of time,", + "The oppressor's wrong, the proud man's contumely,", + "The pangs of despis'd love, the law's delay,", + "The insolence of office, and the spurns", + "That patient merit of the unworthy takes,", + "When he himself might his quietus make", + "With a bare bodkin? who would these fardels bear,", + "To grunt and sweat under a weary life,", + "But that the dread of something after death,--", + "The undiscover'd country, from whose bourn", + "No traveller returns,--puzzles the will,", + "And makes us rather bear those ills we have", + "Than fly to others that we know not of?", + "Thus conscience does make cowards of us all;", + "And thus the native hue of resolution", + "Is sicklied o'er with the pale cast of thought;", + "And enterprises of great pith and moment,", + "With this regard, their currents turn awry,", + "And lose the name of action.--Soft you now!", + "The fair Ophelia!--Nymph, in thy orisons", + "Be all my sins remember'd." + }; +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/AbstractInstableE2ETest.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/AbstractInstableE2ETest.java new file mode 100644 index 00000000..55bd677d --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/AbstractInstableE2ETest.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.e2e.flinkcluster.FlinkLocalCluster; +import com.alibaba.flink.shuffle.e2e.shufflecluster.LocalShuffleCluster; +import com.alibaba.flink.shuffle.e2e.zookeeper.ZooKeeperTestEnvironment; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; + +import org.apache.commons.io.FileUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.TestName; + +import java.io.File; +import java.nio.file.Path; + +/** Base class for instable tests. */ +public class AbstractInstableE2ETest { + + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Rule public TestName name = new TestName(); + + protected ZooKeeperTestEnvironment zkEnv; + + protected CuratorFramework zkClient; + + protected LocalShuffleCluster shuffleCluster; + + protected FlinkLocalCluster flinkCluster; + + @Before + public void setup() throws Exception { + File logDir = + new File( + System.getProperty("buildDirectory") + + "/" + + getClass().getSimpleName() + + "-" + + name.getMethodName()); + if (logDir.exists()) { + FileUtils.deleteDirectory(logDir); + } + + zkEnv = new ZooKeeperTestEnvironment(1); + String zkConnect = zkEnv.getConnect(); + String logPath = logDir.getAbsolutePath(); + shuffleCluster = + createLocalShuffleCluster(logPath, zkConnect, tmpFolder.newFolder().toPath()); + shuffleCluster.start(); + + Exception exception = null; + for (int i = 0; i < 3; ++i) { + try { + flinkCluster = createFlinkCluster(logDir.getAbsolutePath(), tmpFolder, zkConnect); + flinkCluster.start(); + break; + } catch (Exception throwable) { + exception = exception == null ? throwable : exception; + flinkCluster.shutdown(); + flinkCluster = null; + } + } + + if (flinkCluster == null) { + throw new Exception(exception); + } + zkClient = flinkCluster.getZKClient(); + } + + @After + public void cleanup() throws Exception { + flinkCluster.shutdown(); + + shuffleCluster.shutdown(); + + zkEnv.shutdown(); + } + + protected LocalShuffleCluster createLocalShuffleCluster( + String logPath, String zkConnect, Path dataPath) { + return new LocalShuffleCluster(logPath, 2, zkConnect, dataPath, new Configuration()); + } + + protected FlinkLocalCluster createFlinkCluster( + String logPath, TemporaryFolder tmpFolder, String zkConnect) throws Exception { + return new FlinkLocalCluster( + logPath, + 2, + tmpFolder, + zkConnect, + new org.apache.flink.configuration.Configuration()); + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/ArbitraryKillingE2ETest.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/ArbitraryKillingE2ETest.java new file mode 100644 index 00000000..dd8ab856 --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/ArbitraryKillingE2ETest.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.e2e.flinkcluster.FlinkLocalCluster; +import com.alibaba.flink.shuffle.e2e.shufflecluster.LocalShuffleCluster; + +import org.apache.flink.configuration.TaskManagerOptions; + +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collection; +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +/** Test arbitrary killing task manager, shuffle worker, shuffle manager. */ +@RunWith(Parameterized.class) +public class ArbitraryKillingE2ETest extends AbstractInstableE2ETest { + + private static final int NUM_ROUNDS = 5; + + public ArbitraryKillingE2ETest(int ignore) {} + + @Parameterized.Parameters + public static Collection data() { + Object[][] params = new Object[NUM_ROUNDS][1]; + params[0] = new Object[1]; + for (int i = 0; i < params.length; i++) { + params[i][0] = i; + } + return Arrays.asList(params); + } + + @Override + protected FlinkLocalCluster createFlinkCluster( + String logPath, TemporaryFolder tmpFolder, String zkConnect) throws Exception { + org.apache.flink.configuration.Configuration conf = + new org.apache.flink.configuration.Configuration(); + Random random = new Random(); + conf.set(TaskManagerOptions.NUM_TASK_SLOTS, 1 + random.nextInt(2)); + return new FlinkLocalCluster(logPath, 4, tmpFolder, zkConnect, conf); + } + + @Override + protected LocalShuffleCluster createLocalShuffleCluster( + String logPath, String zkConnect, Path dataPath) { + return new LocalShuffleCluster(logPath, 4, zkConnect, dataPath, new Configuration()); + } + + @Test + public void test() { + final ExecutorService executor = Executors.newSingleThreadExecutor(); + final Future future = executor.submit(this::testRoutine); + try { + future.get(360, TimeUnit.SECONDS); + } catch (Exception e) { + flinkCluster.printProcessLog(); + shuffleCluster.printProcessLog(); + throw new AssertionError("Test failure.", e); + } finally { + executor.shutdown(); + } + } + + private void testRoutine() { + AtomicReference killingScript = new AtomicReference<>(new StringBuilder()); + AtomicReference cause = new AtomicReference<>(null); + try { + AtomicBoolean finished = new AtomicBoolean(false); + JobForShuffleTesting job = + new JobForShuffleTesting( + flinkCluster, 4, JobForShuffleTesting.DataScale.NORMAL); + startArbitraryKiller(finished, cause, killingScript); + job.run(); + finished.set(true); + if (cause.get() != null) { + throw cause.get(); + } + } catch (Throwable t) { + flinkCluster.printProcessLog(); + shuffleCluster.printProcessLog(); + throw new AssertionError("Killing script: " + killingScript.get().toString(), t); + } + } + + private void startArbitraryKiller( + AtomicBoolean finished, + AtomicReference cause, + AtomicReference killingScript) { + Thread t = + new Thread( + () -> { + try { + Random random = new Random(); + int killTiming = random.nextInt(10); + killingScript + .get() + .append("Sleep for ") + .append(killTiming) + .append(" seconds.\n"); + Thread.sleep(killTiming * 1000); + if (finished.get()) { + return; + } + int x = random.nextInt(3); + if (x == 0) { + // kill a task manager + int tmIdx = random.nextInt(4); + killingScript.get().append("Kill task manager ").append(tmIdx); + flinkCluster.killTaskManager(tmIdx); + } else if (x == 1) { + // kill a shuffle worker + int swIdx = random.nextInt(3); + killingScript + .get() + .append("Kill shuffle worker ") + .append(swIdx); + shuffleCluster.killShuffleWorkerForcibly(swIdx); + } else { + killingScript.get().append("Kill shuffle manager"); + // TODO kill a shuffle manager + } + } catch (Exception e) { + cause.set(e); + } + }); + t.setDaemon(true); + t.start(); + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/InstableFlinkJobE2ETest.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/InstableFlinkJobE2ETest.java new file mode 100644 index 00000000..93fff5c9 --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/InstableFlinkJobE2ETest.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e; + +import com.alibaba.flink.shuffle.common.functions.ConsumerWithException; +import com.alibaba.flink.shuffle.common.functions.RunnableWithException; +import com.alibaba.flink.shuffle.e2e.JobForShuffleTesting.TaskStat; + +import org.junit.Test; + +import java.io.File; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.runQuietly; +import static com.alibaba.flink.shuffle.e2e.JobForShuffleTesting.STAGE0_NAME; +import static com.alibaba.flink.shuffle.e2e.JobForShuffleTesting.STAGE1_NAME; +import static com.alibaba.flink.shuffle.e2e.JobForShuffleTesting.STAGE2_NAME; + +/** Tests for instable Flink job. */ +public class InstableFlinkJobE2ETest extends AbstractInstableE2ETest { + + @Test + public void testBasicRoutine() { + try { + JobForShuffleTesting job = new JobForShuffleTesting(flinkCluster); + job.run(); + shuffleCluster.checkResourceReleased(); + flinkCluster.checkResourceReleased(); + } catch (Throwable t) { + flinkCluster.printProcessLog(); + shuffleCluster.printProcessLog(); + throw new AssertionError(t); + } + } + + @Test + public void testCancelJob() { + try { + JobForShuffleTesting job = new JobForShuffleTesting(flinkCluster); + RunnableWithException procedure = + () -> { + ConsumerWithException runnable = ignore -> flinkCluster.cancelJobs(); + job.planOperation(STAGE0_NAME, 0, 0, TaskStat.RUNNING, runnable); + job.run(); + }; + runQuietly(procedure::run); + job.checkNoResult(); + shuffleCluster.checkResourceReleased(); + flinkCluster.checkResourceReleased(); + } catch (Throwable t) { + flinkCluster.printProcessLog(); + shuffleCluster.printProcessLog(); + throw new AssertionError(t); + } + } + + @Test + public void testShuffleWriteTaskFailureAndRecovery() { + try { + JobForShuffleTesting job = new JobForShuffleTesting(flinkCluster); + job.planFailingARunningTask(STAGE0_NAME, 0, 0); + job.run(); + shuffleCluster.checkResourceReleased(); + flinkCluster.checkResourceReleased(); + } catch (Throwable t) { + flinkCluster.printProcessLog(); + shuffleCluster.printProcessLog(); + throw new AssertionError(t); + } + } + + @Test + public void testShuffleReadWriteTaskFailureAndRecovery() { + try { + JobForShuffleTesting job = new JobForShuffleTesting(flinkCluster); + job.planFailingARunningTask(STAGE1_NAME, 0, 0); + job.run(); + shuffleCluster.checkResourceReleased(); + flinkCluster.checkResourceReleased(); + } catch (Throwable t) { + flinkCluster.printProcessLog(); + shuffleCluster.printProcessLog(); + throw new AssertionError(t); + } + } + + @Test + public void testShuffleReadTaskFailureAndRecovery() { + try { + JobForShuffleTesting job = new JobForShuffleTesting(flinkCluster); + job.planFailingARunningTask(STAGE2_NAME, 0, 0); + job.run(); + shuffleCluster.checkResourceReleased(); + flinkCluster.checkResourceReleased(); + } catch (Throwable t) { + flinkCluster.printProcessLog(); + shuffleCluster.printProcessLog(); + throw new AssertionError(t); + } + } + + @Test + public void testShuffleWriteTaskIOFailure() { + try { + JobForShuffleTesting job = new JobForShuffleTesting(flinkCluster); + ConsumerWithException r = + ignore -> { + for (int i = 0; i < shuffleCluster.shuffleWorkers.length; ++i) { + Stream paths = + Files.list( + new File(shuffleCluster.getDataDirForWorker(i)) + .toPath()); + for (Path p : paths.collect(Collectors.toList())) { + if (!p.toString().endsWith("_meta")) { + Files.deleteIfExists(p); + } + } + } + }; + job.planOperation(STAGE0_NAME, 0, 0, TaskStat.RUNNING, r); + job.run(); + shuffleCluster.checkResourceReleased(); + flinkCluster.checkResourceReleased(); + } catch (Throwable t) { + flinkCluster.printProcessLog(); + shuffleCluster.printProcessLog(); + throw new AssertionError(t); + } + } + + @Test + public void testShuffleReadWriteTaskIOFailure() { + try { + JobForShuffleTesting job = new JobForShuffleTesting(flinkCluster); + ConsumerWithException r = + ignore -> { + for (int i = 0; i < shuffleCluster.shuffleWorkers.length; ++i) { + Stream paths = + Files.list( + new File(shuffleCluster.getDataDirForWorker(i)) + .toPath()); + for (Path p : paths.collect(Collectors.toList())) { + if (!p.toString().endsWith("_meta") + && !p.toString().endsWith("partial")) { + Files.deleteIfExists(p); + } + } + } + }; + job.planOperation(STAGE1_NAME, 0, 0, TaskStat.RUNNING, r); + job.run(); + + // Longest failover is as below: + // 1. Failure when reading&writing, because all remote files removed; + // 2. Retry but fetch failed due to PartitionException on (STAGE0_NAME, 0); + // 3. Retry but fetch failed due to PartitionException on (STAGE0_NAME, 1); + // 4. Retry and succeed; + // job.checkTaskInfoOnZK(STAGE1_NAME, 0, 4, false); + shuffleCluster.checkResourceReleased(); + flinkCluster.checkResourceReleased(); + } catch (Throwable t) { + flinkCluster.printProcessLog(); + shuffleCluster.printProcessLog(); + throw new AssertionError(t); + } + } + + @Test + public void testShuffleReadTaskIOFailure() { + try { + JobForShuffleTesting job = new JobForShuffleTesting(flinkCluster); + ConsumerWithException r = + ignore -> { + for (int i = 0; i < shuffleCluster.shuffleWorkers.length; ++i) { + Stream paths = + Files.list(Paths.get(shuffleCluster.getDataDirForWorker(i))); + for (Path p : paths.collect(Collectors.toList())) { + if (!p.toString().endsWith("_meta")) { + Files.deleteIfExists(p); + } + } + } + }; + job.planOperation(STAGE2_NAME, 0, 0, TaskStat.RUNNING, r); + job.run(); + shuffleCluster.checkResourceReleased(); + flinkCluster.checkResourceReleased(); + } catch (Throwable t) { + flinkCluster.printProcessLog(); + shuffleCluster.printProcessLog(); + throw new AssertionError(t); + } + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/JobForShuffleTesting.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/JobForShuffleTesting.java new file mode 100644 index 00000000..73ca8525 --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/JobForShuffleTesting.java @@ -0,0 +1,640 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e; + +import com.alibaba.flink.shuffle.common.functions.ConsumerWithException; +import com.alibaba.flink.shuffle.e2e.flinkcluster.FlinkLocalCluster; +import com.alibaba.flink.shuffle.e2e.utils.LogErrorHandler; +import com.alibaba.flink.shuffle.e2e.zookeeper.ZooKeeperTestUtils; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.ExecutionMode; +import org.apache.flink.api.common.InputDependencyConstraint; +import org.apache.flink.api.common.eventtime.NoWatermarksGenerator; +import org.apache.flink.api.common.eventtime.WatermarkGenerator; +import org.apache.flink.api.common.eventtime.WatermarkGeneratorSupplier; +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.connector.source.lib.NumberSequenceSource; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.client.deployment.StandaloneClusterId; +import org.apache.flink.client.program.rest.RestClusterClient; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.runtime.client.JobStatusMessage; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.util.ZooKeeperUtils; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.graph.GlobalStreamExchangeMode; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.graph.StreamGraph; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.ChainingStrategy; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.StreamTask; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.ChildData; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.NodeCache; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.recipes.cache.NodeCacheListener; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled; +import org.apache.flink.shaded.zookeeper3.org.apache.zookeeper.CreateMode; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.util.Collection; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +import static com.alibaba.flink.shuffle.common.utils.ProcessUtils.getProcessID; +import static org.apache.flink.streaming.api.environment.StreamExecutionEnvironment.createRemoteEnvironment; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +/** Flink job for shuffle testing. */ +public class JobForShuffleTesting { + + private static final Logger LOG = LoggerFactory.getLogger(JobForShuffleTesting.class); + + private static final int LARGE_SCALE_DATA_RECORDS = 30000000; + + private static final int NORMAL_SCALE_DATA_RECORDS = 3000000; + + private final int numRecords; + + private static final int EXPECT_CHECKSUM_WHEN_LARGE = 55032; + + private static final int EXPECT_CHECKSUM_WHEN_NORMAL = 50418; + + static final int PARALLELISM = 2; + + static final String STAGE0_NAME = "shuffleWrite"; + + static final String STAGE1_NAME = "shuffleRead&Write"; + + static final String STAGE2_NAME = "shuffleRead"; + + private final FlinkLocalCluster cluster; + + private final int parallelism; + + private final Configuration config; + + private final String outputPath; + + private final String zkConnect; + + private final String zkPath; + + private Consumer executionConfigModifier; + + public JobForShuffleTesting(FlinkLocalCluster cluster) throws Exception { + this(cluster, PARALLELISM, DataScale.LARGE); + } + + public JobForShuffleTesting(FlinkLocalCluster cluster, DataScale scale) throws Exception { + this(cluster, PARALLELISM, scale); + } + + public JobForShuffleTesting(FlinkLocalCluster cluster, int parallelism, DataScale scale) + throws Exception { + this.cluster = cluster; + this.parallelism = parallelism; + this.config = cluster.getConfig(); + this.outputPath = cluster.temporaryFolder.newFolder().getPath(); + this.zkConnect = cluster.getZKConnect(); + this.zkPath = cluster.getZKPath(); + switch (scale) { + case LARGE: + numRecords = LARGE_SCALE_DATA_RECORDS; + break; + case NORMAL: + default: + numRecords = NORMAL_SCALE_DATA_RECORDS; + } + } + + static class NoWatermark implements WatermarkStrategy { + private static final long serialVersionUID = 1L; + + @Override + public WatermarkGenerator createWatermarkGenerator( + WatermarkGeneratorSupplier.Context context) { + return new NoWatermarksGenerator<>(); + } + } + + public void run() throws Exception { + LOG.info("Starting job to process {} records and write to {}.", numRecords, outputPath); + + StreamExecutionEnvironment env = createRemoteEnvironment("localhost", 1337, config); + env.setParallelism(parallelism); + + TupleTypeInfo> typeInfo = + new TupleTypeInfo<>(BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO); + + env.fromSource( + new NumberSequenceSource(1, numRecords), + new NoWatermark(), + "Sequence Source") + .map( + new MapFunction>() { + @Override + public Tuple2 map(Long aLong) throws Exception { + return Tuple2.of(aLong, aLong + 1); + } + }) + .transform( + STAGE0_NAME, + typeInfo, + new MapOpWithStub(zkConnect, zkPath, STAGE0_NAME, parallelism, numRecords)) + .name(STAGE0_NAME) + .keyBy(0) + .transform( + STAGE1_NAME, + typeInfo, + new MapOpWithStub(zkConnect, zkPath, STAGE1_NAME, parallelism, numRecords)) + .name(STAGE1_NAME) + .keyBy(1) + .transform( + STAGE2_NAME, + typeInfo, + new MapOpWithStub(zkConnect, zkPath, STAGE2_NAME, parallelism, numRecords)) + .name(STAGE2_NAME) + .writeAsText(outputPath, FileSystem.WriteMode.OVERWRITE); + + ExecutionConfig executionConfig = env.getConfig(); + executionConfig.setDefaultInputDependencyConstraint(InputDependencyConstraint.ALL); + executionConfig.setExecutionMode(ExecutionMode.BATCH); + + if (executionConfigModifier != null) { + executionConfigModifier.accept(executionConfig); + } + + StreamGraph streamGraph = env.getStreamGraph(); + streamGraph.setGlobalStreamExchangeMode(GlobalStreamExchangeMode.ALL_EDGES_BLOCKING); + streamGraph.setJobName("Streaming WordCount"); + env.execute(streamGraph); + + verifyResult(); + } + + public void setExecutionConfigModifier(Consumer executionConfigModifier) { + this.executionConfigModifier = executionConfigModifier; + } + + public void planFailingARunningTask(String stageName, int taskID, int attemptID) + throws Exception { + String path = pathOfTaskCmd(stageName, taskID, attemptID, TaskStat.RUNNING); + CuratorFramework zk = cluster.getZKClient(); + zk.create().forPath(path); + zk.setData().forPath(path, new byte[] {(byte) TaskCmd.SELF_DESTROY.ordinal()}); + } + + public void planOperation( + String stageName, + int taskID, + int attemptID, + TaskStat taskStat, + ConsumerWithException listener) + throws Exception { + planAsyncOperation( + stageName, + taskID, + attemptID, + taskStat, + (processIdAndTaskStat, nodeCache) -> { + if (processIdAndTaskStat.getStat() == taskStat) { + listener.accept(processIdAndTaskStat); + } + return CompletableFuture.completedFuture(null); + }); + } + + public void planAsyncOperation( + String stageName, + int taskID, + int attemptID, + TaskStat taskStat, + AsyncPlanListener listener) + throws Exception { + String pathCmd = pathOfTaskCmd(stageName, taskID, attemptID, taskStat); + CuratorFramework zk = cluster.getZKClient(); + zk.create().forPath(pathCmd); + zk.setData().forPath(pathCmd, new byte[] {(byte) TaskCmd.WAIT_KEEP_GOING.ordinal()}); + + String pathInfo = pathOfTaskInfo(stageName, taskID, attemptID); + NodeCache nodeCache = new NodeCache(zk, pathInfo, false); + NodeCacheListener l = + () -> { + ChildData childData = nodeCache.getCurrentData(); + if (nodeCache.getCurrentData() == null || childData.getData() == null) { + return; + } + ProcessIDAndTaskStat processIDAndTaskStat = + ProcessIDAndTaskStat.fromBytes(childData.getData()); + + try { + if (processIDAndTaskStat.getStat() == taskStat) { + CompletableFuture resultFuture = + listener.onPlanedTimeReached(processIDAndTaskStat, nodeCache); + resultFuture.whenComplete( + (ret, throwable) -> { + if (throwable != null) { + cancelJob(throwable); + } else { + try { + zk.setData() + .forPath( + pathCmd, + new byte[] { + (byte) + TaskCmd.KEEP_GOING + .ordinal() + }); + } catch (Exception e) { + cancelJob(e); + } + } + }); + } + } catch (Throwable throwable) { + // Exception caught, let's fail the job + LOG.error("The check failed and would stop the flink cluster", throwable); + cancelJob(throwable); + } + }; + + nodeCache.start(); + nodeCache.getListenable().addListener(l); + } + + private void cancelJob(Throwable throwable) { + try { + // If check failed, we would then cancel all the jobs to make the test + // fail... + RestClusterClient restClusterClient = + new RestClusterClient<>(config, StandaloneClusterId.getInstance()); + Collection jobStatusMessages = restClusterClient.listJobs().get(); + for (JobStatusMessage jobStatusMessage : jobStatusMessages) { + LOG.error("Canceling " + jobStatusMessage.getJobId() + " due to ", throwable); + restClusterClient.cancel(jobStatusMessage.getJobId()).get(); + } + + restClusterClient.close(); + } catch (Exception e) { + LOG.info("Failed to cancel job.", e); + } + } + + public void checkNoResult() throws Exception { + if (Files.list(new File(outputPath).toPath()).findAny().isPresent()) { + throw new AssertionError("Result exists."); + } + } + + private static final class MapOpWithStub extends AbstractStreamOperator> + implements OneInputStreamOperator, Tuple2> { + + private static final long serialVersionUID = -4276646161067243517L; + + private final String zkConnect; + + private final String zkPath; + + private final String name; + + private final int parallelism; + + private Environment env; + + private int numRecordsProcessed; + + private CuratorFramework zkClient; + + private int taskIdx; + + private int attempt; + + private final int numRecords; + + public MapOpWithStub( + String zkConnect, String zkPath, String name, int parallelism, int numRecords) { + this.zkConnect = zkConnect; + this.zkPath = zkPath; + this.name = name; + this.parallelism = parallelism; + this.numRecords = numRecords; + this.setChainingStrategy(ChainingStrategy.ALWAYS); + } + + @Override + public void setup( + StreamTask containingTask, + StreamConfig config, + Output>> output) { + super.setup(containingTask, config, output); + + this.env = containingTask.getEnvironment(); + } + + @Override + public void open() throws Exception { + super.open(); + Configuration conf = + ZooKeeperTestUtils.createZooKeeperHAConfigForFlink(zkConnect, zkPath); + zkClient = ZooKeeperUtils.startCuratorFramework(conf, LogErrorHandler.INSTANCE); + taskIdx = getRuntimeContext().getIndexOfThisSubtask(); + attempt = getRuntimeContext().getAttemptNumber(); + final ResultPartitionID resultPartitionID; + if (env.getAllWriters().length > 0) { + resultPartitionID = env.getWriter(0).getPartitionId(); + } else { + resultPartitionID = null; + } + ProcessIDAndTaskStat data = + new ProcessIDAndTaskStat(resultPartitionID, getProcessID(), TaskStat.OPENED); + zkClient.create() + .withMode(CreateMode.PERSISTENT) + .forPath(pathOfTaskInfo(name, taskIdx, attempt), data.toBytes()); + } + + @Override + public void processElement(StreamRecord> streamRecord) { + if (numRecordsProcessed == numRecords / parallelism * 0.5) { + try { + final ResultPartitionID resultPartitionID; + if (env.getAllWriters().length > 0) { + resultPartitionID = env.getWriter(0).getPartitionId(); + } else { + resultPartitionID = null; + } + ProcessIDAndTaskStat data = + new ProcessIDAndTaskStat( + resultPartitionID, getProcessID(), TaskStat.RUNNING); + zkClient.setData() + .forPath(pathOfTaskInfo(name, taskIdx, attempt), data.toBytes()); + TaskCmd cmd = readCmd(TaskStat.RUNNING); + if (cmd != null) { + execCmd(cmd, TaskStat.RUNNING); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + numRecordsProcessed++; + + int mod = 99999; + Tuple2 longs = streamRecord.getValue(); + output.collect(streamRecord.replace(Tuple2.of(longs.f0 * 2 % mod, longs.f1 * 3 % mod))); + } + + @Override + public void close() throws Exception { + super.close(); + + final ResultPartitionID resultPartitionID; + if (env.getAllWriters().length > 0) { + resultPartitionID = env.getWriter(0).getPartitionId(); + } else { + resultPartitionID = null; + } + ProcessIDAndTaskStat data = + new ProcessIDAndTaskStat(resultPartitionID, getProcessID(), TaskStat.CLOSED); + zkClient.setData().forPath(pathOfTaskInfo(name, taskIdx, attempt), data.toBytes()); + zkClient.close(); + } + + private TaskCmd readCmd(TaskStat taskStat) throws Exception { + TaskCmd cmd = null; + String cmdPath = pathOfTaskCmd(name, taskIdx, attempt, taskStat); + if (zkClient.checkExists().forPath(cmdPath) != null) { + byte b = zkClient.getData().forPath(cmdPath)[0]; + cmd = TaskCmd.values()[b]; + } + return cmd; + } + + private void execCmd(TaskCmd cmd, TaskStat taskStat) throws Exception { + switch (cmd) { + case KEEP_GOING: + break; + case WAIT_KEEP_GOING: + while (readCmd(taskStat) == TaskCmd.WAIT_KEEP_GOING) { + Thread.sleep(1000); + } + break; + case SELF_DESTROY: + try { + zkClient.delete().forPath(pathOfTaskCmd(name, taskIdx, attempt, taskStat)); + } catch (Exception e) { + throw new RuntimeException(e); + } + throw new RuntimeException("Task self destroy."); + } + } + } + + static String pathOfTaskInfo(String name, int taskIdx, int attempt) { + return "/" + name + "-task" + taskIdx + "-attempt" + attempt + ".info"; + } + + private static String pathOfTaskCmd(String name, int taskIdx, int attempt, TaskStat taskStat) { + return "/" + name + "-task" + taskIdx + "-attempt" + attempt + "-" + taskStat + ".cmd"; + } + + private enum TaskCmd { + KEEP_GOING, + WAIT_KEEP_GOING, + SELF_DESTROY; + } + + /** Task status when running. */ + public static class ProcessIDAndTaskStat { + + /** Only the single result partition id is recorded. */ + ResultPartitionID resultPartitionID; + + int processID; + + TaskStat stat; + + public ProcessIDAndTaskStat( + ResultPartitionID resultPartitionID, int processID, TaskStat stat) { + this.resultPartitionID = resultPartitionID; + this.processID = processID; + this.stat = stat; + } + + public int getProcessID() { + return processID; + } + + public TaskStat getStat() { + return stat; + } + + static ProcessIDAndTaskStat fromBytes(byte[] bytes) { + + ByteBuffer wrapped = ByteBuffer.wrap(bytes); + int resultPartitionIDBytesSize = wrapped.getInt(); + final ResultPartitionID resultPartitionID; + if (resultPartitionIDBytesSize > 0) { + byte[] resultPartitionIDBytes = new byte[resultPartitionIDBytesSize]; + wrapped.get(resultPartitionIDBytes); + + ByteBuf resultPartitionIDByteBuf = Unpooled.wrappedBuffer(resultPartitionIDBytes); + IntermediateResultPartitionID intermediateResultPartitionID = + IntermediateResultPartitionID.fromByteBuf(resultPartitionIDByteBuf); + ExecutionAttemptID executionAttemptID = + ExecutionAttemptID.fromByteBuf(resultPartitionIDByteBuf); + resultPartitionID = + new ResultPartitionID(intermediateResultPartitionID, executionAttemptID); + } else { + resultPartitionID = null; + } + int processID = wrapped.getInt(); + TaskStat stat = TaskStat.values()[wrapped.get()]; + return new ProcessIDAndTaskStat(resultPartitionID, processID, stat); + } + + byte[] toBytes() { + ByteBuf resultPartitionIDsByteBuf = Unpooled.buffer(); + if (resultPartitionID != null) { + resultPartitionID.getPartitionId().writeTo(resultPartitionIDsByteBuf); + resultPartitionID.getProducerId().writeTo(resultPartitionIDsByteBuf); + } + + byte[] resultPartitionIDsBytes = new byte[resultPartitionIDsByteBuf.readableBytes()]; + resultPartitionIDsByteBuf.readBytes(resultPartitionIDsBytes); + + ByteBuffer buf = ByteBuffer.allocate(4 + resultPartitionIDsBytes.length + 5); + buf.putInt(resultPartitionIDsBytes.length); + buf.put(resultPartitionIDsBytes); + buf.putInt(processID); + buf.put((byte) stat.ordinal()); + return buf.array(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ProcessIDAndTaskStat that = (ProcessIDAndTaskStat) o; + return processID == that.processID + && Objects.equals(resultPartitionID, that.resultPartitionID) + && stat == that.stat; + } + + @Override + public int hashCode() { + return Objects.hash(resultPartitionID, processID, stat); + } + } + + /** Task state on ZK. */ + public enum TaskStat { + OPENED, + RUNNING, + CLOSED + } + + public int getParallelism() { + return parallelism; + } + + public void checkAttemptsNum(String stageName, int taskID, int expectNum) throws Exception { + for (int i = 0; i < expectNum; i++) { + String taskInfoPath = pathOfTaskInfo(stageName, taskID, i); + assertNotNull(cluster.getZKClient().checkExists().forPath(taskInfoPath)); + } + String taskInfoPath = pathOfTaskInfo(stageName, taskID, expectNum); + assertNull(cluster.getZKClient().checkExists().forPath(taskInfoPath)); + } + + private void verifyResult() throws Exception { + int checksum = 0; + int numRecords = 0; + for (int i = 1; i <= parallelism; i++) { + try (BufferedReader br = new BufferedReader(new FileReader(outputPath + "/" + i))) { + String line = null; + while ((line = br.readLine()) != null) { + String strippedBrackets = line.substring(1, line.length() - 1); + String[] numbers = strippedBrackets.split(","); + checksum += Integer.parseInt(numbers[0]); + checksum %= 77777; + checksum += Integer.parseInt(numbers[1]); + checksum %= 77777; + numRecords += 1; + } + } + } + if (this.numRecords == NORMAL_SCALE_DATA_RECORDS + && checksum != EXPECT_CHECKSUM_WHEN_NORMAL) { + throw new IllegalStateException( + String.format( + "Expect checksum %d, but found %d.", + EXPECT_CHECKSUM_WHEN_NORMAL, checksum)); + } + + if (this.numRecords == LARGE_SCALE_DATA_RECORDS && checksum != EXPECT_CHECKSUM_WHEN_LARGE) { + throw new IllegalStateException( + String.format( + "Expect checksum %d, but found %d.", + EXPECT_CHECKSUM_WHEN_LARGE, checksum)); + } + + if (numRecords != this.numRecords) { + throw new IllegalStateException( + String.format( + "Expect numRecords %d, but found %d.", this.numRecords, numRecords)); + } + } + + /** Asynchronous Listener of the planned time has reached. */ + public interface AsyncPlanListener { + + CompletableFuture onPlanedTimeReached( + ProcessIDAndTaskStat taskStat, NodeCache nodeCache) throws Exception; + } + + enum DataScale { + NORMAL, + LARGE + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/JobForShuffleTestingE2ETest.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/JobForShuffleTestingE2ETest.java new file mode 100644 index 00000000..a553bfce --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/JobForShuffleTestingE2ETest.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e; + +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** Tests the {@link JobForShuffleTesting}. */ +public class JobForShuffleTestingE2ETest { + + @Test + public void testSerializeProcessIDAndTaskStat() { + JobForShuffleTesting.ProcessIDAndTaskStat processIDAndTaskStat = + new JobForShuffleTesting.ProcessIDAndTaskStat( + new ResultPartitionID(), 10, JobForShuffleTesting.TaskStat.RUNNING); + + byte[] bytes = processIDAndTaskStat.toBytes(); + JobForShuffleTesting.ProcessIDAndTaskStat deserialized = + JobForShuffleTesting.ProcessIDAndTaskStat.fromBytes(bytes); + assertEquals(processIDAndTaskStat, deserialized); + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/ShuffleManagerHAE2ETest.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/ShuffleManagerHAE2ETest.java new file mode 100644 index 00000000..e675f4b4 --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/ShuffleManagerHAE2ETest.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperHaServices; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; +import com.alibaba.flink.shuffle.core.config.HeartbeatOptions; +import com.alibaba.flink.shuffle.e2e.flinkcluster.FlinkLocalCluster; +import com.alibaba.flink.shuffle.e2e.shufflecluster.LocalShuffleCluster; +import com.alibaba.flink.shuffle.e2e.zookeeper.ZooKeeperTestEnvironment; +import com.alibaba.flink.shuffle.e2e.zookeeper.ZooKeeperTestUtils; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.configuration.PipelineOptionsInternal; +import org.apache.flink.configuration.WebOptions; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; + +import org.apache.commons.io.FileUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.TestName; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.time.Duration; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; + +/** Shuffle manager ha test. */ +public class ShuffleManagerHAE2ETest { + + private static final Logger LOG = LoggerFactory.getLogger(ShuffleManagerHAE2ETest.class); + + @ClassRule public static TemporaryFolder temporaryFolder = new TemporaryFolder(); + + @Rule public TestName name = new TestName(); + + private File logDir; + + private ZooKeeperTestEnvironment zkCluster; + + private LocalShuffleCluster shuffleCluster; + + @Before + public void setup() throws Exception { + zkCluster = new ZooKeeperTestEnvironment(1); + + logDir = + new File( + System.getProperty("buildDirectory") + + "/" + + getClass().getSimpleName() + + "-" + + name.getMethodName()); + if (logDir.exists()) { + FileUtils.deleteDirectory(logDir); + } + + Configuration configuration = new Configuration(); + configuration.setDuration( + HeartbeatOptions.HEARTBEAT_WORKER_INTERVAL, Duration.ofMillis(2000L)); + configuration.setDuration( + HeartbeatOptions.HEARTBEAT_WORKER_TIMEOUT, Duration.ofMillis(10000L)); + configuration.setDuration( + HeartbeatOptions.HEARTBEAT_JOB_INTERVAL, Duration.ofMillis(5000L)); + configuration.setDuration( + HeartbeatOptions.HEARTBEAT_JOB_TIMEOUT, Duration.ofMillis(30000L)); + shuffleCluster = + new LocalShuffleCluster( + logDir.getAbsolutePath(), + 2, + zkCluster.getConnect(), + temporaryFolder.newFolder().toPath(), + configuration); + shuffleCluster.start(); + + LOG.info("========== Test started =========="); + } + + @After + public void teardown() throws Exception { + LOG.info("========== Test end =========="); + + try { + shuffleCluster.shutdown(); + } catch (Exception e) { + LOG.info("Failed to stop shuffle cluster", e); + } + + try { + zkCluster.shutdown(); + } catch (Exception e) { + LOG.info("Failed to stop zk cluster", e); + } + } + + @Test + public void testRevokeAndGrantLeadership() throws Exception { + JobID jobId = new JobID(); + + org.apache.flink.configuration.Configuration flinkConfiguration = + new org.apache.flink.configuration.Configuration(); + flinkConfiguration.set(WebOptions.TIMEOUT, 15000L); + flinkConfiguration.set(PipelineOptionsInternal.PIPELINE_FIXED_JOB_ID, jobId.toHexString()); + + FlinkLocalClusterResource resource = new FlinkLocalClusterResource(flinkConfiguration); + ScheduledExecutorService scheduledExecutor = Executors.newSingleThreadScheduledExecutor(); + try { + FlinkLocalCluster flinkCluster = resource.getFlinkLocalCluster(); + JobForShuffleTesting job = new JobForShuffleTesting(flinkCluster); + job.planOperation( + JobForShuffleTesting.STAGE0_NAME, + 0, + 0, + JobForShuffleTesting.TaskStat.RUNNING, + (processIdAndStat) -> { + CuratorFramework zkClient = + ZooKeeperTestUtils.createZKClientForRemoteShuffle( + ZooKeeperTestUtils.createZooKeeperHAConfig( + zkCluster.getConnect())); + // To cause ShuffleManager to lost its leadership. + zkClient.delete() + .forPath( + ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID.defaultValue() + + ZooKeeperHaServices + .SHUFFLE_MANAGER_LEADER_RETRIEVAL_PATH); + + // One shuffle worker down after 5s + Thread.sleep(5000); + shuffleCluster.killShuffleWorker(0); + + // Wait till it heartbeat timeout + Thread.sleep(10000); + }); + + job.run(); + checkFlinkResourceReleased(flinkCluster); + checkShuffleResourceReleased(); + } catch (Throwable throwable) { + resource.getFlinkLocalCluster().printProcessLog(); + shuffleCluster.printProcessLog(); + throw new AssertionError("Test failure.", throwable); + } finally { + scheduledExecutor.shutdownNow(); + resource.close(); + } + } + + private void checkShuffleResourceReleased() throws Exception { + shuffleCluster.checkStorageResourceReleased(); + shuffleCluster.checkNetworkReleased(); + shuffleCluster.checkBuffersReleased(); + } + + private void checkFlinkResourceReleased(FlinkLocalCluster flinkLocalCluster) throws Exception { + flinkLocalCluster.checkResourceReleased(); + } + + private class FlinkLocalClusterResource implements AutoCloseable { + + private final FlinkLocalCluster flinkLocalCluster; + + public FlinkLocalClusterResource( + org.apache.flink.configuration.Configuration flinkConfiguration) throws Exception { + flinkLocalCluster = + new FlinkLocalCluster( + logDir.getAbsolutePath(), + JobForShuffleTesting.PARALLELISM, + temporaryFolder, + zkCluster.getConnect(), + flinkConfiguration); + flinkLocalCluster.start(); + } + + public FlinkLocalCluster getFlinkLocalCluster() { + return flinkLocalCluster; + } + + @Override + public void close() throws Exception { + flinkLocalCluster.shutdown(); + } + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/ShuffleManagerLostE2ETest.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/ShuffleManagerLostE2ETest.java new file mode 100644 index 00000000..89b799d6 --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/ShuffleManagerLostE2ETest.java @@ -0,0 +1,369 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerMetricKeys; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerMetrics; +import com.alibaba.flink.shuffle.core.config.HeartbeatOptions; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.e2e.flinkcluster.FlinkLocalCluster; +import com.alibaba.flink.shuffle.e2e.shufflecluster.LocalShuffleCluster; +import com.alibaba.flink.shuffle.e2e.zookeeper.ZooKeeperTestEnvironment; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.time.Deadline; +import org.apache.flink.configuration.PipelineOptionsInternal; +import org.apache.flink.configuration.WebOptions; + +import org.apache.commons.io.FileUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.TestName; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.fail; + +/** Tests for scenario of ShuffleManager lost. */ +public class ShuffleManagerLostE2ETest { + + private static final Logger LOG = LoggerFactory.getLogger(ShuffleWorkerLostE2ETest.class); + + @ClassRule public static TemporaryFolder temporaryFolder = new TemporaryFolder(); + + @Rule public TestName name = new TestName(); + + private File logDir; + + private ZooKeeperTestEnvironment zkCluster; + + private LocalShuffleCluster localShuffleCluster; + + @Before + public void setup() throws Exception { + zkCluster = new ZooKeeperTestEnvironment(1); + + logDir = + new File( + System.getProperty("buildDirectory") + + "/" + + getClass().getSimpleName() + + "-" + + name.getMethodName()); + if (logDir.exists()) { + FileUtils.deleteDirectory(logDir); + } + + Configuration configuration = new Configuration(); + configuration.setDuration( + HeartbeatOptions.HEARTBEAT_WORKER_INTERVAL, Duration.ofSeconds(5)); + configuration.setDuration( + HeartbeatOptions.HEARTBEAT_WORKER_TIMEOUT, Duration.ofSeconds(20)); + configuration.setDuration(HeartbeatOptions.HEARTBEAT_JOB_INTERVAL, Duration.ofSeconds(5)); + configuration.setDuration(HeartbeatOptions.HEARTBEAT_JOB_TIMEOUT, Duration.ofSeconds(30)); + localShuffleCluster = + new LocalShuffleCluster( + logDir.getAbsolutePath(), + 2, + zkCluster.getConnect(), + temporaryFolder.newFolder().toPath(), + configuration); + localShuffleCluster.start(); + + LOG.info("========== Test started =========="); + } + + @After + public void teardown() throws Exception { + LOG.info("========== Test end =========="); + + try { + localShuffleCluster.shutdown(); + } catch (Exception e) { + LOG.info("Failed to stop the local shuffle cluster", e); + } + + try { + zkCluster.shutdown(); + } catch (Exception e) { + LOG.info("Failed to stop zk cluster", e); + } + } + + @Test + public void testShuffleManagerRecoveredAfterLostOnWrite() throws Exception { + JobID jobId = new JobID(); + + org.apache.flink.configuration.Configuration flinkConfiguration = + new org.apache.flink.configuration.Configuration(); + flinkConfiguration.set(WebOptions.TIMEOUT, 15000L); + flinkConfiguration.set(PipelineOptionsInternal.PIPELINE_FIXED_JOB_ID, jobId.toHexString()); + + try (FlinkLocalClusterResource resource = + new FlinkLocalClusterResource(flinkConfiguration)) { + JobForShuffleTesting job = new JobForShuffleTesting(resource.getFlinkLocalCluster()); + ScheduledExecutorService scheduledExecutor = + Executors.newSingleThreadScheduledExecutor(); + + try { + job.planOperation( + JobForShuffleTesting.STAGE0_NAME, + 0, + 0, + JobForShuffleTesting.TaskStat.RUNNING, + (processIdAndStat) -> { + localShuffleCluster.killShuffleManager(); + + scheduledExecutor.schedule( + () -> { + try { + localShuffleCluster.recoverShuffleManager(); + } catch (Exception e) { + LOG.error("Failed to recover shuffle manager", e); + } + }, + 10, + TimeUnit.SECONDS); + }); + + try { + job.run(); + checkFlinkResourceReleased(resource.getFlinkLocalCluster()); + checkShuffleResourceReleased(); + } catch (Exception e) { + // Ignored + } + } finally { + scheduledExecutor.shutdownNow(); + } + } + } + + /** + * Tests the shuffle manager get restarted and during this period the job just do not have + * requests. If so, the job won't failover due to fatal error. After the shuffle manager + * restarted, it would not notify client to remove partitions for some time since it would have + * some time to synchronize with the shuffle workers. Therefore, the job will continue to run + * until success. + * + * @throws Exception + */ + @Test + public void testShuffleManagerRecoveredAfterLostWithoutJobFailover() throws Exception { + org.apache.flink.configuration.Configuration flinkConfiguration = + new org.apache.flink.configuration.Configuration(); + + ExecutorService scheduledExecutor = + Executors.newFixedThreadPool(JobForShuffleTesting.PARALLELISM); + try (FlinkLocalClusterResource resource = + new FlinkLocalClusterResource(flinkConfiguration)) { + JobForShuffleTesting job = new JobForShuffleTesting(resource.getFlinkLocalCluster()); + + AtomicInteger stage1LatchStarted = new AtomicInteger(0); + CountDownLatch shuffleManagerRestartedLatch = new CountDownLatch(1); + + for (int i = 0; i < JobForShuffleTesting.PARALLELISM; ++i) { + job.planAsyncOperation( + JobForShuffleTesting.STAGE1_NAME, + i, + 0, + JobForShuffleTesting.TaskStat.RUNNING, + (processIdAndStat, nodeCache) -> { + CompletableFuture resultFuture = new CompletableFuture<>(); + if (processIdAndStat.getStat() + == JobForShuffleTesting.TaskStat.RUNNING) { + scheduledExecutor.execute( + () -> { + try { + int count = stage1LatchStarted.incrementAndGet(); + LOG.info("Increase the count to {}", count); + if (count == JobForShuffleTesting.PARALLELISM) { + LOG.info( + "Increase the count to {} and restart the shuffle manager", + count); + + localShuffleCluster.killShuffleManager(); + Thread.sleep(5000); + + localShuffleCluster.recoverShuffleManager(); + Thread.sleep(5000); + + shuffleManagerRestartedLatch.countDown(); + } + + shuffleManagerRestartedLatch.await(); + resultFuture.complete(null); + } catch (Exception e) { + resultFuture.completeExceptionally(e); + } + }); + } else { + resultFuture.complete(null); + } + return resultFuture; + }); + } + + // The job should finished without failover + job.run(); + } finally { + scheduledExecutor.shutdownNow(); + } + } + + @Test + public void testPartitionCleanupOnJobFinished() throws Exception { + org.apache.flink.configuration.Configuration flinkConfiguration = + new org.apache.flink.configuration.Configuration(); + try (FlinkLocalClusterResource resource = + new FlinkLocalClusterResource(flinkConfiguration)) { + JobForShuffleTesting job = new JobForShuffleTesting(resource.getFlinkLocalCluster()); + job.run(); + } + + LOG.info("Job finished, now check the cleanup"); + // Now let's keep check till all the jobs are cleanup + waitTilAllResultPartitionsReleased(80); + } + + @Test + public void testPartitionCleanupOnJobKilled() throws Exception { + org.apache.flink.configuration.Configuration flinkConfiguration = + new org.apache.flink.configuration.Configuration(); + flinkConfiguration.setString(HeartbeatOptions.HEARTBEAT_JOB_INTERVAL.key(), "5s"); + flinkConfiguration.setString(HeartbeatOptions.HEARTBEAT_JOB_TIMEOUT.key(), "30s"); + FlinkLocalClusterResource resource = new FlinkLocalClusterResource(flinkConfiguration); + JobForShuffleTesting job = new JobForShuffleTesting(resource.getFlinkLocalCluster()); + + job.planOperation( + JobForShuffleTesting.STAGE2_NAME, + 0, + 0, + JobForShuffleTesting.TaskStat.OPENED, + (taskStat) -> resource.getFlinkLocalCluster().shutdown()); + + try { + job.run(); + fail("The job should be failed deu to get killed"); + } catch (Exception e) { + // ignored + } + + LOG.info("Job killed, now check the cleanup"); + // Now let's keep check till all the jobs are cleanup + waitTilAllResultPartitionsReleased(80); + } + + private void waitTilAllResultPartitionsReleased(int timeoutInSeconds) + throws ExecutionException, InterruptedException { + // Now let's keep check till all the jobs are cleanup + Deadline deadline = Deadline.fromNow(Duration.ofSeconds(timeoutInSeconds)); + while (true) { + Thread.sleep(5000); + + List jobIds = + localShuffleCluster.shuffleManagerClient.listJobs(false).get(); + LOG.info("Check jobs and get {}", jobIds); + if (jobIds.size() > 0) { + if (!deadline.hasTimeLeft()) { + fail("There is still jobs left: " + jobIds); + } else { + continue; + } + } + + Map shuffleWorkerMetricMap = + localShuffleCluster.shuffleManagerClient.getShuffleWorkerMetrics().get(); + LOG.info("Check workers and get {}", shuffleWorkerMetricMap); + boolean hasRemainingPartitions = + shuffleWorkerMetricMap.entrySet().stream() + .anyMatch( + entry -> + entry.getValue() + .getIntegerMetric( + ShuffleWorkerMetricKeys + .DATA_PARTITION_NUMBERS_KEY) + > 0); + if (hasRemainingPartitions) { + if (!deadline.hasTimeLeft()) { + fail("There is still partitions left: " + shuffleWorkerMetricMap); + } else { + continue; + } + } + + // Here it means all the result partitions are released; + break; + } + } + + private void checkShuffleResourceReleased() throws Exception { + localShuffleCluster.checkStorageResourceReleased(); + localShuffleCluster.checkNetworkReleased(); + localShuffleCluster.checkBuffersReleased(); + } + + private void checkFlinkResourceReleased(FlinkLocalCluster flinkLocalCluster) throws Exception { + flinkLocalCluster.checkResourceReleased(); + } + + private class FlinkLocalClusterResource implements AutoCloseable { + + private final FlinkLocalCluster flinkLocalCluster; + + public FlinkLocalClusterResource( + org.apache.flink.configuration.Configuration flinkConfiguration) throws Exception { + flinkLocalCluster = + new FlinkLocalCluster( + logDir.getAbsolutePath(), + JobForShuffleTesting.PARALLELISM, + temporaryFolder, + zkCluster.getConnect(), + flinkConfiguration); + flinkLocalCluster.start(); + } + + public FlinkLocalCluster getFlinkLocalCluster() { + return flinkLocalCluster; + } + + @Override + public void close() throws Exception { + flinkLocalCluster.shutdown(); + } + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/ShuffleWorkerLostE2ETest.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/ShuffleWorkerLostE2ETest.java new file mode 100644 index 00000000..4d8372dc --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/ShuffleWorkerLostE2ETest.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.coordinator.manager.JobDataPartitionDistribution; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerRegistration; +import com.alibaba.flink.shuffle.core.config.HeartbeatOptions; +import com.alibaba.flink.shuffle.core.config.WorkerOptions; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.e2e.flinkcluster.FlinkLocalCluster; +import com.alibaba.flink.shuffle.e2e.shufflecluster.LocalShuffleCluster; +import com.alibaba.flink.shuffle.plugin.utils.IdMappingUtils; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.types.IntValue; + +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.nio.file.Path; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +import static org.apache.flink.api.common.restartstrategy.RestartStrategies.fixedDelayRestart; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** Tests for scenario of ShuffleWorker lost. */ +public class ShuffleWorkerLostE2ETest extends AbstractInstableE2ETest { + + @Test + public void testShuffleWorkerLostOnWrite() throws Exception { + JobForShuffleTesting job = new JobForShuffleTesting(flinkCluster); + job.setExecutionConfigModifier( + conf -> conf.setRestartStrategy(fixedDelayRestart(5, Time.seconds(15)))); + + job.planOperation( + JobForShuffleTesting.STAGE0_NAME, + 0, + 0, + JobForShuffleTesting.TaskStat.RUNNING, + processIdAndStat -> { + int workerIndex = findWritingWorkerIndex(processIdAndStat); + shuffleCluster.killShuffleWorker(workerIndex); + }); + + try { + job.run(); + checkFlinkResourceReleased(); + checkShuffleResourceReleased(); + } catch (Throwable t) { + flinkCluster.printProcessLog(); + shuffleCluster.printProcessLog(); + ExceptionUtils.rethrowAsRuntimeException(t); + } + } + + @Test + public void testShuffleWorkerGetLostOnRead() throws Exception { + JobForShuffleTesting job = new JobForShuffleTesting(flinkCluster); + job.setExecutionConfigModifier( + conf -> conf.setRestartStrategy(fixedDelayRestart(15, Time.seconds(15)))); + + IntValue workerIndexToKill = new IntValue(-1); + job.planOperation( + JobForShuffleTesting.STAGE0_NAME, + 0, + 0, + JobForShuffleTesting.TaskStat.RUNNING, + processIdAndStat -> { + if (processIdAndStat.getStat() == JobForShuffleTesting.TaskStat.RUNNING) { + workerIndexToKill.setValue(findWritingWorkerIndex(processIdAndStat)); + } + }); + + job.planOperation( + JobForShuffleTesting.STAGE1_NAME, + 0, + 0, + JobForShuffleTesting.TaskStat.RUNNING, + processIdAndStat -> { + if (processIdAndStat.getStat() == JobForShuffleTesting.TaskStat.RUNNING) { + assertTrue( + "The worker index written is not recorded", + workerIndexToKill.getValue() >= 0); + shuffleCluster.killShuffleWorker(workerIndexToKill.getValue()); + } + }); + + try { + job.run(); + checkFlinkResourceReleased(); + checkShuffleResourceReleased(); + } catch (Throwable t) { + flinkCluster.printProcessLog(); + shuffleCluster.printProcessLog(); + ExceptionUtils.rethrowAsRuntimeException(t); + } + } + + @Test + public void testShuffleWorkerRecoveredAfterLostOnRead() throws Exception { + JobForShuffleTesting job = new JobForShuffleTesting(flinkCluster); + job.setExecutionConfigModifier( + conf -> conf.setRestartStrategy(fixedDelayRestart(10, Time.seconds(15)))); + + ScheduledExecutorService scheduledExecutor = Executors.newSingleThreadScheduledExecutor(); + + IntValue workerIndexToKill = new IntValue(-1); + job.planOperation( + JobForShuffleTesting.STAGE0_NAME, + 0, + 0, + JobForShuffleTesting.TaskStat.RUNNING, + processIdAndStat -> { + if (processIdAndStat.getStat() == JobForShuffleTesting.TaskStat.RUNNING) { + workerIndexToKill.setValue(findWritingWorkerIndex(processIdAndStat)); + } + }); + + job.planOperation( + JobForShuffleTesting.STAGE1_NAME, + 0, + 0, + JobForShuffleTesting.TaskStat.RUNNING, + processIdAndStat -> { + if (processIdAndStat.getStat() == JobForShuffleTesting.TaskStat.RUNNING) { + assertTrue( + "The worker index written is not recorded", + workerIndexToKill.getValue() >= 0); + shuffleCluster.killShuffleWorker(workerIndexToKill.getValue()); + + scheduledExecutor.schedule( + () -> { + try { + shuffleCluster.recoverShuffleWorker( + workerIndexToKill.getValue()); + } catch (Exception e) { + e.printStackTrace(); + } + }, + 10, + TimeUnit.SECONDS); + } + }); + + try { + job.run(); + + // check that stage0 no restart, check that stage1 restart once. + for (int i = 0; i < job.getParallelism(); ++i) { + job.checkAttemptsNum(JobForShuffleTesting.STAGE0_NAME, i, 1); + } + checkFlinkResourceReleased(); + checkShuffleResourceReleased(); + } catch (Throwable t) { + flinkCluster.printProcessLog(); + shuffleCluster.printProcessLog(); + ExceptionUtils.rethrowAsRuntimeException(t); + } + } + + private Integer findWritingWorkerIndex( + JobForShuffleTesting.ProcessIDAndTaskStat processIdAndStat) throws Exception { + List jobs = shuffleCluster.shuffleManagerClient.listJobs(false).get(); + + assertEquals(1, jobs.size()); + + JobDataPartitionDistribution distribution = + shuffleCluster + .shuffleManagerClient + .getJobDataPartitionDistribution(jobs.get(0)) + .get(); + + MapPartitionID mapPartitionId = + IdMappingUtils.fromFlinkResultPartitionID(processIdAndStat.resultPartitionID); + Optional> foundWorker = + distribution.getDataPartitionDistribution().entrySet().stream() + .filter(e -> e.getKey().getDataPartitionId().equals(mapPartitionId)) + .findFirst(); + assertTrue( + "The produced data partition is not found in " + + distribution.getDataPartitionDistribution(), + foundWorker.isPresent()); + + Optional workerIndex = + shuffleCluster.findShuffleWorker(foundWorker.get().getValue().getProcessID()); + assertTrue(workerIndex.isPresent()); + + return workerIndex.get(); + } + + private void checkShuffleResourceReleased() throws Exception { + shuffleCluster.checkStorageResourceReleased(); + shuffleCluster.checkNetworkReleased(); + shuffleCluster.checkBuffersReleased(); + } + + private void checkFlinkResourceReleased() throws Exception { + flinkCluster.checkResourceReleased(); + } + + @Override + protected LocalShuffleCluster createLocalShuffleCluster( + String logPath, String zkConnect, Path dataPath) { + Configuration conf = new Configuration(); + conf.setDuration(HeartbeatOptions.HEARTBEAT_WORKER_INTERVAL, Duration.ofSeconds(5)); + conf.setDuration(HeartbeatOptions.HEARTBEAT_WORKER_TIMEOUT, Duration.ofSeconds(20)); + return new LocalShuffleCluster(logPath, 2, zkConnect, dataPath, conf); + } + + @Override + protected FlinkLocalCluster createFlinkCluster( + String logPath, TemporaryFolder tmpFolder, String zkConnect) throws Exception { + org.apache.flink.configuration.Configuration conf = + new org.apache.flink.configuration.Configuration(); + conf.setString(WorkerOptions.MAX_WORKER_RECOVER_TIME.key(), "80s"); + return new FlinkLocalCluster(logPath, 2, tmpFolder, zkConnect, conf); + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/TMLostE2ETest.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/TMLostE2ETest.java new file mode 100644 index 00000000..6a976918 --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/TMLostE2ETest.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e; + +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.e2e.JobForShuffleTesting.TaskStat; + +import org.junit.Test; + +import static com.alibaba.flink.shuffle.e2e.JobForShuffleTesting.DataScale.NORMAL; +import static com.alibaba.flink.shuffle.e2e.JobForShuffleTesting.STAGE0_NAME; +import static com.alibaba.flink.shuffle.e2e.JobForShuffleTesting.STAGE1_NAME; +import static com.alibaba.flink.shuffle.e2e.JobForShuffleTesting.STAGE2_NAME; + +/** Test for scenario of TM lost. */ +public class TMLostE2ETest extends AbstractInstableE2ETest { + + @Test + public void testTMLostAndTaskFailureWhenShuffleWrite() throws Exception { + try { + JobForShuffleTesting job = new JobForShuffleTesting(flinkCluster, NORMAL); + planKillingTaskManagerInRunningTask(job, STAGE1_NAME, 0, 0); + job.planFailingARunningTask(STAGE1_NAME, 0, 1); + job.run(); + + flinkCluster.checkResourceReleased(); + shuffleCluster.checkResourceReleased(); + } catch (Throwable t) { + flinkCluster.printProcessLog(); + shuffleCluster.printProcessLog(); + ExceptionUtils.rethrowException(t); + } + } + + @Test + public void testTMLostWithShuffleWrite() throws Exception { + testTMLost(STAGE0_NAME); + } + + @Test + public void testTMLostWithShuffleWriteAndRead() throws Exception { + testTMLost(STAGE1_NAME); + } + + @Test + public void testTMLostWithShuffleRead() throws Exception { + testTMLost(STAGE2_NAME); + } + + private void testTMLost(String stageName) throws Exception { + try { + JobForShuffleTesting job = new JobForShuffleTesting(flinkCluster, NORMAL); + planKillingTaskManagerInRunningTask(job, stageName, 0, 0); + + job.run(); + + // other stages' tasks should not restart. + flinkCluster.checkResourceReleased(); + shuffleCluster.checkResourceReleased(); + } catch (Throwable t) { + flinkCluster.printProcessLog(); + shuffleCluster.printProcessLog(); + ExceptionUtils.rethrowException(t); + } + } + + private void planKillingTaskManagerInRunningTask( + JobForShuffleTesting job, String stageName, int taskID, int attemptID) + throws Exception { + job.planOperation( + stageName, + taskID, + attemptID, + TaskStat.RUNNING, + info -> Runtime.getRuntime().exec("kill -9 " + info.getProcessID())); + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/TestJvmProcess.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/TestJvmProcess.java new file mode 100644 index 00000000..5f84156e --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/TestJvmProcess.java @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e; + +import com.alibaba.flink.shuffle.e2e.utils.CommonTestUtils; + +import org.apache.flink.configuration.MemorySize; +import org.apache.flink.util.ShutdownHookUtil; + +import org.apache.commons.io.FileUtils; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.lang3.RandomStringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.StringWriter; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.Arrays; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; +import static org.junit.Assert.fail; + +/** A {@link Process} running a separate JVM. */ +public abstract class TestJvmProcess { + + private static final Logger LOG = LoggerFactory.getLogger(TestJvmProcess.class); + + /** Lock to guard {@link #startProcess()} and {@link #destroy()} calls. */ + private final Object createDestroyLock = new Object(); + + /** The java command path. */ + private final String javaCommandPath; + + /** The log4j configuration path. */ + private final String log4jConfigFilePath; + + private final String logFilePath; + + /** Shutdown hook for resource cleanup. */ + private final Thread shutdownHook; + + /** JVM process memory (set for both '-Xms' and '-Xmx'). */ + private int jvmHeapMemoryInMb = 80; + + private int jvmDirectMemoryInMb = 80; + + /** The JVM process. */ + private volatile Process process; + + private final String name; + + /** Writer for the process output. */ + private volatile StringWriter processOutput; + + /** flag to mark the process as already destroyed. */ + private volatile boolean destroyed; + + public TestJvmProcess(String name, String logDirName) throws Exception { + this( + CommonTestUtils.getJavaCommandPath(), + CommonTestUtils.createTemporaryLog4JProperties().getPath(), + name, + logDirName); + } + + public TestJvmProcess( + String javaCommandPath, String log4jConfigFilePath, String name, String logDir) { + this.javaCommandPath = checkNotNull(javaCommandPath); + this.log4jConfigFilePath = checkNotNull(log4jConfigFilePath); + this.name = name; + this.logFilePath = logDir + "/" + name + "-" + RandomStringUtils.randomAlphabetic(4); + + LOG.info("JVM Process {} will write log into file {}", getName(), logFilePath); + + this.shutdownHook = + new Thread( + new Runnable() { + @Override + public void run() { + try { + destroy(); + } catch (Throwable t) { + LOG.error("Error during process cleanup shutdown hook.", t); + } + } + }); + } + + /** Returns the name of the process. */ + public String getName() { + return name; + } + + /** + * Returns the arguments to the JVM. + * + *

These can be parsed by the main method of the entry point class. + */ + public abstract String[] getJvmArgs(); + + /** + * Returns the name of the class to run. + * + *

Arguments to the main method can be specified via {@link #getJvmArgs()}. + */ + public abstract String getEntryPointClassName(); + + // --------------------------------------------------------------------------------------------- + + /** + * Sets the memory for the process (-Xms and -Xmx flags) (>= 80). + * + * @param jvmHeapMemoryInMb Amount of memory in Megabytes for the JVM (>= 80). + */ + public void setJVMHeapMemory(int jvmHeapMemoryInMb) { + checkArgument(jvmHeapMemoryInMb >= 80, "Process JVM Requires at least 80 MBs of memory."); + checkState(process == null, "Cannot set memory after process was started"); + + this.jvmHeapMemoryInMb = jvmHeapMemoryInMb; + } + + public void setJvmDirectMemory(int jvmDirectMemoryInMb) { + checkState(process == null, "Cannot set memory after process was started"); + + this.jvmDirectMemoryInMb = jvmDirectMemoryInMb; + } + + /** + * Creates and starts the {@link Process}. + * + *

Important: Don't forget to call {@link #destroy()} to prevent resource + * leaks. The created process will be child process and is not guaranteed to terminate when the + * parent process terminates. + */ + public void startProcess() throws IOException { + String[] cmd = + new String[] { + javaCommandPath, + "-Dlog.level=DEBUG", + "-Dlog4j.configurationFile=file:" + log4jConfigFilePath, + "-Dlog.file=" + logFilePath, + "-Xms" + jvmHeapMemoryInMb + "m", + "-Xmx" + jvmHeapMemoryInMb + "m", + "-XX:MaxDirectMemorySize=" + + MemorySize.ofMebiBytes(jvmDirectMemoryInMb).getBytes(), + "-classpath", + CommonTestUtils.getCurrentClasspath(), + getEntryPointClassName() + }; + + String[] jvmArgs = getJvmArgs(); + + if (jvmArgs != null && jvmArgs.length > 0) { + cmd = ArrayUtils.addAll(cmd, jvmArgs); + } + + synchronized (createDestroyLock) { + checkState(process == null, "process already started"); + + LOG.info("Running command '{}'.", Arrays.toString(cmd)); + this.process = new ProcessBuilder(cmd).start(); + + // Forward output + this.processOutput = new StringWriter(); + new CommonTestUtils.PipeForwarder(process.getErrorStream(), processOutput); + + try { + // Add JVM shutdown hook to call shutdown of service + Runtime.getRuntime().addShutdownHook(shutdownHook); + } catch (IllegalStateException ignored) { + // JVM is already shutting down. No need to do this. + } catch (Throwable t) { + LOG.error("Cannot register process cleanup shutdown hook.", t); + } + } + } + + public void printProcessLog() { + checkState(processOutput != null, "not started"); + + System.out.println("-----------------------------------------"); + System.out.println(" BEGIN SPAWNED PROCESS LOG FOR " + getName()); + System.out.println("-----------------------------------------"); + + String out = null; + try { + out = new String(FileUtils.readFileToByteArray(new File(logFilePath))); + } catch (IOException e) { + e.printStackTrace(); + } + + if (out == null || out.length() == 0) { + System.out.println("(EMPTY)"); + } else { + System.out.println(out); + } + + System.out.println("-----------------------------------------"); + System.out.println(" END SPAWNED PROCESS LOG " + getName()); + System.out.println("-----------------------------------------"); + } + + public void destroy() { + synchronized (createDestroyLock) { + checkState(process != null, "process not started"); + + if (destroyed) { + // already done + return; + } + + LOG.info("Destroying " + getName() + " process."); + + try { + // try to call "destroyForcibly()" on Java 8 + boolean destroyed = false; + try { + Method m = process.getClass().getMethod("destroyForcibly"); + m.setAccessible(true); + m.invoke(process); + destroyed = true; + } catch (NoSuchMethodException ignored) { + // happens on Java 7 + } catch (Throwable t) { + LOG.error("Failed to forcibly destroy process", t); + } + + // if it was not destroyed, call the regular destroy method + if (!destroyed) { + try { + process.destroy(); + } catch (Throwable t) { + LOG.error("Error while trying to destroy process.", t); + } + } + } finally { + destroyed = true; + ShutdownHookUtil.removeShutdownHook(shutdownHook, getClass().getSimpleName(), LOG); + } + } + } + + public String getProcessOutput() { + if (processOutput != null) { + return processOutput.toString(); + } else { + return null; + } + } + + /** + * Gets the process ID, if possible. This method currently only work on UNIX-based operating + * systems. On others, it returns {@code -1}. + * + * @return The process ID, or -1, if the ID cannot be determined. + */ + public long getProcessId() { + checkState(process != null, "process not started"); + + try { + Class clazz = process.getClass(); + if (clazz.getName().equals("java.lang.UNIXProcess")) { + Field pidField = clazz.getDeclaredField("pid"); + pidField.setAccessible(true); + return pidField.getLong(process); + } else if (clazz.getName().equals("java.lang.ProcessImpl")) { + Method pid = clazz.getDeclaredMethod("pid"); + pid.setAccessible(true); + return (long) pid.invoke(process); + } else { + return -1; + } + } catch (Throwable ignored) { + return -1; + } + } + + public boolean isAlive() { + if (destroyed) { + return false; + } else { + try { + // the method throws an exception as long as the + // process is alive + process.exitValue(); + return false; + } catch (IllegalThreadStateException ignored) { + // thi + return true; + } + } + } + + public void waitFor() throws InterruptedException { + Process process = this.process; + if (process != null) { + process.waitFor(); + } else { + throw new IllegalStateException("process not started"); + } + } + + // --------------------------------------------------------------------------------------------- + // File based synchronization utilities + // --------------------------------------------------------------------------------------------- + + public static void touchFile(File file) throws IOException { + if (!file.exists()) { + new FileOutputStream(file).close(); + } + if (!file.setLastModified(System.currentTimeMillis())) { + throw new IOException("Could not touch the file."); + } + } + + public static void waitForMarkerFile(File file, long timeoutMillis) + throws InterruptedException { + final long deadline = System.nanoTime() + timeoutMillis * 1_000_000; + + boolean exists; + while (!(exists = file.exists()) && System.nanoTime() < deadline) { + Thread.sleep(10); + } + + if (!exists) { + fail("The marker file was not found within " + timeoutMillis + " msecs"); + } + } + + public static void waitForMarkerFiles(File basedir, String prefix, int num, long timeout) { + long now = System.currentTimeMillis(); + final long deadline = now + timeout; + + while (now < deadline) { + boolean allFound = true; + + for (int i = 0; i < num; i++) { + File nextToCheck = new File(basedir, prefix + i); + if (!nextToCheck.exists()) { + allFound = false; + break; + } + } + + if (allFound) { + return; + } else { + // not all found, wait for a bit + try { + Thread.sleep(10); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + now = System.currentTimeMillis(); + } + } + + fail("The tasks were not started within time (" + timeout + "msecs)"); + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/TestingListener.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/TestingListener.java new file mode 100644 index 00000000..7095ec59 --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/TestingListener.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e; + +import org.apache.flink.runtime.leaderelection.LeaderInformation; +import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalListener; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.UUID; + +/** + * Test {@link LeaderRetrievalListener} implementation which offers some convenience functions for + * testing purposes. + */ +public class TestingListener extends TestingRetrievalBase implements LeaderRetrievalListener { + + private static final Logger LOG = LoggerFactory.getLogger(TestingListener.class); + + @Override + public void notifyLeaderAddress(String leaderAddress, UUID leaderSessionID) { + LOG.debug( + "Notified about new leader address {} with session ID {}.", + leaderAddress, + leaderSessionID); + if (leaderAddress == null && leaderSessionID == null) { + offerToLeaderQueue(LeaderInformation.empty()); + } else { + offerToLeaderQueue(LeaderInformation.known(leaderSessionID, leaderAddress)); + } + } + + @Override + public void handleError(Exception exception) { + super.handleError(exception); + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/TestingRetrievalBase.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/TestingRetrievalBase.java new file mode 100644 index 00000000..a14a4f1c --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/TestingRetrievalBase.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e; + +import com.alibaba.flink.shuffle.e2e.utils.CommonTestUtils; + +import org.apache.flink.api.common.time.Deadline; +import org.apache.flink.runtime.leaderelection.LeaderInformation; +import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalListener; +import org.apache.flink.util.ExceptionUtils; + +import javax.annotation.Nullable; + +import java.time.Duration; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** + * Base class which provides some convenience functions for testing purposes of {@link + * LeaderRetrievalListener} and {@link + * org.apache.flink.runtime.leaderretrieval.LeaderRetrievalEventHandler}. + */ +public class TestingRetrievalBase { + + private final BlockingQueue leaderEventQueue = new LinkedBlockingQueue<>(); + private final BlockingQueue errorQueue = new LinkedBlockingQueue<>(); + + private LeaderInformation leader = LeaderInformation.empty(); + private String oldAddress; + private Throwable error; + + public String waitForNewLeader(long timeout) throws Exception { + throwExceptionIfNotNull(); + + final String errorMsg = + "Listener was not notified about a new leader within " + timeout + "ms"; + CommonTestUtils.waitUntilCondition( + () -> { + leader = leaderEventQueue.poll(timeout, TimeUnit.MILLISECONDS); + return leader != null + && !leader.isEmpty() + && !leader.getLeaderAddress().equals(oldAddress); + }, + Deadline.fromNow(Duration.ofMillis(timeout)), + errorMsg); + + oldAddress = leader.getLeaderAddress(); + + return leader.getLeaderAddress(); + } + + public void waitForEmptyLeaderInformation(long timeout) throws Exception { + throwExceptionIfNotNull(); + + final String errorMsg = + "Listener was not notified about an empty leader within " + timeout + "ms"; + CommonTestUtils.waitUntilCondition( + () -> { + leader = leaderEventQueue.poll(timeout, TimeUnit.MILLISECONDS); + return leader != null && leader.isEmpty(); + }, + Deadline.fromNow(Duration.ofMillis(timeout)), + errorMsg); + + oldAddress = null; + } + + public void waitForError(long timeout) throws Exception { + final String errorMsg = "Listener did not see an exception with " + timeout + "ms"; + CommonTestUtils.waitUntilCondition( + () -> { + error = errorQueue.poll(timeout, TimeUnit.MILLISECONDS); + return error != null; + }, + Deadline.fromNow(Duration.ofMillis(timeout)), + errorMsg); + } + + public void handleError(Throwable ex) { + errorQueue.offer(ex); + } + + public LeaderInformation getLeader() { + return leader; + } + + public String getAddress() { + return leader.getLeaderAddress(); + } + + public UUID getLeaderSessionID() { + return leader.getLeaderSessionID(); + } + + public void offerToLeaderQueue(LeaderInformation leaderInformation) { + leaderEventQueue.offer(leaderInformation); + this.leader = leaderInformation; + } + + public int getLeaderEventQueueSize() { + return leaderEventQueue.size(); + } + + /** + * Please use {@link #waitForError} before get the error. + * + * @return the error has been handled. + */ + @Nullable + public Throwable getError() { + return this.error; + } + + private void throwExceptionIfNotNull() throws Exception { + if (error != null) { + ExceptionUtils.rethrowException(error); + } + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/flinkcluster/DispatcherProcess.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/flinkcluster/DispatcherProcess.java new file mode 100644 index 00000000..5fbe1d05 --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/flinkcluster/DispatcherProcess.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e.flinkcluster; + +import com.alibaba.flink.shuffle.e2e.TestJvmProcess; + +import org.apache.flink.api.java.utils.ParameterTool; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.JobManagerOptions; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.runtime.dispatcher.Dispatcher; +import org.apache.flink.runtime.entrypoint.ClusterEntrypoint; +import org.apache.flink.runtime.entrypoint.StandaloneSessionClusterEntrypoint; +import org.apache.flink.runtime.jobmanager.JobManagerProcessSpec; +import org.apache.flink.runtime.jobmanager.JobManagerProcessUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Map; + +/** A {@link Dispatcher} instance running in a separate JVM. */ +public class DispatcherProcess extends TestJvmProcess { + + private static final Logger LOG = LoggerFactory.getLogger(DispatcherProcess.class); + + private final String[] jvmArgs; + + public DispatcherProcess(String logDirName, Configuration config) throws Exception { + super("Dispatcher", logDirName); + ArrayList args = new ArrayList<>(); + + for (Map.Entry entry : config.toMap().entrySet()) { + args.add("--" + entry.getKey()); + args.add(entry.getValue()); + } + + this.jvmArgs = new String[args.size()]; + args.toArray(jvmArgs); + + final JobManagerProcessSpec processSpec = + JobManagerProcessUtils.processSpecFromConfigWithNewOptionToInterpretLegacyHeap( + config, JobManagerOptions.JVM_HEAP_MEMORY); + + setJVMHeapMemory(processSpec.getJvmHeapMemorySize().getMebiBytes()); + setJvmDirectMemory(processSpec.getJvmDirectMemorySize().getMebiBytes()); + LOG.info( + "JVM Process {} with process memory spec {}, heap memory {}, direct memory {}", + getName(), + processSpec, + processSpec.getJvmHeapMemorySize().getMebiBytes(), + processSpec.getJvmDirectMemorySize().getMebiBytes()); + } + + @Override + public String[] getJvmArgs() { + return jvmArgs; + } + + @Override + public String getEntryPointClassName() { + return EntryPoint.class.getName(); + } + + /** Entry point for the Dispatcher process. */ + public static class EntryPoint { + + private static final Logger LOG = LoggerFactory.getLogger(EntryPoint.class); + + /** + * Entrypoint of the DispatcherProcessEntryPoint. + * + *

Other arguments are parsed to a {@link Configuration} and passed to the Dispatcher, + * for instance: --high-availability ZOOKEEPER --high-availability.zookeeper.quorum + * "xyz:123:456". + */ + public static void main(String[] args) { + try { + ParameterTool params = ParameterTool.fromArgs(args); + Configuration config = params.getConfiguration(); + LOG.info("Configuration: {}.", config); + + config.setInteger(JobManagerOptions.PORT, 0); + config.setString(RestOptions.BIND_PORT, "0"); + + final StandaloneSessionClusterEntrypoint clusterEntrypoint = + new StandaloneSessionClusterEntrypoint(config); + + ClusterEntrypoint.runClusterEntrypoint(clusterEntrypoint); + + } catch (Throwable t) { + t.printStackTrace(); + LOG.error("Failed to start Dispatcher process", t); + System.exit(1); + } + } + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/flinkcluster/FlinkLocalCluster.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/flinkcluster/FlinkLocalCluster.java new file mode 100644 index 00000000..de17692d --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/flinkcluster/FlinkLocalCluster.java @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e.flinkcluster; + +import com.alibaba.flink.shuffle.core.config.HeartbeatOptions; +import com.alibaba.flink.shuffle.core.config.TransferOptions; +import com.alibaba.flink.shuffle.core.config.WorkerOptions; +import com.alibaba.flink.shuffle.e2e.TestingListener; +import com.alibaba.flink.shuffle.e2e.utils.LogErrorHandler; +import com.alibaba.flink.shuffle.e2e.zookeeper.ZooKeeperTestUtils; +import com.alibaba.flink.shuffle.plugin.RemoteShuffleServiceFactory; +import com.alibaba.flink.shuffle.plugin.config.PluginOptions; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.time.Deadline; +import org.apache.flink.api.common.time.Time; +import org.apache.flink.configuration.AkkaOptions; +import org.apache.flink.configuration.CheckpointingOptions; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.HeartbeatManagerOptions; +import org.apache.flink.configuration.HighAvailabilityOptions; +import org.apache.flink.configuration.JobManagerOptions; +import org.apache.flink.configuration.MemorySize; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.configuration.RestartStrategyOptions; +import org.apache.flink.configuration.TaskManagerOptions; +import org.apache.flink.configuration.WebOptions; +import org.apache.flink.runtime.dispatcher.DispatcherGateway; +import org.apache.flink.runtime.dispatcher.DispatcherId; +import org.apache.flink.runtime.highavailability.HighAvailabilityServices; +import org.apache.flink.runtime.highavailability.HighAvailabilityServicesUtils; +import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService; +import org.apache.flink.runtime.messages.webmonitor.ClusterOverview; +import org.apache.flink.runtime.rpc.RpcService; +import org.apache.flink.runtime.rpc.RpcSystem; +import org.apache.flink.runtime.rpc.RpcUtils; +import org.apache.flink.runtime.shuffle.ShuffleServiceOptions; +import org.apache.flink.table.api.config.ExecutionConfigOptions; +import org.apache.flink.util.concurrent.FutureUtils; +import org.apache.flink.util.concurrent.ScheduledExecutorServiceAdapter; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.time.Duration; +import java.util.Collection; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * A testing Flink local cluster, which contains a dispatcher process and several task manager + * processes. + */ +public class FlinkLocalCluster { + + public final TemporaryFolder temporaryFolder; + + private final String logDirName; + + private Configuration conf; + + private final int numTaskManagers; + + private final TaskManagerProcess[] taskManagers; + + private CuratorFramework zkClient; + + private final String zkConnect; + + private final File zkPath; + + private DispatcherProcess dispatcher; + + private DispatcherGateway dispatcherGateway; + + private HighAvailabilityServices highAvailabilityServices; + + private LeaderRetrievalService leaderRetrievalService; + + private RpcService rpcService; + + public FlinkLocalCluster( + String logDirName, + int numTaskManagers, + TemporaryFolder temporaryFolder, + String zkConnect, + Configuration conf) + throws Exception { + this.logDirName = logDirName; + this.numTaskManagers = numTaskManagers; + this.taskManagers = new TaskManagerProcess[numTaskManagers]; + this.temporaryFolder = temporaryFolder; + this.zkConnect = zkConnect; + this.zkPath = temporaryFolder.newFolder(); + initConfig(conf); + } + + public void start() throws Exception { + this.dispatcher = new DispatcherProcess(logDirName, conf); + this.dispatcher.startProcess(); + this.highAvailabilityServices = + HighAvailabilityServicesUtils.createAvailableOrEmbeddedServices( + conf, Executors.newSingleThreadExecutor(), LogErrorHandler.INSTANCE); + RpcSystem rpcSystem = RpcSystem.load(conf); + rpcService = rpcSystem.remoteServiceBuilder(conf, "localhost", "0").createAndStart(); + for (int i = 0; i < numTaskManagers; i++) { + taskManagers[i] = new TaskManagerProcess(logDirName, conf, i); + taskManagers[i].startProcess(); + } + + Deadline deadline = Deadline.fromNow(Duration.ofMinutes(5)); + Pair pair = waitForDispatcher(deadline); + waitForTaskManagers(numTaskManagers, pair.getLeft(), pair.getRight(), deadline.timeLeft()); + } + + private Pair waitForDispatcher(Deadline deadline) throws Exception { + TestingListener listener = new TestingListener(); + leaderRetrievalService = highAvailabilityServices.getDispatcherLeaderRetriever(); + leaderRetrievalService.start(listener); + + listener.waitForNewLeader(deadline.timeLeft().toMillis()); + return Pair.of(listener.getAddress(), DispatcherId.fromUuid(listener.getLeaderSessionID())); + } + + public void shutdown() throws Exception { + for (int i = 0; i < numTaskManagers; i++) { + taskManagers[i].destroy(); + } + + dispatcher.destroy(); + + leaderRetrievalService.stop(); + + RpcUtils.terminateRpcService(rpcService, Time.seconds(30L)); + + highAvailabilityServices.closeAndCleanupAllData(); + + ZooKeeperTestUtils.deleteAll(zkClient); + zkClient.close(); + } + + public Configuration getConfig() { + return conf; + } + + public CuratorFramework getZKClient() { + return checkNotNull(zkClient); + } + + public String getZKConnect() { + return zkConnect; + } + + public String getZKPath() { + return zkPath.getPath(); + } + + private void initConfig(Configuration c) throws Exception { + conf = ZooKeeperTestUtils.configureZooKeeperHAForFlink(c, zkConnect, zkPath.getPath()); + conf.set(AkkaOptions.ASK_TIMEOUT, "600 s"); + conf.set(WebOptions.TIMEOUT, 600000L); + conf.set(JobManagerOptions.ADDRESS, "localhost"); + conf.set(RestOptions.BIND_PORT, "0"); + conf.set(HeartbeatManagerOptions.HEARTBEAT_INTERVAL, 1000L); + conf.set(HeartbeatManagerOptions.HEARTBEAT_TIMEOUT, 20000L); + conf.set(HighAvailabilityOptions.HA_MODE, "zookeeper"); + conf.set(TaskManagerOptions.NUM_TASK_SLOTS, 1); + conf.set(ExecutionConfigOptions.TABLE_EXEC_SHUFFLE_MODE, "ALL_EDGES_BLOCKING"); + conf.set(TaskManagerOptions.CPU_CORES, 1.0); + String cpPath = temporaryFolder.newFolder().getAbsoluteFile().toURI().toString(); + conf.set(CheckpointingOptions.CHECKPOINTS_DIRECTORY, cpPath); + conf.setInteger(TransferOptions.CLIENT_CONNECT_TIMEOUT.key(), 3); + conf.setString( + com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions.HA_MODE.key(), + "ZOOKEEPER"); + conf.setString( + com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM + .key(), + zkConnect); + conf.setLong(HeartbeatOptions.HEARTBEAT_JOB_INTERVAL.key(), 3000); + String shuffleServiceClassName = RemoteShuffleServiceFactory.class.getName(); + conf.set(ShuffleServiceOptions.SHUFFLE_SERVICE_FACTORY_CLASS, shuffleServiceClassName); + conf.set(RestartStrategyOptions.RESTART_STRATEGY, "fixed-delay"); + conf.set(RestartStrategyOptions.RESTART_STRATEGY_FIXED_DELAY_ATTEMPTS, 100); + conf.set(RestartStrategyOptions.RESTART_STRATEGY_FIXED_DELAY_DELAY, Duration.ofSeconds(1)); + conf.setString(WorkerOptions.MAX_WORKER_RECOVER_TIME.key(), "60s"); + + // JM memory config + conf.set(JobManagerOptions.TOTAL_PROCESS_MEMORY, MemorySize.ofMebiBytes(2048)); + conf.set(JobManagerOptions.JVM_METASPACE, MemorySize.ofMebiBytes(32)); + conf.set(JobManagerOptions.JVM_OVERHEAD_MIN, MemorySize.ofMebiBytes(32)); + conf.set(JobManagerOptions.JVM_HEAP_MEMORY, MemorySize.ofMebiBytes(1536)); + + // TM memory config + conf.setString(PluginOptions.MEMORY_PER_INPUT_GATE.key(), "16m"); + conf.setString(PluginOptions.MEMORY_PER_RESULT_PARTITION.key(), "8m"); + conf.set(TaskManagerOptions.TOTAL_PROCESS_MEMORY, MemorySize.ofMebiBytes(1024)); + conf.set(TaskManagerOptions.NETWORK_MEMORY_MIN, MemorySize.ofMebiBytes(64)); + conf.set(TaskManagerOptions.NETWORK_MEMORY_MAX, MemorySize.ofMebiBytes(64)); + conf.set(TaskManagerOptions.JVM_METASPACE, MemorySize.ofMebiBytes(32)); + conf.set(TaskManagerOptions.JVM_OVERHEAD_MIN, MemorySize.ofMebiBytes(32)); + conf.set(TaskManagerOptions.JVM_OVERHEAD_MAX, MemorySize.ofMebiBytes(32)); + + // ZK config + zkClient = ZooKeeperTestUtils.createZKClientForFlink(conf); + } + + private void waitForTaskManagers( + int numTaskManagers, String addr, DispatcherId dispatcherId, Duration timeLeft) + throws ExecutionException, InterruptedException { + dispatcherGateway = rpcService.connect(addr, dispatcherId, DispatcherGateway.class).get(); + FutureUtils.retrySuccessfulWithDelay( + () -> + dispatcherGateway.requestClusterOverview( + Time.milliseconds(timeLeft.toMillis())), + Time.milliseconds(50L), + org.apache.flink.api.common.time.Deadline.fromNow( + Duration.ofMillis(timeLeft.toMillis())), + overview -> overview.getNumTaskManagersConnected() >= numTaskManagers, + new ScheduledExecutorServiceAdapter( + Executors.newSingleThreadScheduledExecutor())) + .get(); + } + + public int getNumTaskManagersConnected() throws ExecutionException, InterruptedException { + ClusterOverview view = dispatcherGateway.requestClusterOverview(Time.seconds(10L)).get(); + return view.getNumTaskManagersConnected(); + } + + public void killTaskManager(int index) throws Exception { + Process kill = Runtime.getRuntime().exec("kill -9 " + taskManagers[index].getProcessId()); + kill.waitFor(); + } + + public boolean isTaskManagerAlive(int index) { + return taskManagers[index].isAlive(); + } + + public void recoverTaskManager(int index) throws Exception { + checkState(!isTaskManagerAlive(index)); + taskManagers[index] = new TaskManagerProcess(logDirName, conf, index); + taskManagers[index].startProcess(); + } + + public void cancelJobs() throws Exception { + Time timeout = Time.seconds(10); + Collection jobIDs = dispatcherGateway.listJobs(timeout).get(); + jobIDs.forEach(jobID -> dispatcherGateway.cancelJob(jobID, timeout)); + } + + public void checkResourceReleased() throws Exception { + Thread.sleep(3000); // Wait a moment allowing TM to report stats with delay. + for (int i = 0; i < taskManagers.length; ++i) { + if (!taskManagers[i].isAlive()) { + continue; + } + byte[] bytes = zkClient.getData().forPath("/taskmanager-" + i); + TaskManagerProcess.TaskManagerStats stats = + TaskManagerProcess.TaskManagerStats.fromBytes(bytes); + int actual = stats.availableBuffers; + int expect = 2048; // 128M(network memory size) / 32k(page size) + String msg = String.format("TM memory leak, expect=%d, actual=%d", expect, actual); + if (actual != expect) { + throw new AssertionError(msg); + } + } + } + + public void printProcessLog() { + dispatcher.printProcessLog(); + for (TaskManagerProcess worker : taskManagers) { + worker.printProcessLog(); + } + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/flinkcluster/FlinkLocalClusterE2ETest.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/flinkcluster/FlinkLocalClusterE2ETest.java new file mode 100644 index 00000000..93086ceb --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/flinkcluster/FlinkLocalClusterE2ETest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e.flinkcluster; + +import com.alibaba.flink.shuffle.e2e.utils.CommonTestUtils; +import com.alibaba.flink.shuffle.e2e.zookeeper.ZooKeeperTestEnvironment; + +import org.apache.flink.api.common.time.Deadline; +import org.apache.flink.configuration.Configuration; + +import org.apache.commons.io.FileUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.TestName; + +import java.io.File; +import java.time.Duration; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +/** Test for {@link FlinkLocalCluster}. */ +public class FlinkLocalClusterE2ETest { + + @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); + + @Rule public TestName name = new TestName(); + + private ZooKeeperTestEnvironment zkCluster; + + private FlinkLocalCluster cluster; + + @Before + public void setup() throws Exception { + File logDir = + new File( + System.getProperty("buildDirectory") + + "/" + + getClass().getSimpleName() + + "-" + + name.getMethodName()); + if (logDir.exists()) { + FileUtils.deleteDirectory(logDir); + } + + zkCluster = new ZooKeeperTestEnvironment(1); + cluster = + new FlinkLocalCluster( + logDir.getAbsolutePath(), + 2, + temporaryFolder, + zkCluster.getConnect(), + new Configuration()); + } + + @After + public void cleanup() throws Exception { + cluster.shutdown(); + zkCluster.shutdown(); + } + + @Test + public void testRecoverTaskManager() throws Exception { + cluster.start(); + + assertThat(cluster.getNumTaskManagersConnected(), is(2)); + + cluster.killTaskManager(0); + CommonTestUtils.waitUntilCondition( + () -> cluster.getNumTaskManagersConnected() == 1, + Deadline.fromNow(Duration.ofMinutes(5)), + "timeout."); + assertThat(cluster.isTaskManagerAlive(0), is(false)); + + // recover + cluster.recoverTaskManager(0); + CommonTestUtils.waitUntilCondition( + () -> cluster.getNumTaskManagersConnected() == 2, + Deadline.fromNow(Duration.ofMinutes(5)), + "timeout."); + assertThat(cluster.isTaskManagerAlive(0), is(true)); + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/flinkcluster/TaskManagerProcess.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/flinkcluster/TaskManagerProcess.java new file mode 100644 index 00000000..389164b8 --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/flinkcluster/TaskManagerProcess.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e.flinkcluster; + +import com.alibaba.flink.shuffle.e2e.TestJvmProcess; +import com.alibaba.flink.shuffle.e2e.utils.LogErrorHandler; +import com.alibaba.flink.shuffle.plugin.RemoteShuffleEnvironment; + +import org.apache.flink.api.java.utils.ParameterTool; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.TaskManagerOptions; +import org.apache.flink.core.plugin.PluginManager; +import org.apache.flink.core.plugin.PluginUtils; +import org.apache.flink.runtime.blob.BlobCacheService; +import org.apache.flink.runtime.clusterframework.TaskExecutorProcessSpec; +import org.apache.flink.runtime.clusterframework.TaskExecutorProcessUtils; +import org.apache.flink.runtime.clusterframework.types.ResourceID; +import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider; +import org.apache.flink.runtime.heartbeat.HeartbeatServices; +import org.apache.flink.runtime.highavailability.HighAvailabilityServices; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; +import org.apache.flink.runtime.metrics.MetricRegistry; +import org.apache.flink.runtime.rpc.FatalErrorHandler; +import org.apache.flink.runtime.rpc.RpcService; +import org.apache.flink.runtime.taskexecutor.TaskExecutor; +import org.apache.flink.runtime.taskexecutor.TaskExecutorToServiceAdapter; +import org.apache.flink.runtime.taskexecutor.TaskManagerRunner; +import org.apache.flink.runtime.util.ZooKeeperUtils; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.Field; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Map; + +/** A process of a task manager. */ +public class TaskManagerProcess extends TestJvmProcess { + + private static final Logger LOG = LoggerFactory.getLogger(TaskManagerProcess.class); + + private final String[] jvmArgs; + + private final int index; + + public TaskManagerProcess(String logDirName, Configuration config, int index) throws Exception { + super("TaskExecutor-" + index, logDirName); + Configuration configurationWithFallback = + TaskExecutorProcessUtils.getConfigurationMapLegacyTaskManagerHeapSizeToConfigOption( + config, TaskManagerOptions.TOTAL_FLINK_MEMORY); + + final TaskExecutorProcessSpec processSpec = + TaskExecutorProcessUtils.processSpecFromConfig(configurationWithFallback); + config.set(TaskManagerOptions.MANAGED_MEMORY_SIZE, processSpec.getManagedMemorySize()); + config.set(TaskManagerOptions.NETWORK_MEMORY_MIN, processSpec.getNetworkMemSize()); + config.set(TaskManagerOptions.NETWORK_MEMORY_MAX, processSpec.getNetworkMemSize()); + config.set(TaskManagerOptions.TASK_HEAP_MEMORY, processSpec.getTaskHeapSize()); + config.set(TaskManagerOptions.TASK_OFF_HEAP_MEMORY, processSpec.getTaskOffHeapSize()); + config.set( + TaskManagerOptions.FRAMEWORK_HEAP_MEMORY, + processSpec.getFlinkMemory().getFrameworkHeap()); + config.set( + TaskManagerOptions.FRAMEWORK_OFF_HEAP_MEMORY, + processSpec.getFlinkMemory().getFrameworkOffHeap()); + config.set(TaskManagerOptions.JVM_OVERHEAD_MIN, processSpec.getJvmOverheadSize()); + config.set(TaskManagerOptions.JVM_OVERHEAD_MAX, processSpec.getJvmOverheadSize()); + + ArrayList args = new ArrayList<>(); + + config.setInteger("taskmanager.index", index); + for (Map.Entry entry : config.toMap().entrySet()) { + args.add("--" + entry.getKey()); + args.add(entry.getValue()); + } + + this.jvmArgs = new String[args.size()]; + args.toArray(jvmArgs); + this.index = index; + + setJVMHeapMemory(processSpec.getJvmHeapMemorySize().getMebiBytes()); + setJvmDirectMemory(processSpec.getJvmDirectMemorySize().getMebiBytes()); + LOG.info( + "JVM Process {} with process memory spec {}, heap memory {}, direct memory {}", + getName(), + processSpec, + processSpec.getJvmHeapMemorySize().getMebiBytes(), + processSpec.getJvmDirectMemorySize().getMebiBytes()); + } + + @Override + public String[] getJvmArgs() { + return jvmArgs; + } + + @Override + public String getEntryPointClassName() { + return EntryPoint.class.getName(); + } + + /** Entry point for the Dispatcher process. */ + public static class EntryPoint { + + private static final Logger LOG = LoggerFactory.getLogger(EntryPoint.class); + + public static void main(String[] args) { + try { + final ParameterTool parameterTool = ParameterTool.fromArgs(args); + Configuration cfg = parameterTool.getConfiguration(); + runTaskManager(cfg, PluginUtils.createPluginManagerFromRootFolder(cfg)); + } catch (Throwable t) { + LOG.error("Failed to run the TaskManager process", t); + System.exit(1); + } + } + + public static void runTaskManager(Configuration cfg, PluginManager pluginManager) + throws Exception { + TaskManagerRunner.TaskExecutorServiceFactory factory = + EntryPoint::createTaskExecutorService; + TaskManagerRunner taskManagerRunner = + new TaskManagerRunner(cfg, pluginManager, factory); + taskManagerRunner.start(); + } + + public static TaskManagerRunner.TaskExecutorService createTaskExecutorService( + Configuration configuration, + ResourceID resourceID, + RpcService rpcService, + HighAvailabilityServices highAvailabilityServices, + HeartbeatServices heartbeatServices, + MetricRegistry metricRegistry, + BlobCacheService blobCacheService, + boolean localCommunicationOnly, + ExternalResourceInfoProvider externalResourceInfoProvider, + FatalErrorHandler fatalErrorHandler) + throws Exception { + + final TaskExecutor taskExecutor = + TaskManagerRunner.startTaskManager( + configuration, + resourceID, + rpcService, + highAvailabilityServices, + heartbeatServices, + metricRegistry, + blobCacheService, + localCommunicationOnly, + externalResourceInfoProvider, + fatalErrorHandler); + startMetricsReportingDaemon(configuration, taskExecutor); + return TaskExecutorToServiceAdapter.createFor(taskExecutor); + } + + public static void startMetricsReportingDaemon( + Configuration conf, TaskExecutor taskExecutor) throws Exception { + Class fooClass = (Class) taskExecutor.getClass(); + Field field = fooClass.getDeclaredField("shuffleEnvironment"); + field.setAccessible(true); + RemoteShuffleEnvironment env = (RemoteShuffleEnvironment) field.get(taskExecutor); + NetworkBufferPool bufferPool = env.getNetworkBufferPool(); + CuratorFramework zkClient = + ZooKeeperUtils.startCuratorFramework(conf, LogErrorHandler.INSTANCE); + int index = conf.getInteger("taskmanager.index", -1); + String zkPath = "/taskmanager-" + index; + if (zkClient.checkExists().forPath(zkPath) != null) { + zkClient.delete().forPath(zkPath); + } + zkClient.create().forPath(zkPath); + Thread thread = + new Thread( + () -> { + try { + while (true) { + int num = bufferPool.getNumberOfAvailableMemorySegments(); + TaskManagerStats tmStats = new TaskManagerStats(num); + zkClient.setData().forPath(zkPath, tmStats.toBytes()); + Thread.sleep(500); + } + } catch (Throwable t) { + LOG.warn("Exception when reporting metrics.", t); + System.exit(-1); + } + }); + thread.setDaemon(true); + thread.start(); + } + } + + static class TaskManagerStats { + + int availableBuffers; + + TaskManagerStats(int availableBuffers) { + this.availableBuffers = availableBuffers; + } + + byte[] toBytes() { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(availableBuffers); + return buf.array(); + } + + static TaskManagerStats fromBytes(byte[] bytes) { + ByteBuffer wrapped = ByteBuffer.wrap(bytes); + int numAvailableBuffers = wrapped.getInt(); + return new TaskManagerStats(numAvailableBuffers); + } + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/shufflecluster/LocalShuffleCluster.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/shufflecluster/LocalShuffleCluster.java new file mode 100644 index 00000000..4706d23b --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/shufflecluster/LocalShuffleCluster.java @@ -0,0 +1,398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e.shufflecluster; + +import com.alibaba.flink.shuffle.client.ShuffleManagerClient; +import com.alibaba.flink.shuffle.client.ShuffleManagerClientConfiguration; +import com.alibaba.flink.shuffle.client.ShuffleManagerClientImpl; +import com.alibaba.flink.shuffle.client.ShuffleWorkerStatusListener; +import com.alibaba.flink.shuffle.common.config.ConfigOption; +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.config.MemorySize; +import com.alibaba.flink.shuffle.common.functions.RunnableWithException; +import com.alibaba.flink.shuffle.common.handler.FatalErrorHandler; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServicesUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServiceUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.coordinator.utils.RandomIDUtils; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerMetricKeys; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerMetrics; +import com.alibaba.flink.shuffle.core.config.HeartbeatOptions; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; +import com.alibaba.flink.shuffle.core.config.MemoryOptions; +import com.alibaba.flink.shuffle.core.config.WorkerOptions; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.e2e.TestJvmProcess; +import com.alibaba.flink.shuffle.e2e.utils.CommonTestUtils; +import com.alibaba.flink.shuffle.e2e.zookeeper.ZooKeeperTestUtils; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; +import com.alibaba.flink.shuffle.rpc.utils.AkkaRpcServiceUtils; +import com.alibaba.flink.shuffle.rpc.utils.RpcUtils; +import com.alibaba.flink.shuffle.transfer.AbstractNettyTest; + +import org.apache.flink.api.common.time.Deadline; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.StringWriter; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +/** + * A testing local shuffle cluster, which contains a shuffle manager process and several shuffle + * worker processes. + */ +public class LocalShuffleCluster { + + private static final Logger LOG = LoggerFactory.getLogger(LocalShuffleCluster.class); + + private final String logDirName; + + private final Configuration config; + + private final MemorySize bufferSize; + + private final int numWorkers; + + private final Path parentDataDir; + + private final String zkConnect; + + public ShuffleManagerProcess shuffleManager; + + public ShuffleWorkerProcess[] shuffleWorkers; + + public ShuffleManagerClient shuffleManagerClient; + + private CuratorFramework zkClient; + + private RemoteShuffleRpcService rpcService; + + public LocalShuffleCluster( + String logDirName, + int numWorkers, + String zkConnect, + Path parentDataDir, + Configuration otherConfigurations) { + this.logDirName = logDirName; + this.numWorkers = numWorkers; + this.shuffleWorkers = new ShuffleWorkerProcess[numWorkers]; + this.zkConnect = zkConnect; + this.parentDataDir = parentDataDir; + this.config = new Configuration(otherConfigurations); + + this.bufferSize = config.getMemorySize(MemoryOptions.MEMORY_BUFFER_SIZE); + config.setMemorySize( + MemoryOptions.MEMORY_SIZE_FOR_DATA_READING, MemoryOptions.MIN_VALID_MEMORY_SIZE); + config.setMemorySize( + MemoryOptions.MEMORY_SIZE_FOR_DATA_WRITING, MemoryOptions.MIN_VALID_MEMORY_SIZE); + } + + public void start() throws Exception { + LOG.info( + "Starting the local shuffle cluster with parent data directory {}.", parentDataDir); + initConfig(); + shuffleManager = new ShuffleManagerProcess(logDirName, config); + shuffleManager.startProcess(); + + for (int i = 0; i < numWorkers; i++) { + String workerDataDir = "ShuffleWorker-" + i; + Files.createDirectory(parentDataDir.resolve(workerDataDir)); + shuffleWorkers[i] = + new ShuffleWorkerProcess( + logDirName, + new Configuration(config), + i, + AbstractNettyTest.getAvailablePort(), + getDataDirForWorker(i)); + shuffleWorkers[i].startProcess(); + } + + rpcService = + AkkaRpcServiceUtils.remoteServiceBuilder(config, "localhost", "0").createAndStart(); + + HaServices haServices = HaServiceUtils.createHAServices(config); + + shuffleManagerClient = + new ShuffleManagerClientImpl( + RandomIDUtils.randomJobId(), + new ShuffleWorkerStatusListener() { + @Override + public void notifyIrrelevantWorker(InstanceID workerID) {} + + @Override + public void notifyRelevantWorker( + InstanceID workerID, + Set dataPartitions) {} + }, + rpcService, + (Throwable throwable) -> new ShutDownFatalErrorHandler(), + ShuffleManagerClientConfiguration.fromConfiguration(config), + haServices, + HeartbeatServicesUtils.createManagerJobHeartbeatServices(config)); + shuffleManagerClient.start(); + + waitForShuffleClusterReady(Duration.ofMinutes(5)); + + LOG.info("Local shuffle cluster started successfully."); + } + + public String getDataDirForWorker(int index) { + return new File(parentDataDir.toFile(), "ShuffleWorker-" + index).getAbsolutePath(); + } + + public void printProcessLog() { + shuffleManager.printProcessLog(); + for (TestJvmProcess worker : shuffleWorkers) { + worker.printProcessLog(); + } + } + + public void shutdown() throws Exception { + for (int i = 0; i < numWorkers; i++) { + shuffleWorkers[i].destroy(); + } + + shuffleManager.destroy(); + + RpcUtils.terminateRpcService(rpcService, 30000L); + + ZooKeeperTestUtils.deleteAll(zkClient); + zkClient.close(); + } + + public Optional findShuffleWorker(int processId) { + for (int i = 0; i < shuffleWorkers.length; ++i) { + if (shuffleWorkers[i].getProcessId() == processId) { + return Optional.of(i); + } + } + + return Optional.empty(); + } + + public void killShuffleWorker(int index) { + shuffleWorkers[index].destroy(); + } + + public void killShuffleWorkerForcibly(int index) throws Exception { + Process kill = Runtime.getRuntime().exec("kill -9 " + shuffleWorkers[index].getProcessId()); + kill.waitFor(); + } + + public boolean isShuffleWorkerAlive(int index) { + return shuffleWorkers[index].isAlive(); + } + + public void recoverShuffleWorker(int index) throws Exception { + checkState( + !isShuffleWorkerAlive(index), + String.format( + "ShuffleWorker %s is alive now, should not be recovered.", + shuffleWorkers[index].getName())); + shuffleWorkers[index] = + new ShuffleWorkerProcess( + logDirName, + config, + index, + shuffleWorkers[index].getDataPort(), + shuffleWorkers[index].getDataDir()); + shuffleWorkers[index].startProcess(); + } + + public void killShuffleManager() { + shuffleManager.destroy(); + } + + public void recoverShuffleManager() throws Exception { + checkState( + !shuffleManager.isAlive(), + "ShuffleManager is still alive, should not be recovered"); + + shuffleManager = new ShuffleManagerProcess(logDirName, config); + shuffleManager.startProcess(); + } + + public Configuration getConfig() { + return config; + } + + public void checkResourceReleased() throws Exception { + CommonTestUtils.delayCheck( + () -> { + checkStorageResourceReleased(); + checkNetworkReleased(); + checkBuffersReleased(); + }, + Deadline.fromNow(Duration.ofMinutes(5)), + "timeout"); + } + + public void checkStorageResourceReleased() { + for (int i = 0; i < shuffleWorkers.length; ++i) { + if (!shuffleWorkers[i].isAlive()) { + continue; + } + String dataDir = getDataDirForWorker(i); + Path dataPath = new File(dataDir).toPath(); + Path metaPath = new File(dataDir + "/_meta").toPath(); + checkUntil( + () -> { + List paths = Files.list(dataPath).collect(Collectors.toList()); + if (!(paths.size() == 1 && paths.get(0).equals(metaPath))) { + StringBuilder msgBuilder = new StringBuilder("Leaking files:"); + paths.forEach(path -> msgBuilder.append(" " + path)); + throw new AssertionError(msgBuilder.toString()); + } + if (Files.exists(metaPath)) { + List metas = Files.list(metaPath).collect(Collectors.toList()); + if (!metas.isEmpty()) { + StringBuilder msgBuilder = new StringBuilder("Leaking metas:"); + metas.forEach(path -> msgBuilder.append(" " + path)); + throw new AssertionError(msgBuilder.toString()); + } + } + }); + } + } + + public void checkNetworkReleased() throws Exception { + for (ShuffleWorkerProcess worker : shuffleWorkers) { + if (!worker.isAlive()) { + continue; + } + int dataPort = worker.getDataPort(); + String cmd = String.format("netstat -ant"); + Process process = Runtime.getRuntime().exec(cmd); + StringWriter processOutput = new StringWriter(); + new CommonTestUtils.PipeForwarder(process.getInputStream(), processOutput).join(); + String[] lines = processOutput.toString().split("\n"); + for (String line : lines) { + String pattern = String.format(".*[^0-9]%d[^0-9].*ESTABLISHED.*", dataPort); + if (Pattern.matches(pattern, line)) { + throw new AssertionError("Network connections residual."); + } + } + } + } + + public void checkBuffersReleased() { + checkUntil( + () -> { + int numReadingBuffers = + calculateNumBuffers(MemoryOptions.MEMORY_SIZE_FOR_DATA_READING); + int numWritingBuffers = + calculateNumBuffers(MemoryOptions.MEMORY_SIZE_FOR_DATA_WRITING); + + Collection shuffleWorkerMetrics = + shuffleManagerClient.getShuffleWorkerMetrics().get().values(); + for (ShuffleWorkerMetrics metric : shuffleWorkerMetrics) { + // check reading buffer. + assertThat( + metric.getIntegerMetric( + ShuffleWorkerMetricKeys.AVAILABLE_READING_BUFFERS_KEY), + is(numReadingBuffers)); + + // check writing buffer. + assertThat( + metric.getIntegerMetric( + ShuffleWorkerMetricKeys.AVAILABLE_READING_BUFFERS_KEY), + is(numWritingBuffers)); + } + }); + } + + private int calculateNumBuffers(ConfigOption memorySizeOptions) { + MemorySize memorySize = config.getMemorySize(memorySizeOptions); + return CommonUtils.checkedDownCast(memorySize.getBytes() / bufferSize.getBytes()); + } + + private void checkUntil(RunnableWithException runnable) { + Throwable lastThrowable = null; + for (int i = 0; i < 100; i++) { + try { + runnable.run(); + return; + } catch (Throwable t) { + lastThrowable = t; + try { + Thread.sleep(2000); + } catch (Exception it) { + } + } + } + ExceptionUtils.rethrowAsRuntimeException(lastThrowable); + } + + private void initConfig() throws Exception { + config.addAll(ZooKeeperTestUtils.createZooKeeperHAConfig(zkConnect)); + config.setString(HighAvailabilityOptions.HA_MODE, "zookeeper"); + config.setMemorySize(ManagerOptions.FRAMEWORK_HEAP_MEMORY, MemorySize.ofMebiBytes(128)); + config.setMemorySize(ManagerOptions.FRAMEWORK_OFF_HEAP_MEMORY, MemorySize.ofMebiBytes(32)); + config.setMemorySize(WorkerOptions.FRAMEWORK_HEAP_MEMORY, MemorySize.ofMebiBytes(128)); + config.setMemorySize(WorkerOptions.FRAMEWORK_OFF_HEAP_MEMORY, MemorySize.ofMebiBytes(32)); + config.setDuration(HeartbeatOptions.HEARTBEAT_JOB_INTERVAL, Duration.ofSeconds(3)); + config.setDuration(HeartbeatOptions.HEARTBEAT_WORKER_INTERVAL, Duration.ofSeconds(3)); + + zkClient = ZooKeeperTestUtils.createZKClientForRemoteShuffle(config); + } + + private void waitForShuffleClusterReady(Duration timeout) throws Exception { + final String errorMsg = + "LocalShuffleCluster is not ready within " + timeout.toMillis() + "ms"; + CommonTestUtils.waitUntilCondition( + () -> shuffleManagerClient.getNumberOfRegisteredWorkers().get() >= numWorkers, + Deadline.fromNow(timeout), + errorMsg); + } + + private class ShutDownFatalErrorHandler implements FatalErrorHandler { + + @Override + public void onFatalError(Throwable exception) { + LOG.warn( + "Error in LocalShuffleCluster. Shutting the LocalShuffleCluster down.", + exception); + try { + shutdown(); + } catch (Throwable throwable) { + LOG.warn("Shutting down LocalShuffleCluster error.", throwable); + } + } + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/shufflecluster/LocalShuffleClusterE2ETest.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/shufflecluster/LocalShuffleClusterE2ETest.java new file mode 100644 index 00000000..0dda9743 --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/shufflecluster/LocalShuffleClusterE2ETest.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e.shufflecluster; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.config.MemorySize; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerMetricKeys; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerMetrics; +import com.alibaba.flink.shuffle.core.config.MemoryOptions; +import com.alibaba.flink.shuffle.e2e.utils.CommonTestUtils; +import com.alibaba.flink.shuffle.e2e.zookeeper.ZooKeeperTestEnvironment; + +import org.apache.flink.api.common.time.Deadline; + +import org.apache.commons.io.FileUtils; +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.TestName; + +import java.io.File; +import java.io.IOException; +import java.time.Duration; +import java.util.Collection; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +/** Test for {@link LocalShuffleCluster}. */ +public class LocalShuffleClusterE2ETest { + + @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); + + @Rule public TestName name = new TestName(); + + private ZooKeeperTestEnvironment zkCluster; + + private LocalShuffleCluster cluster; + + @Before + public void setup() throws IOException { + File logDir = + new File( + System.getProperty("buildDirectory") + + "/" + + getClass().getSimpleName() + + "-" + + name.getMethodName()); + if (logDir.exists()) { + FileUtils.deleteDirectory(logDir); + } + + zkCluster = new ZooKeeperTestEnvironment(1); + cluster = + new LocalShuffleCluster( + logDir.getAbsolutePath(), + 2, + zkCluster.getConnect(), + temporaryFolder.newFolder().toPath(), + new Configuration()); + } + + @After + public void cleanup() throws Exception { + cluster.shutdown(); + zkCluster.shutdown(); + } + + @Test(timeout = 300000L) + public void testKillShuffleWorker() throws Exception { + + cluster.start(); + MatcherAssert.assertThat( + cluster.shuffleManagerClient.getNumberOfRegisteredWorkers().get(), is(2)); + + cluster.killShuffleWorker(0); + + CommonTestUtils.waitUntilCondition( + () -> cluster.shuffleManagerClient.getNumberOfRegisteredWorkers().get() == 1, + Deadline.fromNow(Duration.ofMinutes(5)), + "timeout."); + } + + @Test(timeout = 300000L) + public void testKillShuffleWorkerForcibly() throws Exception { + + cluster.start(); + MatcherAssert.assertThat( + cluster.shuffleManagerClient.getNumberOfRegisteredWorkers().get(), is(2)); + + cluster.killShuffleWorkerForcibly(0); + + CommonTestUtils.waitUntilCondition( + () -> cluster.shuffleManagerClient.getNumberOfRegisteredWorkers().get() == 1, + Deadline.fromNow(Duration.ofMinutes(5)), + "timeout."); + } + + @Test(timeout = 300000L) + public void testRecoverShuffleWorker() throws Exception { + cluster.start(); + MatcherAssert.assertThat( + cluster.shuffleManagerClient.getNumberOfRegisteredWorkers().get(), is(2)); + + cluster.killShuffleWorkerForcibly(0); + CommonTestUtils.waitUntilCondition( + () -> cluster.shuffleManagerClient.getNumberOfRegisteredWorkers().get() == 1, + Deadline.fromNow(Duration.ofMinutes(5)), + "timeout."); + assertThat(cluster.isShuffleWorkerAlive(0), is(false)); + + // recover + cluster.recoverShuffleWorker(0); + CommonTestUtils.waitUntilCondition( + () -> cluster.shuffleManagerClient.getNumberOfRegisteredWorkers().get() == 2, + Deadline.fromNow(Duration.ofMinutes(5)), + "timeout."); + assertThat(cluster.isShuffleWorkerAlive(0), is(true)); + } + + @Test(timeout = 300000L) + public void testGetShuffleWorkerMetrics() throws Exception { + cluster.start(); + Collection shuffleWorkerMetrics = + cluster.shuffleManagerClient.getShuffleWorkerMetrics().get().values(); + for (ShuffleWorkerMetrics metric : shuffleWorkerMetrics) { + Configuration configuration = cluster.getConfig(); + MemorySize bufferSize = configuration.getMemorySize(MemoryOptions.MEMORY_BUFFER_SIZE); + + int actualR = + metric.getIntegerMetric(ShuffleWorkerMetricKeys.AVAILABLE_READING_BUFFERS_KEY); + MemorySize rMemorySize = + configuration.getMemorySize(MemoryOptions.MEMORY_SIZE_FOR_DATA_READING); + int expectR = (int) (rMemorySize.getBytes() / bufferSize.getBytes()); + assertThat(actualR, is(expectR)); + + int actualW = + metric.getIntegerMetric(ShuffleWorkerMetricKeys.AVAILABLE_WRITING_BUFFERS_KEY); + MemorySize wMemorySize = + configuration.getMemorySize(MemoryOptions.MEMORY_SIZE_FOR_DATA_WRITING); + int expectW = (int) (wMemorySize.getBytes() / bufferSize.getBytes()); + assertThat(actualW, is(expectW)); + } + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/shufflecluster/LocalShuffleClusterUtils.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/shufflecluster/LocalShuffleClusterUtils.java new file mode 100644 index 00000000..9472fb15 --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/shufflecluster/LocalShuffleClusterUtils.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e.shufflecluster; + +import com.alibaba.flink.shuffle.common.config.Configuration; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** Utils for {@link LocalShuffleCluster}. */ +public class LocalShuffleClusterUtils { + static String[] generateDynamicConfigs(Configuration configuration) { + List dynamicConfigs = new ArrayList<>(); + configuration.toMap().entrySet().stream() + .sorted(Map.Entry.comparingByKey()) + .forEach( + kv -> { + dynamicConfigs.add("-D"); + dynamicConfigs.add(String.format("%s=%s", kv.getKey(), kv.getValue())); + }); + + return dynamicConfigs.toArray(new String[dynamicConfigs.size()]); + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/shufflecluster/ShuffleManagerProcess.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/shufflecluster/ShuffleManagerProcess.java new file mode 100644 index 00000000..7ab647ab --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/shufflecluster/ShuffleManagerProcess.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e.shufflecluster; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleManager; +import com.alibaba.flink.shuffle.coordinator.manager.entrypoint.ShuffleManagerEntrypoint; +import com.alibaba.flink.shuffle.coordinator.utils.ClusterEntrypointUtils; +import com.alibaba.flink.shuffle.coordinator.utils.EnvironmentInformation; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; +import com.alibaba.flink.shuffle.core.config.memory.ShuffleManagerProcessSpec; +import com.alibaba.flink.shuffle.e2e.TestJvmProcess; + +import org.apache.flink.runtime.util.JvmShutdownSafeguard; +import org.apache.flink.runtime.util.SignalHandler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A {@link ShuffleManager} instance running in a separate JVM. */ +public class ShuffleManagerProcess extends TestJvmProcess { + + private final String[] jvmArgs; + + public ShuffleManagerProcess(String logDirName, Configuration configuration) throws Exception { + super("ShuffleManager", logDirName); + this.jvmArgs = LocalShuffleClusterUtils.generateDynamicConfigs(configuration); + + ShuffleManagerProcessSpec processSpec = new ShuffleManagerProcessSpec(configuration); + setJvmDirectMemory(processSpec.getJvmDirectMemorySize().getMebiBytes()); + setJVMHeapMemory(processSpec.getJvmHeapMemorySize().getMebiBytes()); + } + + @Override + public String[] getJvmArgs() { + return jvmArgs; + } + + @Override + public String getEntryPointClassName() { + return EntryPoint.class.getName(); + } + + /** Entry point for the ShuffleManager process. */ + public static class EntryPoint { + + private static final Logger LOG = LoggerFactory.getLogger(EntryPoint.class); + + public static void main(String[] args) throws Exception { + // startup checks and logging + EnvironmentInformation.logEnvironmentInfo(LOG, "Shuffle Manager", args); + SignalHandler.register(LOG); + JvmShutdownSafeguard.installAsShutdownHook(LOG); + + Configuration configuration = ClusterEntrypointUtils.parseParametersOrExit(args); + + configuration.setString(ManagerOptions.RPC_ADDRESS, "127.0.0.1"); + ShuffleManagerEntrypoint shuffleManagerEntrypoint = + new ShuffleManagerEntrypoint(configuration); + + ShuffleManagerEntrypoint.runShuffleManagerEntrypoint(shuffleManagerEntrypoint); + } + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/shufflecluster/ShuffleWorkerProcess.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/shufflecluster/ShuffleWorkerProcess.java new file mode 100644 index 00000000..ae871bf8 --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/shufflecluster/ShuffleWorkerProcess.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e.shufflecluster; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorker; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerRunner; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.config.TransferOptions; +import com.alibaba.flink.shuffle.core.config.memory.ShuffleWorkerProcessSpec; +import com.alibaba.flink.shuffle.e2e.TestJvmProcess; + +/** A {@link ShuffleWorker} instance running in a separate JVM. */ +public class ShuffleWorkerProcess extends TestJvmProcess { + + private final String[] jvmArgs; + + private final int dataPort; + + private final String dataDir; + + public ShuffleWorkerProcess( + String logDirName, Configuration configuration, int index, int dataPort, String dataDir) + throws Exception { + super("ShuffleWorker-" + index, logDirName); + this.dataPort = dataPort; + this.dataDir = dataDir; + configuration.setInteger(TransferOptions.SERVER_DATA_PORT, dataPort); + configuration.setString(StorageOptions.STORAGE_LOCAL_DATA_DIRS, dataDir); + this.jvmArgs = LocalShuffleClusterUtils.generateDynamicConfigs(configuration); + + ShuffleWorkerProcessSpec processSpec = new ShuffleWorkerProcessSpec(configuration); + setJvmDirectMemory(processSpec.getJvmDirectMemorySize().getMebiBytes()); + setJVMHeapMemory(processSpec.getJvmHeapMemorySize().getMebiBytes()); + } + + @Override + public String[] getJvmArgs() { + return jvmArgs; + } + + @Override + public String getEntryPointClassName() { + return ShuffleWorkerRunner.class.getName(); + } + + public int getDataPort() { + return dataPort; + } + + public String getDataDir() { + return dataDir; + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/utils/CommonTestUtils.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/utils/CommonTestUtils.java new file mode 100644 index 00000000..b695b1bd --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/utils/CommonTestUtils.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e.utils; + +import com.alibaba.flink.shuffle.common.functions.RunnableWithException; + +import org.apache.flink.api.common.time.Deadline; +import org.apache.flink.util.FileUtils; +import org.apache.flink.util.function.SupplierWithException; + +import java.io.BufferedInputStream; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.io.InputStream; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.lang.management.ManagementFactory; +import java.lang.management.RuntimeMXBean; +import java.util.concurrent.TimeoutException; + +/** This class contains auxiliary methods for unit tests. */ +public class CommonTestUtils { + + private static final long RETRY_INTERVAL = 100L; + + /** + * Gets the classpath with which the current JVM was started. + * + * @return The classpath with which the current JVM was started. + */ + public static String getCurrentClasspath() { + RuntimeMXBean bean = ManagementFactory.getRuntimeMXBean(); + return bean.getClassPath(); + } + + /** Create a temporary log4j configuration for the test. */ + public static File createTemporaryLog4JProperties() throws IOException { + File log4jProps = File.createTempFile(FileUtils.getRandomFilename(""), "-log4j.properties"); + log4jProps.deleteOnExit(); + CommonTestUtils.printLog4jDebugConfig(log4jProps); + + return log4jProps; + } + + /** Create a temporary log4j configuration for the test. */ + public static File createTemporaryLogFile() throws IOException { + File logFile = File.createTempFile(FileUtils.getRandomFilename(""), "-log"); + logFile.deleteOnExit(); + + return logFile; + } + + /** + * Tries to get the java executable command with which the current JVM was started. Returns + * null, if the command could not be found. + * + * @return The java executable command. + */ + public static String getJavaCommandPath() { + File javaHome = new File(System.getProperty("java.home")); + + String path1 = new File(javaHome, "java").getAbsolutePath(); + String path2 = new File(new File(javaHome, "bin"), "java").getAbsolutePath(); + + try { + ProcessBuilder bld = new ProcessBuilder(path1, "-version"); + Process process = bld.start(); + if (process.waitFor() == 0) { + return path1; + } + } catch (Throwable t) { + // ignore and try the second path + } + + try { + ProcessBuilder bld = new ProcessBuilder(path2, "-version"); + Process process = bld.start(); + if (process.waitFor() == 0) { + return path2; + } + } catch (Throwable tt) { + // no luck + } + return null; + } + + public static void printLog4jDebugConfig(File file) throws IOException { + try (PrintWriter writer = new PrintWriter(new FileWriter(file))) { + writer.println("rootLogger.level = INFO"); + + writer.println("rootLogger.appenderRef.file.ref = MainAppender"); + writer.println("appender.main.name = MainAppender"); + writer.println("appender.main.type = RollingFile"); + writer.println("appender.main.append = true"); + writer.println("appender.main.fileName = ${sys:log.file}"); + writer.println("appender.main.filePattern = ${sys:log.file}.%i"); + writer.println("appender.main.layout.type = PatternLayout"); + writer.println( + "appender.main.layout.pattern = %d{yyyy-MM-dd HH:mm:ss,SSS} %-5p %c{1} - %m%n"); + writer.println("appender.main.policies.type = Policies"); + writer.println("appender.main.policies.size.type = SizeBasedTriggeringPolicy"); + writer.println("appender.main.policies.size.size = 100MB"); + writer.println("appender.main.policies.startup.type = OnStartupTriggeringPolicy"); + writer.println("appender.main.strategy.type = DefaultRolloverStrategy"); + writer.println("appender.main.strategy.max = ${env:MAX_LOG_FILE_NUMBER:-10}"); + + writer.println("logger.jetty.name = org.eclipse.jetty.util.log"); + writer.println("logger.jetty.level = OFF"); + writer.println("logger.zookeeper.name = org.apache.zookeeper"); + writer.println("logger.zookeeper.level = OFF"); + + writer.flush(); + } + } + + public static void delayCheck(RunnableWithException runnable, Deadline timeout, String errorMsg) + throws Exception { + waitUntilCondition( + () -> { + try { + runnable.run(); + } catch (Throwable t) { + return false; + } + return true; + }, + timeout, + errorMsg); + } + + public static void waitUntilCondition( + SupplierWithException condition, Deadline timeout, String errorMsg) + throws Exception { + waitUntilCondition(condition, timeout, RETRY_INTERVAL, errorMsg); + } + + public static void waitUntilCondition( + SupplierWithException condition, + Deadline timeout, + long retryIntervalMillis, + String errorMsg) + throws Exception { + while (timeout.hasTimeLeft() && !condition.get()) { + final long timeLeft = Math.max(0, timeout.timeLeft().toMillis()); + Thread.sleep(Math.min(retryIntervalMillis, timeLeft)); + } + + if (!timeout.hasTimeLeft()) { + throw new TimeoutException(errorMsg); + } + } + + /** Utility class to read the output of a process stream and forward it into a StringWriter. */ + public static class PipeForwarder extends Thread { + + private final StringWriter target; + private final InputStream source; + + public PipeForwarder(InputStream source, StringWriter target) { + super("Pipe Forwarder"); + setDaemon(true); + + this.source = source; + this.target = target; + + start(); + } + + @Override + public void run() { + try { + int next; + while ((next = source.read()) != -1) { + target.write(next); + } + } catch (IOException e) { + // terminate + } + } + } + + public static boolean isStreamContentEqual(InputStream input1, InputStream input2) + throws IOException { + + if (!(input1 instanceof BufferedInputStream)) { + input1 = new BufferedInputStream(input1); + } + if (!(input2 instanceof BufferedInputStream)) { + input2 = new BufferedInputStream(input2); + } + + int ch = input1.read(); + while (-1 != ch) { + int ch2 = input2.read(); + if (ch != ch2) { + return false; + } + ch = input1.read(); + } + + int ch2 = input2.read(); + return (ch2 == -1); + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/utils/LogErrorHandler.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/utils/LogErrorHandler.java new file mode 100644 index 00000000..ccee7157 --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/utils/LogErrorHandler.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e.utils; + +import org.apache.flink.runtime.rpc.FatalErrorHandler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A {@link FatalErrorHandler} implementation which just logs the error encountered. */ +public class LogErrorHandler implements FatalErrorHandler { + + private static final Logger LOG = LoggerFactory.getLogger(LogErrorHandler.class); + + public static final LogErrorHandler INSTANCE = new LogErrorHandler(); + + @Override + public void onFatalError(Throwable throwable) { + LOG.error("Error encountered.", throwable); + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/zookeeper/ZooKeeperTestEnvironment.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/zookeeper/ZooKeeperTestEnvironment.java new file mode 100644 index 00000000..61207e69 --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/zookeeper/ZooKeeperTestEnvironment.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e.zookeeper; + +import org.apache.curator.test.TestingCluster; +import org.apache.curator.test.TestingServer; + +/** Simple ZooKeeper and CuratorFramework setup for tests. */ +public class ZooKeeperTestEnvironment { + + private final TestingServer zooKeeperServer; + private final TestingCluster zooKeeperCluster; + + /** + * Starts a ZooKeeper cluster with the number of quorum peers and a client. + * + * @param numberOfZooKeeperQuorumPeers Starts a {@link TestingServer}, if 1. Starts + * a {@link TestingCluster}, if =>1. + */ + public ZooKeeperTestEnvironment(int numberOfZooKeeperQuorumPeers) { + if (numberOfZooKeeperQuorumPeers <= 0) { + throw new IllegalArgumentException("Number of peers needs to be >= 1."); + } + + try { + if (numberOfZooKeeperQuorumPeers == 1) { + zooKeeperServer = new TestingServer(true); + zooKeeperCluster = null; + } else { + zooKeeperServer = null; + zooKeeperCluster = new TestingCluster(numberOfZooKeeperQuorumPeers); + + zooKeeperCluster.start(); + } + } catch (Exception e) { + throw new RuntimeException("Error setting up ZooKeeperTestEnvironment", e); + } + } + + /** Shutdown the client and ZooKeeper server/cluster. */ + public void shutdown() throws Exception { + + if (zooKeeperServer != null) { + zooKeeperServer.close(); + } + + if (zooKeeperCluster != null) { + zooKeeperCluster.close(); + } + } + + public String getConnect() { + if (zooKeeperServer != null) { + return zooKeeperServer.getConnectString(); + } else { + return zooKeeperCluster.getConnectString(); + } + } +} diff --git a/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/zookeeper/ZooKeeperTestUtils.java b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/zookeeper/ZooKeeperTestUtils.java new file mode 100644 index 00000000..a6acb03e --- /dev/null +++ b/shuffle-e2e-tests/src/test/java/com/alibaba/flink/shuffle/e2e/zookeeper/ZooKeeperTestUtils.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.e2e.zookeeper; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.coordinator.highavailability.zookeeper.ZooKeeperUtils; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; +import com.alibaba.flink.shuffle.e2e.utils.LogErrorHandler; + +import org.apache.flink.configuration.AkkaOptions; +import org.apache.flink.configuration.CheckpointingOptions; +import org.apache.flink.runtime.jobmanager.HighAvailabilityMode; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; +import org.apache.flink.shaded.curator4.org.apache.curator.utils.ZKPaths; +import org.apache.flink.shaded.zookeeper3.org.apache.zookeeper.KeeperException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** ZooKeeper test utilities. */ +public class ZooKeeperTestUtils { + + private static final Logger LOG = LoggerFactory.getLogger(ZooKeeperTestUtils.class); + + /** + * Creates a configuration to operate in {@link HighAvailabilityMode#ZOOKEEPER}. + * + * @param zooKeeperQuorum ZooKeeper quorum to connect to + * @param fsStateHandlePath Base path for file system state backend (for checkpoints and + * recovery) + * @return A new configuration to operate in {@link HighAvailabilityMode#ZOOKEEPER}. + */ + public static org.apache.flink.configuration.Configuration createZooKeeperHAConfigForFlink( + String zooKeeperQuorum, String fsStateHandlePath) { + + return configureZooKeeperHAForFlink( + new org.apache.flink.configuration.Configuration(), + zooKeeperQuorum, + fsStateHandlePath); + } + + /** + * Sets all necessary configuration keys to operate in {@link HighAvailabilityMode#ZOOKEEPER}. + * + * @param config Configuration to use + * @param zooKeeperQuorum ZooKeeper quorum to connect to + * @param fsStateHandlePath Base path for file system state backend (for checkpoints and + * recovery) + * @return The modified configuration to operate in {@link HighAvailabilityMode#ZOOKEEPER}. + */ + public static org.apache.flink.configuration.Configuration configureZooKeeperHAForFlink( + org.apache.flink.configuration.Configuration config, + String zooKeeperQuorum, + String fsStateHandlePath) { + + checkNotNull(config); + checkNotNull(zooKeeperQuorum); + checkNotNull(fsStateHandlePath); + + // ZooKeeper recovery mode + config.setString( + org.apache.flink.configuration.HighAvailabilityOptions.HA_MODE, "ZOOKEEPER"); + config.setString( + org.apache.flink.configuration.HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM, + zooKeeperQuorum); + + int connTimeout = 60000; + config.setInteger( + org.apache.flink.configuration.HighAvailabilityOptions.ZOOKEEPER_CONNECTION_TIMEOUT, + connTimeout); + config.setInteger( + org.apache.flink.configuration.HighAvailabilityOptions.ZOOKEEPER_SESSION_TIMEOUT, + connTimeout); + + // File system state backend + config.setString(CheckpointingOptions.STATE_BACKEND, "FILESYSTEM"); + config.setString( + CheckpointingOptions.CHECKPOINTS_DIRECTORY, fsStateHandlePath + "/checkpoints"); + config.setString( + org.apache.flink.configuration.HighAvailabilityOptions.HA_STORAGE_PATH, + fsStateHandlePath + "/recovery"); + + config.setString(AkkaOptions.ASK_TIMEOUT, "100 s"); + + return config; + } + + public static Configuration createZooKeeperHAConfig(String zooKeeperQuorum) { + return configureZooKeeperHA(new Configuration(), zooKeeperQuorum); + } + + public static Configuration configureZooKeeperHA(Configuration config, String zooKeeperQuorum) { + + checkNotNull(config); + checkNotNull(zooKeeperQuorum); + + // ZooKeeper recovery mode + config.setString(HighAvailabilityOptions.HA_MODE, "ZOOKEEPER"); + config.setString(HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM, zooKeeperQuorum); + + int connTimeout = 60; + config.setDuration( + HighAvailabilityOptions.ZOOKEEPER_CONNECTION_TIMEOUT, + Duration.ofSeconds(connTimeout)); + config.setDuration( + HighAvailabilityOptions.ZOOKEEPER_SESSION_TIMEOUT, Duration.ofSeconds(connTimeout)); + + return config; + } + + public static CuratorFramework createZKClientForFlink( + org.apache.flink.configuration.Configuration configuration) { + return org.apache.flink.runtime.util.ZooKeeperUtils.startCuratorFramework( + configuration, LogErrorHandler.INSTANCE); + } + + public static CuratorFramework createZKClientForRemoteShuffle(Configuration configuration) { + return ZooKeeperUtils.startCuratorFramework(configuration); + } + + /** + * Deletes all ZNodes under the root node. + * + * @throws Exception If the ZooKeeper operation fails + */ + public static void deleteAll(CuratorFramework client) throws Exception { + final String path = "/" + client.getNamespace(); + + int maxAttempts = 10; + + for (int i = 0; i < maxAttempts; i++) { + try { + ZKPaths.deleteChildren(client.getZookeeperClient().getZooKeeper(), path, false); + return; + } catch (org.apache.zookeeper.KeeperException.NoNodeException e) { + // that seems all right. if one of the children we want to delete is + // actually already deleted, that's fine. + return; + } catch (KeeperException.ConnectionLossException e) { + // Keep retrying + Thread.sleep(100); + } + } + + throw new Exception( + "Could not clear the ZNodes under " + + path + + ". ZooKeeper is not in " + + "a clean state."); + } +} diff --git a/shuffle-e2e-tests/src/test/resources/log4j2-test.properties b/shuffle-e2e-tests/src/test/resources/log4j2-test.properties new file mode 100644 index 00000000..0c41239d --- /dev/null +++ b/shuffle-e2e-tests/src/test/resources/log4j2-test.properties @@ -0,0 +1,31 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +rootLogger.level=OFF +rootLogger.appenderRef.test.ref=TestLogger +appender.testlogger.name=TestLogger +appender.testlogger.type=CONSOLE +appender.testlogger.target=SYSTEM_ERR +appender.testlogger.layout.type=PatternLayout +appender.testlogger.layout.pattern=%-4r [%t] %-5p %c %x - %m%n + +logger.zookeeper.name=org.apache.zookeeper +logger.zookeeper.level=OFF +logger.zookeeper_shaded.name=org.apache.flink.shaded.zookeeper3.org.apache.zookeeper +logger.zookeeper_shaded.level=OFF diff --git a/shuffle-examples/pom.xml b/shuffle-examples/pom.xml new file mode 100644 index 00000000..0df1d242 --- /dev/null +++ b/shuffle-examples/pom.xml @@ -0,0 +1,98 @@ + + + + + flink-shuffle-parent + com.alibaba.flink.shuffle + 1.0-SNAPSHOT + + 4.0.0 + + shuffle-examples + + + + org.apache.flink + flink-streaming-java_${scala.binary.version} + ${flink.version} + provided + + + + + + + org.apache.maven.plugins + maven-jar-plugin + 2.2 + + + + test-jar + + + + + BatchJobDemo + package + + jar + + + BatchJobDemo + + + + com.alibaba.flink.shuffle.examples.BatchJobDemo + + + + + com/alibaba/flink/shuffle/examples/BatchJobDemo.class + com/alibaba/flink/shuffle/examples/BatchJobDemo$*.class + + + + + + + + + org.apache.maven.plugins + maven-antrun-plugin + 1.7 + + + rename + package + + run + + + + + + + + + + + + diff --git a/shuffle-examples/src/main/java/com/alibaba/flink/shuffle/examples/BatchJobDemo.java b/shuffle-examples/src/main/java/com/alibaba/flink/shuffle/examples/BatchJobDemo.java new file mode 100644 index 00000000..edb9eba1 --- /dev/null +++ b/shuffle-examples/src/main/java/com/alibaba/flink/shuffle/examples/BatchJobDemo.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.examples; + +import org.apache.flink.runtime.jobgraph.JobType; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; +import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction; +import org.apache.flink.streaming.api.graph.GlobalStreamExchangeMode; +import org.apache.flink.streaming.api.graph.StreamGraph; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** Batch-shuffle job demo. */ +public class BatchJobDemo { + public static void main(String[] args) throws Exception { + int numRecords = 1024; + int parallelism = 4; + int recordSize = 1024; + int numRecordsToSend = 1024; + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(parallelism); + DataStream source = + env.addSource(new ByteArraySource(numRecordsToSend, recordSize, numRecords)); + source.rebalance() + .addSink( + new SinkFunction() { + @Override + public void invoke(byte[] value) {} + }); + StreamGraph streamGraph = env.getStreamGraph(); + streamGraph.setGlobalStreamExchangeMode(GlobalStreamExchangeMode.ALL_EDGES_BLOCKING); + streamGraph.setJobType(JobType.BATCH); + env.execute(streamGraph); + } + + private static class ByteArraySource implements ParallelSourceFunction { + private final int numRecordsToSend; + private final List records = new ArrayList<>(); + private volatile boolean isRunning = true; + + ByteArraySource(int numRecordsToSend, int recordSize, int numRecords) { + this.numRecordsToSend = numRecordsToSend; + Random random = new Random(); + for (int i = 0; i < numRecords; ++i) { + byte[] record = new byte[recordSize]; + random.nextBytes(record); + records.add(record); + } + } + + @Override + public void run(SourceContext sourceContext) { + int counter = 0; + while (isRunning && counter++ < numRecordsToSend) { + sourceContext.collect(records.get(counter % records.size())); + } + } + + @Override + public void cancel() { + isRunning = false; + } + } +} diff --git a/shuffle-kubernetes-operator/pom.xml b/shuffle-kubernetes-operator/pom.xml new file mode 100644 index 00000000..34885b30 --- /dev/null +++ b/shuffle-kubernetes-operator/pom.xml @@ -0,0 +1,169 @@ + + + + + com.alibaba.flink.shuffle + flink-shuffle-parent + 1.0-SNAPSHOT + + 4.0.0 + + shuffle-kubernetes-operator + + + 5.2.1 + + + + + + com.alibaba.flink.shuffle + shuffle-core + ${project.version} + + + + io.fabric8 + kubernetes-client + ${kubernetes.client.version} + + + + org.apache.commons + commons-lang3 + 3.3.2 + + + + + org.apache.logging.log4j + log4j-slf4j-impl + + + + org.apache.logging.log4j + log4j-api + + + + org.apache.logging.log4j + log4j-core + + + + + org.apache.logging.log4j + log4j-1.2-api + + + + + io.fabric8 + kubernetes-server-mock + ${kubernetes.client.version} + test + + + + + + + + org.apache.maven.plugins + maven-shade-plugin + + + shade-remote-shuffle + package + + shade + + + false + false + ${project.artifactId}-${project.version} + + + *:* + + *.aut + META-INF/maven/** + META-INF/services/*com.fasterxml* + META-INF/proguard/** + OSGI-INF/** + schema/** + *.vm + *.properties + *.xml + META-INF/jandex.idx + license.header + + + + + + + + + + io.fabric8 + + com.alibaba.flink.shuffle.kubernetes.shaded.io.fabric8 + + + + com.fasterxml.jackson + + com.alibaba.flink.shuffle.kubernetes.shaded.com.fasterxml.jackson + + + + okhttp3 + + com.alibaba.flink.shuffle.kubernetes.shaded.okhttp3 + + + + okio + com.alibaba.flink.shuffle.kubernetes.shaded.okio + + + + org.yaml + + com.alibaba.flink.shuffle.kubernetes.shaded.org.yaml + + + + dk.brics.automaton + + com.alibaba.flink.shuffle.kubernetes.shaded.dk.brics.automaton + + + + + + + + + + + diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/RemoteShuffleApplicationOperatorEntrypoint.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/RemoteShuffleApplicationOperatorEntrypoint.java new file mode 100644 index 00000000..b1f02c63 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/RemoteShuffleApplicationOperatorEntrypoint.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator; + +import com.alibaba.flink.shuffle.core.executor.ExecutorThreadFactory; +import com.alibaba.flink.shuffle.kubernetes.operator.controller.RemoteShuffleApplicationController; +import com.alibaba.flink.shuffle.kubernetes.operator.crd.RemoteShuffleApplication; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; + +import io.fabric8.kubernetes.api.model.apiextensions.v1beta1.CustomResourceDefinition; +import io.fabric8.kubernetes.api.model.apiextensions.v1beta1.CustomResourceDefinitionBuilder; +import io.fabric8.kubernetes.client.Config; +import io.fabric8.kubernetes.client.DefaultKubernetesClient; +import io.fabric8.kubernetes.client.KubernetesClient; +import io.fabric8.kubernetes.client.dsl.base.CustomResourceDefinitionContext; +import io.fabric8.kubernetes.client.informers.SharedInformerFactory; +import org.apache.commons.lang3.tuple.Triple; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +/** Entrypoint class of {@link RemoteShuffleApplication} Operator. */ +public class RemoteShuffleApplicationOperatorEntrypoint { + + private static final Logger LOG = + LoggerFactory.getLogger(RemoteShuffleApplicationOperatorEntrypoint.class); + + static final List> ADDITIONAL_COLUMN = + new ArrayList>() { + { + add( + Triple.of( + "SHUFFLE_MANAGER_READY", + "string", + ".status.readyShuffleManagers")); + add(Triple.of("SHUFFLE_WORKER_READY", "string", ".status.readyShuffleWorkers")); + } + }; + + public static void main(String[] args) { + Config config = Config.autoConfigure(null); + String namespace = config.getNamespace(); + if (namespace == null) { + LOG.info("No namespace found via config, assuming default."); + config.setNamespace("default"); + } + LOG.info("Using namespace {}.", config.getNamespace()); + + config.setUserAgent(Constants.REMOTE_SHUFFLE_OPERATOR_USER_AGENT); + LOG.info("Setting user agent for Kubernetes client to {}", config.getUserAgent()); + + final KubernetesClient kubeClient = new DefaultKubernetesClient(config); + final ExecutorService executorPool = + Executors.newFixedThreadPool(10, new ExecutorThreadFactory("informers")); + final SharedInformerFactory informerFactory = + kubeClient.informers(executorPool).inNamespace(kubeClient.getNamespace()); + try { + + CustomResourceDefinition shuffleApplicationCRD = createRemoteShuffleApplicationCRD(); + // create CRD for the flink remote shuffle service. + kubeClient + .apiextensions() + .v1beta1() + .customResourceDefinitions() + .createOrReplace(shuffleApplicationCRD); + + RemoteShuffleApplicationController remoteShuffleApplicationController = + RemoteShuffleApplicationController.createRemoteShuffleApplicationController( + kubeClient, informerFactory); + + informerFactory.startAllRegisteredInformers(); + informerFactory.addSharedInformerEventListener( + exception -> LOG.error("Exception occurred, but caught", exception)); + remoteShuffleApplicationController.run(); + } catch (Throwable throwable) { + LOG.error("Remote shuffle application operator terminated.", throwable); + } finally { + informerFactory.stopAllRegisteredInformers(); + kubeClient.close(); + } + } + + static CustomResourceDefinition createRemoteShuffleApplicationCRD() { + // create CRD builder from context. + CustomResourceDefinitionBuilder customCRDBuilder = + CustomResourceDefinitionContext.v1beta1CRDFromCustomResourceType( + RemoteShuffleApplication.class); + + // setup additional print columns. + ADDITIONAL_COLUMN.forEach( + column -> { + customCRDBuilder + .editSpec() + .addNewAdditionalPrinterColumn() + .withName(column.getLeft()) + .withType(column.getMiddle()) + .withJSONPath(column.getRight()) + .endAdditionalPrinterColumn() + .endSpec(); + }); + + // setup status + customCRDBuilder + .editOrNewSpec() + .editOrNewSubresources() + .withNewStatus() + .endStatus() + .endSubresources() + .endSpec(); + + return customCRDBuilder.build(); + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/controller/RemoteShuffleApplicationController.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/controller/RemoteShuffleApplicationController.java new file mode 100644 index 00000000..4114db31 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/controller/RemoteShuffleApplicationController.java @@ -0,0 +1,482 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.controller; + +import com.alibaba.flink.shuffle.kubernetes.operator.crd.RemoteShuffleApplication; +import com.alibaba.flink.shuffle.kubernetes.operator.crd.RemoteShuffleApplicationList; +import com.alibaba.flink.shuffle.kubernetes.operator.crd.RemoteShuffleApplicationSpec; +import com.alibaba.flink.shuffle.kubernetes.operator.crd.RemoteShuffleApplicationStatus; +import com.alibaba.flink.shuffle.kubernetes.operator.reconciler.RemoteShuffleApplicationReconciler; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesUtils; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.fabric8.kubernetes.api.model.HasMetadata; +import io.fabric8.kubernetes.api.model.OwnerReference; +import io.fabric8.kubernetes.api.model.apps.DaemonSet; +import io.fabric8.kubernetes.api.model.apps.DaemonSetStatus; +import io.fabric8.kubernetes.api.model.apps.Deployment; +import io.fabric8.kubernetes.api.model.apps.DeploymentStatus; +import io.fabric8.kubernetes.client.KubernetesClient; +import io.fabric8.kubernetes.client.dsl.MixedOperation; +import io.fabric8.kubernetes.client.dsl.Resource; +import io.fabric8.kubernetes.client.informers.ResourceEventHandler; +import io.fabric8.kubernetes.client.informers.SharedIndexInformer; +import io.fabric8.kubernetes.client.informers.SharedInformerFactory; +import io.fabric8.kubernetes.client.informers.cache.Cache; +import io.fabric8.kubernetes.client.informers.cache.Lister; +import io.fabric8.kubernetes.client.utils.Serialization; +import org.apache.commons.lang3.tuple.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Controller of {@link RemoteShuffleApplication}. This class is responsible for following things: + * + *

1. Monitor the addition, modification and deletion requests of {@link + * RemoteShuffleApplication}, and take actions. + * + *

2. Monitor the status of {@link RemoteShuffleApplication} internal components (ShuffleManager + * and ShuffleWorkers). Once they are not in the desired state, reconciling them by {@link + * RemoteShuffleApplicationReconciler}. + */ +public class RemoteShuffleApplicationController { + + private static final Logger LOG = + LoggerFactory.getLogger(RemoteShuffleApplicationController.class); + private static final long RSYNC_MILLIS = 10 * 60 * 1000; + private static final long UPDATE_STATUS_INTERVAL_MS = 1000L; + + private final BlockingQueue> workQueue; + private final SharedIndexInformer shuffleAppInformer; + private final SharedIndexInformer deploymentInformer; + private final SharedIndexInformer daemonSetInformer; + + private final Lister deploymentLister; + private final Lister daemonSetLister; + private final Lister shuffleAppLister; + + private final MixedOperation< + RemoteShuffleApplication, + RemoteShuffleApplicationList, + Resource> + shuffleAppClient; + + private final RemoteShuffleApplicationReconciler shuffleAppReconciler; + private volatile long lastUpdateStatusTime = 0L; + + AtomicBoolean isRunning = new AtomicBoolean(false); + + private enum ComponentEventType { + ADD, + DELETE, + UPDATE_SPEC, + UPDATE_STATUS + } + + private enum RemoteShuffleAction { + ADD_OR_UPDATE, + UPDATE_STATUS + } + + public RemoteShuffleApplicationController( + SharedIndexInformer shuffleAppInformer, + SharedIndexInformer deploymentInformer, + SharedIndexInformer daemonSetInformer, + KubernetesClient kubeClient, + MixedOperation< + RemoteShuffleApplication, + RemoteShuffleApplicationList, + Resource> + shuffleAppClient) { + + this.shuffleAppInformer = shuffleAppInformer; + this.deploymentInformer = deploymentInformer; + this.daemonSetInformer = daemonSetInformer; + this.shuffleAppClient = shuffleAppClient; + + this.shuffleAppReconciler = new RemoteShuffleApplicationReconciler(kubeClient); + this.shuffleAppLister = new Lister<>(shuffleAppInformer.getIndexer()); + this.deploymentLister = new Lister<>(deploymentInformer.getIndexer()); + this.daemonSetLister = new Lister<>(daemonSetInformer.getIndexer()); + + // TODO: Removing the duplicate key. + this.workQueue = new ArrayBlockingQueue<>(4096); + } + + public void create() { + // Monitor the addition, modification and deletion requests. + shuffleAppInformer.addEventHandler( + new ResourceEventHandler() { + @Override + public void onAdd(RemoteShuffleApplication shuffleApp) { + LOG.info( + "RemoteShuffleApplication {} add.", + KubernetesUtils.getResourceFullName(shuffleApp)); + enqueueRemoteShuffleApplication( + shuffleApp, RemoteShuffleAction.ADD_OR_UPDATE); + } + + @Override + public void onUpdate( + RemoteShuffleApplication oldShuffleApp, + RemoteShuffleApplication newShuffleApp) { + // Only spec changed, we do reconcile. + if (!Objects.equals(oldShuffleApp.getSpec(), newShuffleApp.getSpec())) { + LOG.info( + "RemoteShuffleApplication {} spec update.", + KubernetesUtils.getResourceFullName(newShuffleApp)); + enqueueRemoteShuffleApplication( + newShuffleApp, RemoteShuffleAction.ADD_OR_UPDATE); + } + } + + @Override + public void onDelete(RemoteShuffleApplication shuffleApp, boolean b) { + LOG.info( + "RemoteShuffleApplication {} delete.", + KubernetesUtils.getResourceFullName(shuffleApp)); + // do nothing. + } + }); + + // Monitor the Deployments status. + // Set up an event handler for when Deployment resources change. This handler will lookup + // the owner of the given Deployment, and if it is owned by a RemoteShuffleApplication + // resource will enqueue that RemoteShuffleApplication resource for processing. This way, we + // don't need to implement custom logic for handling Deployment resources. More info on this + // pattern: + // https://github.com/kubernetes/community/blob/8cafef897a22026d42f5e5bb3f104febe7e29830/contributors/devel/controllers.md + deploymentInformer.addEventHandler( + new ResourceEventHandler() { + @Override + public void onAdd(Deployment deployment) { + if (LOG.isDebugEnabled()) { + LOG.debug( + "Deployment {} add.", + KubernetesUtils.getResourceFullName(deployment)); + } + handleComponentEvent(deployment, ComponentEventType.ADD); + } + + @Override + public void onUpdate(Deployment oldDeployment, Deployment newDeployment) { + if (LOG.isDebugEnabled()) { + LOG.debug( + "Deployment {} update.", + KubernetesUtils.getResourceFullName(oldDeployment)); + } + if (!Objects.equals(oldDeployment.getSpec(), newDeployment.getSpec())) { + handleComponentEvent(newDeployment, ComponentEventType.UPDATE_SPEC); + } else if (!Objects.equals( + oldDeployment.getStatus(), newDeployment.getStatus())) { + handleComponentEvent(newDeployment, ComponentEventType.UPDATE_STATUS); + } + } + + @Override + public void onDelete(Deployment deployment, boolean b) { + if (LOG.isDebugEnabled()) { + LOG.debug( + "Deployment {} delete.", + KubernetesUtils.getResourceFullName(deployment)); + } + handleComponentEvent(deployment, ComponentEventType.DELETE); + } + }); + + // Monitor the DaemonSets status. + daemonSetInformer.addEventHandler( + new ResourceEventHandler() { + @Override + public void onAdd(DaemonSet daemonSet) { + if (LOG.isDebugEnabled()) { + LOG.debug( + "DaemonSet {} add.", + KubernetesUtils.getResourceFullName(daemonSet)); + } + handleComponentEvent(daemonSet, ComponentEventType.ADD); + } + + @Override + public void onUpdate(DaemonSet oldDaemonSet, DaemonSet newDaemonSet) { + if (LOG.isDebugEnabled()) { + LOG.debug( + "DaemonSet {} update.", + KubernetesUtils.getResourceFullName(oldDaemonSet)); + } + if (!Objects.equals(oldDaemonSet.getSpec(), newDaemonSet.getSpec())) { + handleComponentEvent(newDaemonSet, ComponentEventType.UPDATE_SPEC); + } else if (!Objects.equals( + oldDaemonSet.getStatus(), newDaemonSet.getStatus())) { + handleComponentEvent(newDaemonSet, ComponentEventType.UPDATE_STATUS); + } + } + + @Override + public void onDelete(DaemonSet daemonSet, boolean b) { + if (LOG.isDebugEnabled()) { + LOG.debug( + "DaemonSet {} delete.", + KubernetesUtils.getResourceFullName(daemonSet)); + } + handleComponentEvent(daemonSet, ComponentEventType.DELETE); + } + }); + } + + private void enqueueRemoteShuffleApplication( + RemoteShuffleApplication shuffleApp, RemoteShuffleAction actionType) { + String key = Cache.metaNamespaceKeyFunc(shuffleApp); + if (key != null && !key.isEmpty()) { + workQueue.add(Pair.of(key, actionType)); + if (LOG.isDebugEnabled()) { + LOG.debug( + "Enqueue RemoteShuffleApplication({}) with key {}, action {}", + shuffleApp.getMetadata().getName(), + key, + actionType); + } + } + } + + public void run() { + LOG.info("Starting RemoteShuffleApplication controller"); + + isRunning.set(true); + + while (!shuffleAppInformer.hasSynced() + || !deploymentInformer.hasSynced() + || !daemonSetInformer.hasSynced()) { + // Wait till Informer syncs + } + + while (true) { + try { + Pair shuffleEvent = workQueue.take(); + String key = shuffleEvent.getKey(); + RemoteShuffleAction remoteShuffleAction = shuffleEvent.getValue(); + if (key == null || key.isEmpty() || (!key.contains("/"))) { + LOG.info("Invalid resource key: {}", key); + } + + // Get the RemoteShuffleApplication resource's name from key (in format + // namespace/name) + RemoteShuffleApplication shuffleApp = shuffleAppLister.get(key); + if (shuffleApp == null) { + LOG.warn("RemoteShuffleApplication {} in workQueue no longer exists", key); + } else { + if (remoteShuffleAction == RemoteShuffleAction.ADD_OR_UPDATE) { + shuffleAppReconciler.reconcile(shuffleApp); + } + updateRemoteShuffleApplicationStatusWithRetry(shuffleApp); + } + + } catch (InterruptedException interruptedException) { + Thread.currentThread().interrupt(); + LOG.info("RemoteShuffleApplication controller interrupted.."); + break; + } catch (Throwable throwable) { + LOG.error("Error : ", throwable); + } + } + } + + private void handleComponentEvent(HasMetadata resource, ComponentEventType eventType) { + OwnerReference ownerReference = KubernetesUtils.getControllerOf(resource); + if (ownerReference != null + && ownerReference.getKind().equalsIgnoreCase(RemoteShuffleApplication.KIND)) { + RemoteShuffleApplication shuffleApp = + shuffleAppLister + .namespace(resource.getMetadata().getNamespace()) + .get(ownerReference.getName()); + if (shuffleApp == null) { + LOG.warn( + "Receive ShuffleManager/ShuffleWorkers event : {}({}) {} of an unknown shuffle application {}.", + resource.getClass().getSimpleName(), + KubernetesUtils.getResourceFullName(resource), + eventType, + ownerReference.getName()); + } else { + if (eventType == ComponentEventType.UPDATE_STATUS) { + enqueueRemoteShuffleApplication(shuffleApp, RemoteShuffleAction.UPDATE_STATUS); + } else { + LOG.info( + "Receive ShuffleManager/ShuffleWorkers event : {}({}) {}.", + resource.getClass().getSimpleName(), + KubernetesUtils.getResourceFullName(resource), + eventType); + if (eventType == ComponentEventType.ADD) { + // do nothing. + } else { + enqueueRemoteShuffleApplication( + shuffleApp, RemoteShuffleAction.ADD_OR_UPDATE); + } + } + } + } + } + + private void updateRemoteShuffleApplicationStatusWithRetry( + RemoteShuffleApplication shuffleApp) { + long currentTime = System.currentTimeMillis(); + if (currentTime - lastUpdateStatusTime > UPDATE_STATUS_INTERVAL_MS) { + KubernetesUtils.executeWithRetry( + () -> updateRemoteShuffleApplicationStatus(shuffleApp), + String.format( + "Update %s status", KubernetesUtils.getResourceFullName(shuffleApp))); + lastUpdateStatusTime = currentTime; + } + } + + /** + * Query the status of each component of {@link RemoteShuffleApplication}, and update the status + * of remote shuffle application based on components' statuses. + */ + private void updateRemoteShuffleApplicationStatus(RemoteShuffleApplication shuffleApp) { + + String namespace = shuffleApp.getMetadata().getNamespace(); + String clusterId = shuffleApp.getMetadata().getName(); + + String shuffleManagerName = + String.format( + "%s/%s", + namespace, KubernetesUtils.getShuffleManagerNameWithClusterId(clusterId)); + String shuffleWorkersName = + String.format( + "%s/%s", + namespace, KubernetesUtils.getShuffleWorkersNameWithClusterId(clusterId)); + + Deployment deployment = deploymentLister.get(shuffleManagerName); + DaemonSet daemonSet = daemonSetLister.get(shuffleWorkersName); + + RemoteShuffleApplicationStatus currentStatus = new RemoteShuffleApplicationStatus(); + + if (deployment != null && deployment.getStatus() != null) { + DeploymentStatus status = deployment.getStatus(); + if (status.getReplicas() != null) { + currentStatus.setDesiredShuffleManagers(status.getReplicas()); + } + if (status.getReadyReplicas() != null) { + currentStatus.setReadyShuffleManagers(status.getReadyReplicas()); + } + } + + if (daemonSet != null && daemonSet.getStatus() != null) { + DaemonSetStatus status = daemonSet.getStatus(); + if (status.getDesiredNumberScheduled() != null) { + currentStatus.setDesiredShuffleWorkers(status.getDesiredNumberScheduled()); + } + if (status.getNumberReady() != null) { + currentStatus.setReadyShuffleWorkers(status.getNumberReady()); + } + } + + RemoteShuffleApplication cloneShuffleApp = cloneRemoteShuffleApplication(shuffleApp); + cloneShuffleApp.setStatus(currentStatus); + shuffleAppClient.inNamespace(namespace).withName(clusterId).updateStatus(cloneShuffleApp); + } + + private static RemoteShuffleApplication cloneRemoteShuffleApplication( + RemoteShuffleApplication shuffleApp) { + RemoteShuffleApplication cloneShuffleApp = new RemoteShuffleApplication(); + RemoteShuffleApplicationSpec cloneShuffleSpec = + new RemoteShuffleApplicationSpec( + shuffleApp.getSpec().getShuffleDynamicConfigs(), + shuffleApp.getSpec().getShuffleFileConfigs()); + + cloneShuffleApp.setSpec(cloneShuffleSpec); + cloneShuffleApp.setMetadata(shuffleApp.getMetadata()); + + return cloneShuffleApp; + } + + static Deployment cloneResource(Deployment deployment) throws IOException { + final ObjectMapper jsonMapper = Serialization.jsonMapper(); + byte[] bytes = jsonMapper.writeValueAsBytes(deployment); + return jsonMapper.readValue(bytes, Deployment.class); + } + + static DaemonSet cloneResource(DaemonSet daemonSet) throws IOException { + final ObjectMapper jsonMapper = Serialization.jsonMapper(); + byte[] bytes = jsonMapper.writeValueAsBytes(daemonSet); + return jsonMapper.readValue(bytes, DaemonSet.class); + } + + public static RemoteShuffleApplicationController createRemoteShuffleApplicationController( + KubernetesClient kubeClient, SharedInformerFactory informerFactory) { + // create informers. + final SharedIndexInformer shuffleAppInformer = + informerFactory.sharedIndexInformerForCustomResource( + RemoteShuffleApplication.class, RSYNC_MILLIS); + final SharedIndexInformer deploymentInformer = + informerFactory.sharedIndexInformerFor(Deployment.class, RSYNC_MILLIS); + final SharedIndexInformer daemonSetInformer = + informerFactory.sharedIndexInformerFor(DaemonSet.class, RSYNC_MILLIS); + + MixedOperation< + RemoteShuffleApplication, + RemoteShuffleApplicationList, + Resource> + shuffleAppClient = + kubeClient.customResources( + RemoteShuffleApplication.class, RemoteShuffleApplicationList.class); + + // create shuffle application controller. + RemoteShuffleApplicationController remoteShuffleApplicationController = + new RemoteShuffleApplicationController( + shuffleAppInformer, + deploymentInformer, + daemonSetInformer, + kubeClient, + shuffleAppClient); + remoteShuffleApplicationController.create(); + + return remoteShuffleApplicationController; + } + + // --------------------------------------------------------------------------------------------- + // For test + // --------------------------------------------------------------------------------------------- + + SharedIndexInformer getDeploymentInformer() { + return deploymentInformer; + } + + SharedIndexInformer getDaemonSetInformer() { + return daemonSetInformer; + } + + Lister getShuffleAppLister() { + return shuffleAppLister; + } + + Lister getDaemonSetLister() { + return daemonSetLister; + } + + Lister getDeploymentLister() { + return deploymentLister; + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplication.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplication.java new file mode 100644 index 00000000..7c5ec2ce --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplication.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.crd; + +import io.fabric8.kubernetes.api.model.Namespaced; +import io.fabric8.kubernetes.client.CustomResource; +import io.fabric8.kubernetes.model.annotation.Group; +import io.fabric8.kubernetes.model.annotation.Kind; +import io.fabric8.kubernetes.model.annotation.Plural; +import io.fabric8.kubernetes.model.annotation.Singular; +import io.fabric8.kubernetes.model.annotation.Version; + +/** + * {@link RemoteShuffleApplication} is an implementation of Kubernetes {@link CustomResource}, which + * represents a remote shuffle cluster. {@link #spec} contains the configurations of the cluster, + * and {@link #status} represents current status of the cluster. + */ +@Version("v1") +@Group("shuffleoperator.alibaba.com") +@Kind(RemoteShuffleApplication.KIND) +@Singular("remoteshuffle") +@Plural("remoteshuffles") +public class RemoteShuffleApplication + extends CustomResource + implements Namespaced { + + private static final long serialVersionUID = -7093257940691211895L; + + public static final String KIND = "RemoteShuffle"; +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationList.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationList.java new file mode 100644 index 00000000..fcded822 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationList.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.crd; + +import io.fabric8.kubernetes.client.CustomResourceList; + +/** The list variant of the {@link RemoteShuffleApplication}. */ +public class RemoteShuffleApplicationList extends CustomResourceList { + + private static final long serialVersionUID = -8862204525168889881L; +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationSpec.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationSpec.java new file mode 100644 index 00000000..5d0feba3 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationSpec.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.crd; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; + +import javax.annotation.Nullable; + +import java.util.Map; +import java.util.Objects; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** Specifications for {@link RemoteShuffleApplication}. */ +@JsonInclude(JsonInclude.Include.NON_NULL) +@JsonDeserialize() +public class RemoteShuffleApplicationSpec { + + private Map shuffleDynamicConfigs; + + @Nullable private Map shuffleFileConfigs; + + public Map getShuffleDynamicConfigs() { + checkNotNull(shuffleDynamicConfigs); + return shuffleDynamicConfigs; + } + + public RemoteShuffleApplicationSpec() {} + + public RemoteShuffleApplicationSpec( + Map shuffleDynamicConfigs, + @Nullable Map shuffleFileConfigs) { + this.shuffleFileConfigs = shuffleFileConfigs; + this.shuffleDynamicConfigs = checkNotNull(shuffleDynamicConfigs); + } + + public void setShuffleDynamicConfigs(Map shuffleDynamicConfigs) { + this.shuffleDynamicConfigs = checkNotNull(shuffleDynamicConfigs); + } + + @Nullable + public Map getShuffleFileConfigs() { + return shuffleFileConfigs; + } + + public void setShuffleFileConfigs(@Nullable Map shuffleFileConfigs) { + this.shuffleFileConfigs = shuffleFileConfigs; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof RemoteShuffleApplicationSpec) { + RemoteShuffleApplicationSpec spec = (RemoteShuffleApplicationSpec) obj; + return shuffleDynamicConfigs.equals(spec.shuffleDynamicConfigs) + && Objects.equals(shuffleFileConfigs, spec.shuffleFileConfigs); + } else { + return false; + } + } + + @Override + public String toString() { + return "RemoteShuffleApplicationSpec(" + + "shuffleDynamicConfigs=" + + shuffleDynamicConfigs + + ", shuffleFileConfigs=" + + shuffleFileConfigs + + ")"; + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationStatus.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationStatus.java new file mode 100644 index 00000000..0ec0628c --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationStatus.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.crd; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; + +/** Status of the {@link RemoteShuffleApplication}. */ +@JsonInclude(JsonInclude.Include.NON_NULL) +@JsonDeserialize() +public class RemoteShuffleApplicationStatus { + + private int readyShuffleManagers; + private int readyShuffleWorkers; + private int desiredShuffleManagers; + private int desiredShuffleWorkers; + + public RemoteShuffleApplicationStatus() {} + + RemoteShuffleApplicationStatus( + int readyShuffleManagers, + int readyShuffleWorkers, + int desiredShuffleManagers, + int desiredShuffleWorkers) { + + this.readyShuffleManagers = readyShuffleManagers; + this.readyShuffleWorkers = readyShuffleWorkers; + this.desiredShuffleManagers = desiredShuffleManagers; + this.desiredShuffleWorkers = desiredShuffleWorkers; + } + + public int getReadyShuffleManagers() { + return readyShuffleManagers; + } + + public void setReadyShuffleManagers(int readyShuffleManagers) { + this.readyShuffleManagers = readyShuffleManagers; + } + + public int getReadyShuffleWorkers() { + return readyShuffleWorkers; + } + + public void setReadyShuffleWorkers(int readyShuffleWorkers) { + this.readyShuffleWorkers = readyShuffleWorkers; + } + + public int getDesiredShuffleManagers() { + return desiredShuffleManagers; + } + + public void setDesiredShuffleManagers(int desiredShuffleManagers) { + this.desiredShuffleManagers = desiredShuffleManagers; + } + + public int getDesiredShuffleWorkers() { + return desiredShuffleWorkers; + } + + public void setDesiredShuffleWorkers(int desiredShuffleWorkers) { + this.desiredShuffleWorkers = desiredShuffleWorkers; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof RemoteShuffleApplicationStatus) { + RemoteShuffleApplicationStatus status = (RemoteShuffleApplicationStatus) obj; + return this.readyShuffleManagers == status.readyShuffleManagers + && this.readyShuffleWorkers == status.readyShuffleWorkers + && this.desiredShuffleManagers == status.desiredShuffleManagers + && this.desiredShuffleWorkers == status.desiredShuffleWorkers; + } else { + return false; + } + } + + @Override + public String toString() { + return "RemoteShuffleApplicationStatus(" + + "readyShuffleManagers=" + + readyShuffleManagers + + ", readyShuffleWorkers=" + + readyShuffleWorkers + + ", desiredShuffleManagers=" + + desiredShuffleManagers + + ", desiredShuffleWorkers=" + + desiredShuffleWorkers + + ")"; + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/AbstractKubernetesParameters.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/AbstractKubernetesParameters.java new file mode 100644 index 00000000..6f724fb1 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/AbstractKubernetesParameters.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; +import com.alibaba.flink.shuffle.core.config.KubernetesOptions; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.util.ConfigMapVolume; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesInternalOptions; + +import org.apache.commons.lang3.StringUtils; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** Abstract class for the {@link KubernetesPodParameters}. */ +public abstract class AbstractKubernetesParameters + implements KubernetesContainerParameters, KubernetesPodParameters { + + protected final Configuration conf; + + public AbstractKubernetesParameters(Configuration conf) { + this.conf = checkNotNull(conf); + } + + @Override + public String getNamespace() { + final String namespace = conf.getString(KubernetesInternalOptions.NAMESPACE); + checkArgument( + !namespace.trim().isEmpty(), + "Invalid " + KubernetesInternalOptions.NAMESPACE + "."); + + return namespace; + } + + @Override + public boolean enablePodHostNetwork() { + return conf.getBoolean(KubernetesOptions.POD_HOST_NETWORK_ENABLED); + } + + @Override + public String getContainerImage() { + final String containerImage = conf.getString(KubernetesOptions.CONTAINER_IMAGE); + checkArgument( + !containerImage.trim().isEmpty(), + "Invalid " + KubernetesOptions.CONTAINER_IMAGE + "."); + return containerImage; + } + + @Override + public String getContainerImagePullPolicy() { + return conf.getString(KubernetesOptions.CONTAINER_IMAGE_PULL_POLICY); + } + + @Override + public List getConfigMapVolumes() { + Optional volumeName = + Optional.ofNullable(conf.getString(KubernetesInternalOptions.CONFIG_VOLUME_NAME)); + Optional configMapName = + Optional.ofNullable( + conf.getString(KubernetesInternalOptions.CONFIG_VOLUME_CONFIG_MAP_NAME)); + Optional mountPath = + Optional.ofNullable( + conf.getString(KubernetesInternalOptions.CONFIG_VOLUME_MOUNT_PATH)); + Map items = conf.getMap(KubernetesInternalOptions.CONFIG_VOLUME_ITEMS); + if (volumeName.isPresent()) { + checkState(configMapName.isPresent()); + checkState(mountPath.isPresent()); + return Collections.singletonList( + new ConfigMapVolume( + volumeName.get(), + configMapName.get(), + checkNotNull(items), + mountPath.get())); + } else { + checkState(items.isEmpty()); + return Collections.emptyList(); + } + } + + @Override + public KubernetesContainerParameters getContainerParameters() { + return this; + } + + public String getClusterId() { + final String clusterId = conf.getString(ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID); + + if (StringUtils.isBlank(clusterId)) { + throw new IllegalArgumentException( + ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID.key() + " must not be blank."); + } else if (clusterId.length() > Constants.MAXIMUM_CHARACTERS_OF_CLUSTER_ID) { + throw new IllegalArgumentException( + ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID.key() + + " must be no more than " + + Constants.MAXIMUM_CHARACTERS_OF_CLUSTER_ID + + " characters."); + } + + return clusterId; + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/K8sRemoteShuffleFileConfigsParameters.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/K8sRemoteShuffleFileConfigsParameters.java new file mode 100644 index 00000000..80831153 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/K8sRemoteShuffleFileConfigsParameters.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters; + +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesUtils; + +import io.fabric8.kubernetes.api.model.ConfigMap; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * The configuration files will be deployed as a kubernetes {@link ConfigMap} and mounted to the + * pod. This class is responsible for converting config files into kubernetes {@link ConfigMap} + * configurations. + */ +public class K8sRemoteShuffleFileConfigsParameters implements KubernetesConfigMapParameters { + private final String namespace; + private final String clusterId; + private final Map data; + + public K8sRemoteShuffleFileConfigsParameters( + String namespace, String clusterId, Map data) { + this.namespace = namespace; + this.clusterId = clusterId; + this.data = data; + } + + @Override + public String getConfigMapName() { + return clusterId + "-configmap"; + } + + @Override + public Map getData() { + return data; + } + + @Override + public String getNamespace() { + return namespace; + } + + @Override + public Map getLabels() { + final Map labels = new HashMap<>(); + labels.putAll(KubernetesUtils.getCommonLabels(clusterId)); + return Collections.unmodifiableMap(labels); + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesCommonParameters.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesCommonParameters.java new file mode 100644 index 00000000..fcd9ace8 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesCommonParameters.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters; + +import java.util.Map; + +/** The common parameters for resources (Pod, Deployment, DaemonSet, ConfigMap). */ +public interface KubernetesCommonParameters { + + String getNamespace(); + + Map getLabels(); +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesConfigMapParameters.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesConfigMapParameters.java new file mode 100644 index 00000000..f83ce451 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesConfigMapParameters.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters; + +import java.util.Map; + +/** The parameters that is used to construct a kubernetes ConfigMap. */ +public interface KubernetesConfigMapParameters extends KubernetesCommonParameters { + + String getConfigMapName(); + + Map getData(); +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesContainerParameters.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesContainerParameters.java new file mode 100644 index 00000000..d08838e2 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesContainerParameters.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters; + +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.util.ContainerCommandAndArgs; + +import java.util.List; +import java.util.Map; + +/** The parameters that is used to construct a kubernetes Container. */ +public interface KubernetesContainerParameters { + + String getContainerName(); + + String getContainerImage(); + + String getContainerImagePullPolicy(); + + List> getContainerVolumeMounts(); + + Integer getContainerMemoryMB(); + + Double getContainerCPU(); + + Map getResourceLimitFactors(); + + ContainerCommandAndArgs getContainerCommandAndArgs(); + + Map getEnvironmentVars(); +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesDaemonSetParameters.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesDaemonSetParameters.java new file mode 100644 index 00000000..112aa2e0 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesDaemonSetParameters.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters; + +/** The parameters that is used to construct a kubernetes DaemonSet. */ +public interface KubernetesDaemonSetParameters extends KubernetesCommonParameters { + + String getDaemonSetName(); + + KubernetesPodParameters getPodTemplateParameters(); +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesDeploymentParameters.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesDeploymentParameters.java new file mode 100644 index 00000000..b6b90edb --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesDeploymentParameters.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters; + +/** The parameters that is used to construct a kubernetes Deployment. */ +public interface KubernetesDeploymentParameters extends KubernetesCommonParameters { + + String getDeploymentName(); + + int getReplicas(); + + KubernetesPodParameters getPodTemplateParameters(); +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesPodParameters.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesPodParameters.java new file mode 100644 index 00000000..56c32b37 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesPodParameters.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters; + +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.util.ConfigMapVolume; + +import java.util.List; +import java.util.Map; + +/** The parameters that is used to construct a kubernetes Pod. */ +public interface KubernetesPodParameters extends KubernetesCommonParameters { + + Map getNodeSelector(); + + boolean enablePodHostNetwork(); + + List> getEmptyDirVolumes(); + + List> getHostPathVolumes(); + + // currently, only support one config map + List getConfigMapVolumes(); + + List> getTolerations(); + + KubernetesContainerParameters getContainerParameters(); +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesShuffleManagerParameters.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesShuffleManagerParameters.java new file mode 100644 index 00000000..af96843c --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesShuffleManagerParameters.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.config.KubernetesOptions; +import com.alibaba.flink.shuffle.core.config.memory.ShuffleManagerProcessSpec; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.util.ConfigMapVolume; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.util.ContainerCommandAndArgs; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesUtils; + +import io.fabric8.kubernetes.api.model.apps.Deployment; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * ShuffleManager will be deployed as a {@link Deployment} in Kubernetes. This class helps to parse + * ShuffleManger configuration to kubernetes {@link Deployment} configuration. + */ +public class KubernetesShuffleManagerParameters extends AbstractKubernetesParameters + implements KubernetesDeploymentParameters { + + private final ShuffleManagerProcessSpec shuffleManagerProcessSpec; + + public KubernetesShuffleManagerParameters(Configuration conf) { + super(conf); + this.shuffleManagerProcessSpec = new ShuffleManagerProcessSpec(conf); + } + + @Override + public Map getLabels() { + final Map labels = new HashMap<>(); + labels.putAll( + Optional.ofNullable(conf.getMap(KubernetesOptions.SHUFFLE_MANAGER_LABELS)) + .orElse(Collections.emptyMap())); + labels.putAll(KubernetesUtils.getCommonLabels(getClusterId())); + labels.put(Constants.LABEL_COMPONENT_KEY, Constants.LABEL_COMPONENT_SHUFFLE_MANAGER); + return Collections.unmodifiableMap(labels); + } + + @Override + public Map getNodeSelector() { + return Optional.ofNullable(conf.getMap(KubernetesOptions.SHUFFLE_MANAGER_NODE_SELECTOR)) + .orElse(Collections.emptyMap()); + } + + @Override + public String getContainerName() { + return KubernetesUtils.SHUFFLE_MANAGER_CONTAINER_NAME; + } + + @Override + public List> getEmptyDirVolumes() { + return KubernetesUtils.filterVolumesConfigs( + conf, + KubernetesOptions.SHUFFLE_MANAGER_EMPTY_DIR_VOLUMES, + KubernetesUtils::filterEmptyDirVolumeConfigs); + } + + @Override + public List> getHostPathVolumes() { + return KubernetesUtils.filterVolumesConfigs( + conf, + KubernetesOptions.SHUFFLE_MANAGER_HOST_PATH_VOLUMES, + KubernetesUtils::filterHostPathVolumeConfigs); + } + + @Override + public List> getTolerations() { + return Optional.ofNullable( + conf.getList(KubernetesOptions.SHUFFLE_MANAGER_TOLERATIONS, Map.class)) + .orElse(Collections.emptyList()); + } + + @Override + public Map getEnvironmentVars() { + Map envVars = new HashMap<>(); + envVars.put(Constants.LABEL_APPTYPE_KEY.toUpperCase(), Constants.LABEL_APPTYPE_VALUE); + envVars.put(Constants.LABEL_APP_KEY.toUpperCase(), getClusterId()); + envVars.put( + Constants.LABEL_COMPONENT_KEY.toUpperCase(), + Constants.LABEL_COMPONENT_SHUFFLE_MANAGER); + envVars.putAll(conf.getMap(KubernetesOptions.SHUFFLE_MANAGER_ENV_VARS)); + return envVars; + } + + @Override + public List> getContainerVolumeMounts() { + + List> volumeMountsConfigs = new ArrayList<>(); + + // empty dir volume mounts + volumeMountsConfigs.addAll( + KubernetesUtils.filterVolumesConfigs( + conf, + KubernetesOptions.SHUFFLE_MANAGER_EMPTY_DIR_VOLUMES, + KubernetesUtils::filterVolumeMountsConfigs)); + + // host path volume mounts + volumeMountsConfigs.addAll( + KubernetesUtils.filterVolumesConfigs( + conf, + KubernetesOptions.SHUFFLE_MANAGER_HOST_PATH_VOLUMES, + KubernetesUtils::filterVolumeMountsConfigs)); + + // configmap volume mounts + Map configmapVolumeMounts = new HashMap<>(); + List configMapVolumes = getConfigMapVolumes(); + if (!configMapVolumes.isEmpty()) { + checkState(configMapVolumes.size() == 1); + ConfigMapVolume configMapVolume = configMapVolumes.get(0); + configmapVolumeMounts.put(Constants.VOLUME_NAME, configMapVolume.getVolumeName()); + configmapVolumeMounts.put(Constants.VOLUME_MOUNT_PATH, configMapVolume.getMountPath()); + volumeMountsConfigs.add(configmapVolumeMounts); + } + + return volumeMountsConfigs; + } + + @Override + public Integer getContainerMemoryMB() { + return shuffleManagerProcessSpec.getTotalProcessMemorySize().getMebiBytes(); + } + + @Override + public Double getContainerCPU() { + return conf.getDouble(KubernetesOptions.SHUFFLE_MANAGER_CPU); + } + + @Override + public Map getResourceLimitFactors() { + return KubernetesUtils.getPrefixedKeyValuePairs( + KubernetesOptions.SHUFFLE_MANAGER_RESOURCE_LIMIT_FACTOR_PREFIX, conf); + } + + @Override + public ContainerCommandAndArgs getContainerCommandAndArgs() { + String command = "bash"; + return new ContainerCommandAndArgs( + command, + Arrays.asList( + "-c", + Constants.SHUFFLE_MANAGER_SCRIPT_PATH + " " + getShuffleManagerConfigs())); + } + + @Override + public String getDeploymentName() { + // The shuffle manager deployment name pattern is {clusterId}-{shufflemanager} + return KubernetesUtils.getShuffleManagerNameWithClusterId(getClusterId()); + } + + @Override + public int getReplicas() { + return 1; + } + + @Override + public KubernetesContainerParameters getContainerParameters() { + return this; + } + + @Override + public KubernetesPodParameters getPodTemplateParameters() { + return this; + } + + private String getShuffleManagerConfigs() { + List dynamicConfigs = new ArrayList<>(); + conf.toMap().entrySet().stream() + .sorted(Map.Entry.comparingByKey()) + .forEach( + kv -> { + String configKey = kv.getKey(); + // following configs should not be passed to ShuffleManager: + // (1) start with "remote-shuffle.kubernetes". + // (2) start with "remote-shuffle.worker". + if (!configKey.startsWith("remote-shuffle.kubernetes.") + && !configKey.startsWith("remote-shuffle.worker.")) { + dynamicConfigs.add( + String.format("-D '%s=%s'", kv.getKey(), kv.getValue())); + } + }); + return String.join(" ", dynamicConfigs); + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesShuffleWorkerParameters.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesShuffleWorkerParameters.java new file mode 100644 index 00000000..d232146b --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesShuffleWorkerParameters.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.config.KubernetesOptions; +import com.alibaba.flink.shuffle.core.config.memory.ShuffleWorkerProcessSpec; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.util.ConfigMapVolume; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.util.ContainerCommandAndArgs; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesUtils; + +import io.fabric8.kubernetes.api.model.apps.DaemonSet; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * ShuffleWorkers will be deployed as a {@link DaemonSet} in Kubernetes. This class helps to parse + * ShuffleWorkers configuration to kubernetes {@link DaemonSet} configuration. + */ +public class KubernetesShuffleWorkerParameters extends AbstractKubernetesParameters + implements KubernetesDaemonSetParameters { + + private final ShuffleWorkerProcessSpec shuffleWorkerProcessSpec; + + public KubernetesShuffleWorkerParameters(Configuration conf) { + super(conf); + this.shuffleWorkerProcessSpec = new ShuffleWorkerProcessSpec(conf); + } + + @Override + public Map getLabels() { + final Map labels = new HashMap<>(); + labels.putAll( + Optional.ofNullable(conf.getMap(KubernetesOptions.SHUFFLE_WORKER_LABELS)) + .orElse(Collections.emptyMap())); + labels.putAll(KubernetesUtils.getCommonLabels(getClusterId())); + labels.put(Constants.LABEL_COMPONENT_KEY, Constants.LABEL_COMPONENT_SHUFFLE_WORKER); + return Collections.unmodifiableMap(labels); + } + + @Override + public Map getNodeSelector() { + return Optional.ofNullable(conf.getMap(KubernetesOptions.SHUFFLE_WORKER_NODE_SELECTOR)) + .orElse(Collections.emptyMap()); + } + + @Override + public String getContainerName() { + return KubernetesUtils.SHUFFLE_WORKER_CONTAINER_NAME; + } + + @Override + public Integer getContainerMemoryMB() { + return shuffleWorkerProcessSpec.getTotalProcessMemorySize().getMebiBytes(); + } + + @Override + public Double getContainerCPU() { + return conf.getDouble(KubernetesOptions.SHUFFLE_WORKER_CPU); + } + + @Override + public Map getResourceLimitFactors() { + return KubernetesUtils.getPrefixedKeyValuePairs( + KubernetesOptions.SHUFFLE_WORKER_RESOURCE_LIMIT_FACTOR_PREFIX, conf); + } + + @Override + public ContainerCommandAndArgs getContainerCommandAndArgs() { + String command = "bash"; + return new ContainerCommandAndArgs( + command, + Arrays.asList( + "-c", + Constants.SHUFFLE_WORKER_SCRIPT_PATH + " " + getShuffleWorkerConfigs())); + } + + @Override + public List> getContainerVolumeMounts() { + List> volumeMountsConfigs = new ArrayList<>(); + + // empty dir volume mounts + volumeMountsConfigs.addAll( + KubernetesUtils.filterVolumesConfigs( + conf, + KubernetesOptions.SHUFFLE_WORKER_EMPTY_DIR_VOLUMES, + KubernetesUtils::filterVolumeMountsConfigs)); + + // host path volume mounts + volumeMountsConfigs.addAll( + KubernetesUtils.filterVolumesConfigs( + conf, + KubernetesOptions.SHUFFLE_WORKER_HOST_PATH_VOLUMES, + KubernetesUtils::filterVolumeMountsConfigs)); + + // configmap volume mounts + Map configmapVolumeMounts = new HashMap<>(); + List configMapVolumes = getConfigMapVolumes(); + if (!configMapVolumes.isEmpty()) { + checkState(configMapVolumes.size() == 1); + ConfigMapVolume configMapVolume = configMapVolumes.get(0); + configmapVolumeMounts.put(Constants.VOLUME_NAME, configMapVolume.getVolumeName()); + configmapVolumeMounts.put(Constants.VOLUME_MOUNT_PATH, configMapVolume.getMountPath()); + volumeMountsConfigs.add(configmapVolumeMounts); + } + + return volumeMountsConfigs; + } + + @Override + public List> getEmptyDirVolumes() { + return KubernetesUtils.filterVolumesConfigs( + conf, + KubernetesOptions.SHUFFLE_WORKER_EMPTY_DIR_VOLUMES, + KubernetesUtils::filterEmptyDirVolumeConfigs); + } + + @Override + public List> getHostPathVolumes() { + return KubernetesUtils.filterVolumesConfigs( + conf, + KubernetesOptions.SHUFFLE_WORKER_HOST_PATH_VOLUMES, + KubernetesUtils::filterHostPathVolumeConfigs); + } + + @Override + public List> getTolerations() { + return Optional.ofNullable( + conf.getList(KubernetesOptions.SHUFFLE_WORKER_TOLERATIONS, Map.class)) + .orElse(Collections.emptyList()); + } + + @Override + public Map getEnvironmentVars() { + Map envVars = new HashMap<>(); + envVars.put(Constants.LABEL_APPTYPE_KEY.toUpperCase(), Constants.LABEL_APPTYPE_VALUE); + envVars.put(Constants.LABEL_APP_KEY.toUpperCase(), getClusterId()); + envVars.put( + Constants.LABEL_COMPONENT_KEY.toUpperCase(), + Constants.LABEL_COMPONENT_SHUFFLE_WORKER); + envVars.putAll(conf.getMap(KubernetesOptions.SHUFFLE_WORKER_ENV_VARS)); + return envVars; + } + + private String getShuffleWorkerConfigs() { + List dynamicConfigs = new ArrayList<>(); + conf.toMap().entrySet().stream() + .sorted(Map.Entry.comparingByKey()) + .forEachOrdered( + kv -> { + String configKey = kv.getKey(); + // following configs should not be passed to ShuffleWorker: + // (1) start with "remote-shuffle.kubernetes". + // (2) start with "remote-shuffle.manager". + if (!configKey.startsWith("remote-shuffle.kubernetes.") + && !configKey.startsWith("remote-shuffle.manager.")) { + dynamicConfigs.add( + String.format("-D '%s=%s'", kv.getKey(), kv.getValue())); + } + }); + return String.join(" ", dynamicConfigs); + } + + @Override + public String getDaemonSetName() { + return KubernetesUtils.getShuffleWorkersNameWithClusterId(getClusterId()); + } + + @Override + public KubernetesPodParameters getPodTemplateParameters() { + return this; + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/util/ConfigMapVolume.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/util/ConfigMapVolume.java new file mode 100644 index 00000000..13b9c5c9 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/util/ConfigMapVolume.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters.util; + +import java.util.Map; + +/** A utility class that contains the parameters for constructing a kubernetes configmap. */ +public class ConfigMapVolume { + + private final String configMapName; + private final String volumeName; + private final Map items; + private final String mountPath; + + public ConfigMapVolume( + String volumeName, String configMapName, Map items, String mountPath) { + this.volumeName = volumeName; + this.configMapName = configMapName; + this.items = items; + this.mountPath = mountPath; + } + + public String getConfigMapName() { + return configMapName; + } + + public Map getItems() { + return items; + } + + public String getVolumeName() { + return volumeName; + } + + public String getMountPath() { + return mountPath; + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/util/ContainerCommandAndArgs.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/util/ContainerCommandAndArgs.java new file mode 100644 index 00000000..1a7d5a7c --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/util/ContainerCommandAndArgs.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters.util; + +import java.util.List; + +/** A utility class that contains a kubernetes container command and args. */ +public class ContainerCommandAndArgs { + private final String command; + private final List args; + + public ContainerCommandAndArgs(String command, List args) { + this.command = command; + this.args = args; + } + + public String getCommand() { + return command; + } + + public List getArgs() { + return args; + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/reconciler/RemoteShuffleApplicationReconciler.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/reconciler/RemoteShuffleApplicationReconciler.java new file mode 100644 index 00000000..8841b6cb --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/reconciler/RemoteShuffleApplicationReconciler.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.reconciler; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.functions.RunnableWithException; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; +import com.alibaba.flink.shuffle.kubernetes.operator.crd.RemoteShuffleApplication; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.K8sRemoteShuffleFileConfigsParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesConfigMapParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesDaemonSetParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesShuffleManagerParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesShuffleWorkerParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.resources.KubernetesConfigMapBuilder; +import com.alibaba.flink.shuffle.kubernetes.operator.resources.KubernetesDaemonSetBuilder; +import com.alibaba.flink.shuffle.kubernetes.operator.resources.KubernetesDeploymentBuilder; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesInternalOptions; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesUtils; + +import io.fabric8.kubernetes.api.model.ConfigMap; +import io.fabric8.kubernetes.api.model.HasMetadata; +import io.fabric8.kubernetes.api.model.apps.DaemonSet; +import io.fabric8.kubernetes.api.model.apps.Deployment; +import io.fabric8.kubernetes.client.KubernetesClient; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Map; +import java.util.stream.Collectors; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** + * The {@link RemoteShuffleApplicationReconciler} is responsible for reconciling {@link + * RemoteShuffleApplication} to the desired state which is described by {@link + * RemoteShuffleApplication#spec}. + */ +public class RemoteShuffleApplicationReconciler { + + private static final Logger LOG = + LoggerFactory.getLogger(RemoteShuffleApplicationReconciler.class); + + private final KubernetesClient kubeClient; + + public RemoteShuffleApplicationReconciler(KubernetesClient kubeClient) { + this.kubeClient = kubeClient; + } + + /** + * This method will trigger the reconciliation of RemoteShuffleApplication to the desired state. + * + * @param shuffleApp the RemoteShuffleApplication. + */ + public void reconcile(RemoteShuffleApplication shuffleApp) { + final String namespace = shuffleApp.getMetadata().getNamespace(); + final String clusterId = shuffleApp.getMetadata().getName(); + + LOG.info("Reconciling RemoteShuffleApplication {}/{}", namespace, clusterId); + + final Configuration dynamicConfigs = + Configuration.fromMap(shuffleApp.getSpec().getShuffleDynamicConfigs()); + dynamicConfigs.setString(KubernetesInternalOptions.NAMESPACE, namespace); + dynamicConfigs.setString(ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID, clusterId); + + final Map fileConfigs = shuffleApp.getSpec().getShuffleFileConfigs(); + + Configuration updatedDynamicConfigs = dynamicConfigs; + if (fileConfigs != null && !fileConfigs.isEmpty()) { + updatedDynamicConfigs = reconcileFileConfigs(dynamicConfigs, fileConfigs, shuffleApp); + } + reconcileShuffleManager(updatedDynamicConfigs, shuffleApp); + reconcileShuffleWorkers(updatedDynamicConfigs, shuffleApp); + } + + /** + * This method will trigger the deployment of ShuffleManager. ShuffleManager will be deployed in + * the form of a Kubernetes {@link Deployment}. + * + * @param configuration configuration of the ShuffleManager + * @param owner The owner of ShuffleManager({@link Deployment}). The owner is set for garbage + * collection. When the owner is deleted, the ShuffleManager will be deleted automatically. + */ + private void reconcileShuffleManager(Configuration configuration, HasMetadata owner) { + KubernetesShuffleManagerParameters shuffleManagerParameters = + new KubernetesShuffleManagerParameters(configuration); + Deployment deployment = + new KubernetesDeploymentBuilder() + .buildKubernetesResourceFrom(shuffleManagerParameters); + KubernetesUtils.setOwnerReference(deployment, owner); + + LOG.info("Reconcile shuffle manager {}.", deployment.getMetadata().getName()); + executeReconcileWithRetry( + () -> { + LOG.debug("Try to create or update Deployment {}.", deployment.toString()); + this.kubeClient + .apps() + .deployments() + .inNamespace( + checkNotNull( + configuration.getString( + KubernetesInternalOptions.NAMESPACE))) + .createOrReplace(deployment); + }, + deployment.getMetadata().getName()); + } + + /** + * This method will trigger the deployment of ShuffleWorkers. ShuffleWorkers will be deployed in + * the form of a Kubernetes {@link DaemonSet}. + * + * @param configuration configuration of the ShuffleWorkers + * @param owner The owner of ShuffleWorkers({@link DaemonSet}). The owner is set for garbage + * collection. When the owner is deleted, the ShuffleWorkers will be deleted automatically. + */ + private void reconcileShuffleWorkers(Configuration configuration, HasMetadata owner) { + + KubernetesDaemonSetParameters shuffleWorkerParameters = + new KubernetesShuffleWorkerParameters(configuration); + DaemonSet daemonSet = + new KubernetesDaemonSetBuilder() + .buildKubernetesResourceFrom(shuffleWorkerParameters); + KubernetesUtils.setOwnerReference(daemonSet, owner); + + LOG.info("Reconcile shuffle workers {}.", daemonSet.getMetadata().getName()); + executeReconcileWithRetry( + () -> { + LOG.debug("Try to create or update DaemonSet {}.", daemonSet.toString()); + this.kubeClient + .apps() + .daemonSets() + .inNamespace( + checkNotNull( + configuration.getString( + KubernetesInternalOptions.NAMESPACE))) + .createOrReplace(daemonSet); + }, + daemonSet.getMetadata().getName()); + } + + /** + * This method will trigger the deployment of shuffle config files. The shuffle config files + * will be deployed in the form of a Kubernetes {@link ConfigMap}. And then the config map will + * be mounted into ShuffleManager and ShuffleWorkers pods. + * + * @param fileConfigs Config file contents. + * @param owner The owner of ShuffleManager({@link Deployment}). The owner is set for garbage + * collection. When the owner is deleted, the {@link ConfigMap} will be deleted + * automatically. + * @return the updated configuration. + */ + private Configuration reconcileFileConfigs( + Configuration dynamicParameters, Map fileConfigs, HasMetadata owner) { + + // TODO: Currently, pods will not restart automatically when the content of the file is + // changed. Will support later. + + KubernetesConfigMapParameters configMapParameters = + new K8sRemoteShuffleFileConfigsParameters( + dynamicParameters.getString(KubernetesInternalOptions.NAMESPACE), + dynamicParameters.getString(ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID), + fileConfigs); + + ConfigMap configMap = + new KubernetesConfigMapBuilder().buildKubernetesResourceFrom(configMapParameters); + + String volumeName = Constants.REMOTE_SHUFFLE_CONF_VOLUME; + String configMapName = configMap.getMetadata().getName(); + + Configuration configuration = dynamicParameters; + + // setup configmap + configuration = + setupConfigMapVolume( + configuration, volumeName, configMapName, fileConfigs.keySet()); + + // set owner + KubernetesUtils.setOwnerReference(configMap, owner); + + final Configuration finalConfiguration = configuration; + // deploy configmap + LOG.info("Reconcile configmap {}.", configMap.getMetadata().getName()); + executeReconcileWithRetry( + () -> { + LOG.debug("Try to create or update ConfigMap {}.", configMap.toString()); + this.kubeClient + .configMaps() + .inNamespace( + finalConfiguration.getString( + KubernetesInternalOptions.NAMESPACE)) + .withName(configMapName) + .createOrReplace(configMap); + }, + configMap.getMetadata().getName()); + + return configuration; + } + + private void executeReconcileWithRetry(RunnableWithException action, String component) { + KubernetesUtils.executeWithRetry(action, String.format("Reconcile %s", component)); + } + + private Configuration setupConfigMapVolume( + Configuration config, + String volumeName, + String configMapName, + Collection fileNames) { + + Configuration configuration = new Configuration(config); + + // set volume + configuration.setString(KubernetesInternalOptions.CONFIG_VOLUME_NAME, volumeName); + configuration.setString( + KubernetesInternalOptions.CONFIG_VOLUME_CONFIG_MAP_NAME, configMapName); + configuration.setMap( + KubernetesInternalOptions.CONFIG_VOLUME_ITEMS, + fileNames.stream().collect(Collectors.toMap(file -> file, file -> file))); + configuration.setString( + KubernetesInternalOptions.CONFIG_VOLUME_MOUNT_PATH, + Constants.REMOTE_SHUFFLE_CONF_DIR); + + return configuration; + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesConfigMapBuilder.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesConfigMapBuilder.java new file mode 100644 index 00000000..795483ae --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesConfigMapBuilder.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.resources; + +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesConfigMapParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; + +import io.fabric8.kubernetes.api.model.ConfigMap; +import io.fabric8.kubernetes.api.model.ConfigMapBuilder; + +/** Kubernetes ConfigMap builder. */ +public class KubernetesConfigMapBuilder + implements KubernetesResourceBuilder { + + @Override + public ConfigMap buildKubernetesResourceFrom( + KubernetesConfigMapParameters configMapParameters) { + return new ConfigMapBuilder() + .withApiVersion(Constants.API_VERSION) + .editOrNewMetadata() + .withName(configMapParameters.getConfigMapName()) + .withNamespace(configMapParameters.getNamespace()) + .withLabels(configMapParameters.getLabels()) + .endMetadata() + .withData(configMapParameters.getData()) + .build(); + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesContainerBuilder.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesContainerBuilder.java new file mode 100644 index 00000000..6c84b851 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesContainerBuilder.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.resources; + +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesContainerParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.util.ContainerCommandAndArgs; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesUtils; + +import io.fabric8.kubernetes.api.model.Container; +import io.fabric8.kubernetes.api.model.ContainerBuilder; +import io.fabric8.kubernetes.api.model.EnvVar; +import io.fabric8.kubernetes.api.model.EnvVarBuilder; +import io.fabric8.kubernetes.api.model.EnvVarSourceBuilder; +import io.fabric8.kubernetes.api.model.ResourceRequirements; +import io.fabric8.kubernetes.api.model.VolumeMount; +import io.fabric8.kubernetes.api.model.VolumeMountBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** Kubernetes Container builder. */ +public class KubernetesContainerBuilder + implements KubernetesResourceBuilder { + + private static final Logger LOG = LoggerFactory.getLogger(KubernetesContainerBuilder.class); + + @Override + public Container buildKubernetesResourceFrom( + KubernetesContainerParameters containerParameters) { + + final ResourceRequirements requirements = + KubernetesUtils.getResourceRequirements( + containerParameters.getContainerMemoryMB(), + containerParameters.getContainerCPU()); + KubernetesUtils.updateResourceRequirements( + requirements, containerParameters.getResourceLimitFactors()); + + ContainerCommandAndArgs commandAndArgs = containerParameters.getContainerCommandAndArgs(); + + return new ContainerBuilder() + .withName(containerParameters.getContainerName()) + .withImage(containerParameters.getContainerImage()) + .withImagePullPolicy(containerParameters.getContainerImagePullPolicy()) + .withCommand(commandAndArgs.getCommand()) + .withArgs(commandAndArgs.getArgs()) + .withResources(requirements) + .withVolumeMounts(getVolumeMounts(containerParameters)) + .withEnv(getEnvironmentVars(containerParameters)) + .build(); + } + + private List getVolumeMounts(KubernetesContainerParameters containerParameters) { + return containerParameters.getContainerVolumeMounts().stream() + .map(this::getVolumeMount) + .collect(Collectors.toList()); + } + + private VolumeMount getVolumeMount(Map stringMap) { + checkState(stringMap.containsKey(Constants.VOLUME_NAME)); + checkState(stringMap.containsKey(Constants.VOLUME_MOUNT_PATH)); + + final VolumeMountBuilder volumeMountBuilder = new VolumeMountBuilder(); + stringMap.forEach( + (k, v) -> { + switch (k) { + case Constants.VOLUME_NAME: + volumeMountBuilder.withName(v); + break; + case Constants.VOLUME_MOUNT_PATH: + volumeMountBuilder.withMountPath(v); + break; + default: + LOG.warn("Unrecognized key({}) of volume mount, will ignore.", k); + break; + } + }); + + return volumeMountBuilder.build(); + } + + private List getEnvironmentVars(KubernetesContainerParameters containerParameters) { + List envVars = + containerParameters.getEnvironmentVars().entrySet().stream() + .map( + kv -> + new EnvVarBuilder() + .withName(kv.getKey()) + .withValue(kv.getValue()) + .build()) + .collect(Collectors.toList()); + + Map fieldRefEnvVars = new HashMap<>(); + fieldRefEnvVars.put( + Constants.ENV_REMOTE_SHUFFLE_POD_IP_ADDRESS, Constants.POD_IP_FIELD_PATH); + fieldRefEnvVars.put(Constants.ENV_REMOTE_SHUFFLE_POD_NAME, Constants.POD_NAME_FIELD_PATH); + envVars.addAll( + fieldRefEnvVars.entrySet().stream() + .map( + kv -> + new EnvVarBuilder() + .withName(kv.getKey()) + .withValueFrom( + new EnvVarSourceBuilder() + .withNewFieldRef( + Constants.API_VERSION, + kv.getValue()) + .build()) + .build()) + .collect(Collectors.toList())); + + return envVars; + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesDaemonSetBuilder.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesDaemonSetBuilder.java new file mode 100644 index 00000000..38b90c2c --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesDaemonSetBuilder.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.resources; + +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesDaemonSetParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; + +import io.fabric8.kubernetes.api.model.Pod; +import io.fabric8.kubernetes.api.model.PodTemplateSpec; +import io.fabric8.kubernetes.api.model.PodTemplateSpecBuilder; +import io.fabric8.kubernetes.api.model.apps.DaemonSet; +import io.fabric8.kubernetes.api.model.apps.DaemonSetBuilder; +import io.fabric8.kubernetes.api.model.apps.DaemonSetUpdateStrategyBuilder; + +import java.util.Map; + +/** Kubernetes DaemonSet builder. */ +public class KubernetesDaemonSetBuilder + implements KubernetesResourceBuilder { + + @Override + public DaemonSet buildKubernetesResourceFrom( + KubernetesDaemonSetParameters daemonSetParameters) { + + final Pod resolvedPod = + new KubernetesPodBuilder() + .buildKubernetesResourceFrom( + daemonSetParameters.getPodTemplateParameters()); + + final Map labels = resolvedPod.getMetadata().getLabels(); + + return new DaemonSetBuilder() + .withApiVersion(Constants.APPS_API_VERSION) + .editOrNewMetadata() + .withName(daemonSetParameters.getDaemonSetName()) + .withNamespace(daemonSetParameters.getNamespace()) + .withLabels(daemonSetParameters.getLabels()) + .endMetadata() + .editOrNewSpec() + .withTemplate(getPodTemplate(resolvedPod)) + .withUpdateStrategy( + new DaemonSetUpdateStrategyBuilder() + .withType(Constants.ROLLING_UPDATE) + .build()) + .editOrNewSelector() + .addToMatchLabels(labels) + .endSelector() + .endSpec() + .build(); + } + + public PodTemplateSpec getPodTemplate(Pod resolvedPod) { + return new PodTemplateSpecBuilder() + .withMetadata(resolvedPod.getMetadata()) + .withSpec(resolvedPod.getSpec()) + .build(); + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesDeploymentBuilder.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesDeploymentBuilder.java new file mode 100644 index 00000000..fca0c785 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesDeploymentBuilder.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.resources; + +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesDeploymentParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; + +import io.fabric8.kubernetes.api.model.Pod; +import io.fabric8.kubernetes.api.model.PodTemplateSpec; +import io.fabric8.kubernetes.api.model.PodTemplateSpecBuilder; +import io.fabric8.kubernetes.api.model.apps.Deployment; +import io.fabric8.kubernetes.api.model.apps.DeploymentBuilder; + +import java.util.Map; + +/** Kubernetes Deployment builder. */ +public class KubernetesDeploymentBuilder + implements KubernetesResourceBuilder { + @Override + public Deployment buildKubernetesResourceFrom( + KubernetesDeploymentParameters deploymentParameters) { + + final Pod resolvedPod = + new KubernetesPodBuilder() + .buildKubernetesResourceFrom( + deploymentParameters.getPodTemplateParameters()); + + final Map labels = resolvedPod.getMetadata().getLabels(); + + return new DeploymentBuilder() + .withApiVersion(Constants.APPS_API_VERSION) + .editOrNewMetadata() + .withName(deploymentParameters.getDeploymentName()) + .withNamespace(deploymentParameters.getNamespace()) + .withLabels(deploymentParameters.getLabels()) + .endMetadata() + .editOrNewSpec() + .withReplicas(deploymentParameters.getReplicas()) + .withTemplate(getPodTemplate(resolvedPod)) + .editOrNewSelector() + .addToMatchLabels(labels) + .endSelector() + .endSpec() + .build(); + } + + public PodTemplateSpec getPodTemplate(Pod resolvedPod) { + return new PodTemplateSpecBuilder() + .withMetadata(resolvedPod.getMetadata()) + .withSpec(resolvedPod.getSpec()) + .build(); + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesPodBuilder.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesPodBuilder.java new file mode 100644 index 00000000..a76db7c7 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesPodBuilder.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.resources; + +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesPodParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.util.ConfigMapVolume; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; + +import io.fabric8.kubernetes.api.model.Container; +import io.fabric8.kubernetes.api.model.KeyToPath; +import io.fabric8.kubernetes.api.model.KeyToPathBuilder; +import io.fabric8.kubernetes.api.model.Pod; +import io.fabric8.kubernetes.api.model.PodBuilder; +import io.fabric8.kubernetes.api.model.Quantity; +import io.fabric8.kubernetes.api.model.Toleration; +import io.fabric8.kubernetes.api.model.TolerationBuilder; +import io.fabric8.kubernetes.api.model.Volume; +import io.fabric8.kubernetes.api.model.VolumeBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** Kubernetes Pod builder. */ +public class KubernetesPodBuilder + implements KubernetesResourceBuilder { + + private static final Logger LOG = LoggerFactory.getLogger(KubernetesPodBuilder.class); + + @Override + public Pod buildKubernetesResourceFrom(KubernetesPodParameters podParameters) { + + final Container singleContainer = + new KubernetesContainerBuilder() + .buildKubernetesResourceFrom(podParameters.getContainerParameters()); + + PodBuilder podBuilder = + new PodBuilder() + .withApiVersion(Constants.API_VERSION) + .editOrNewMetadata() + .withLabels(podParameters.getLabels()) + .endMetadata() + .editOrNewSpec() + .withHostNetwork(podParameters.enablePodHostNetwork()) + .withContainers(singleContainer) + .withVolumes(getVolumes(podParameters)) + .withNodeSelector(podParameters.getNodeSelector()) + .withTolerations( + podParameters.getTolerations().stream() + .map(this::getToleration) + .collect(Collectors.toList())) + .endSpec(); + + return podBuilder.build(); + } + + private List getVolumes(KubernetesPodParameters podParameters) { + List volumes = new ArrayList<>(); + // empty dir + volumes.addAll( + podParameters.getEmptyDirVolumes().stream() + .map(this::getEmptyDirVolume) + .collect(Collectors.toList())); + // host path + volumes.addAll( + podParameters.getHostPathVolumes().stream() + .map(this::getHostPathVolume) + .collect(Collectors.toList())); + // config map + for (ConfigMapVolume volume : podParameters.getConfigMapVolumes()) { + volumes.add(getConfigMapVolume(volume)); + } + + return volumes; + } + + private Volume getEmptyDirVolume(Map stringMap) { + checkState(stringMap.containsKey(Constants.VOLUME_NAME)); + + final VolumeBuilder volumeBuilder = new VolumeBuilder(); + stringMap.forEach( + (k, v) -> { + switch (k) { + case Constants.VOLUME_NAME: + volumeBuilder.withName(v); + break; + case Constants.EMPTY_DIR_VOLUME_MEDIUM: + volumeBuilder.editOrNewEmptyDir().withMedium(v).endEmptyDir(); + break; + case Constants.EMPTY_DIR_VOLUME_SIZE_LIMIT: + volumeBuilder + .editOrNewEmptyDir() + .withSizeLimit(new Quantity(v)) + .endEmptyDir(); + break; + default: + LOG.warn("Unrecognized key({}) of emptyDir config, will ignore.", k); + break; + } + }); + + return volumeBuilder.build(); + } + + private Volume getHostPathVolume(Map stringMap) { + checkState(stringMap.containsKey(Constants.VOLUME_NAME)); + checkState(stringMap.containsKey(Constants.HOST_PATH_VOLUME_PATH)); + + final VolumeBuilder volumeBuilder = new VolumeBuilder(); + stringMap.forEach( + (k, v) -> { + switch (k) { + case Constants.VOLUME_NAME: + volumeBuilder.withName(v); + break; + case Constants.HOST_PATH_VOLUME_TYPE: + volumeBuilder.editOrNewHostPath().withType(v).endHostPath(); + break; + case Constants.HOST_PATH_VOLUME_PATH: + volumeBuilder.editOrNewHostPath().withPath(v).endHostPath(); + break; + default: + LOG.warn("Unrecognized key({}) of hostPath config, will ignore.", k); + break; + } + }); + + return volumeBuilder.build(); + } + + private Volume getConfigMapVolume(ConfigMapVolume volume) { + + final List keyToPaths = + volume.getItems().entrySet().stream() + .map( + kv -> + new KeyToPathBuilder() + .withKey(kv.getKey()) + .withPath(kv.getValue()) + .build()) + .collect(Collectors.toList()); + + return new VolumeBuilder() + .withName(volume.getVolumeName()) + .editOrNewConfigMap() + .withName(volume.getConfigMapName()) + .withDefaultMode(420) + .withItems(keyToPaths) + .endConfigMap() + .build(); + } + + private Toleration getToleration(Map stringMap) { + final TolerationBuilder tolerationBuilder = new TolerationBuilder(); + stringMap.forEach( + (k, v) -> { + switch (k.toLowerCase()) { + case "effect": + tolerationBuilder.withEffect(v); + break; + case "key": + tolerationBuilder.withKey(v); + break; + case "operator": + tolerationBuilder.withOperator(v); + break; + case "tolerationseconds": + tolerationBuilder.withTolerationSeconds(Long.valueOf(v)); + break; + case "value": + tolerationBuilder.withValue(v); + break; + default: + LOG.warn("Unrecognized key({}) of toleration, will ignore.", k); + break; + } + }); + return tolerationBuilder.build(); + } +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesResourceBuilder.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesResourceBuilder.java new file mode 100644 index 00000000..6b12a8dd --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesResourceBuilder.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.resources; + +import io.fabric8.kubernetes.api.model.KubernetesResource; + +/** Builder of kubernetes resources. */ +public interface KubernetesResourceBuilder { + + /** + * Build the kubernetes resources. + * + * @param parameters The parameters for constructing the kubernetes resource. + * @return Kubernetes resource. + */ + R buildKubernetesResourceFrom(P parameters); +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/util/Constants.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/util/Constants.java new file mode 100644 index 00000000..1f49b3eb --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/util/Constants.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.util; + +/** Constants for kubernetes. */ +public class Constants { + + public static final String API_VERSION = "v1"; + public static final String APPS_API_VERSION = "apps/v1"; + + public static final String RESOURCE_NAME_CPU = "cpu"; + public static final String RESOURCE_NAME_MEMORY = "memory"; + public static final String RESOURCE_UNIT_MB = "Mi"; + + public static final int MAXIMUM_CHARACTERS_OF_CLUSTER_ID = 45; + + // env + public static final String ENV_REMOTE_SHUFFLE_POD_IP_ADDRESS = "_POD_IP_ADDRESS"; + + // labels + public static final String LABEL_APPTYPE_KEY = "apptype"; + public static final String LABEL_APPTYPE_VALUE = "remoteshuffle"; + public static final String LABEL_APP_KEY = "app"; + public static final String LABEL_COMPONENT_KEY = "component"; + public static final String LABEL_COMPONENT_SHUFFLE_MANAGER = "shufflemanager"; + public static final String LABEL_COMPONENT_SHUFFLE_WORKER = "shuffleworker"; + + public static final String ROLLING_UPDATE = "RollingUpdate"; + + // volume + public static final String VOLUME_NAME = "name"; + public static final String VOLUME_MOUNT_PATH = "mountPath"; + // empty dir + public static final String EMPTY_DIR_VOLUME_MEDIUM = "medium"; + public static final String EMPTY_DIR_VOLUME_SIZE_LIMIT = "sizeLimit"; + // host path + public static final String HOST_PATH_VOLUME_PATH = "path"; + public static final String HOST_PATH_VOLUME_TYPE = "type"; + + // pod ip + public static final String POD_IP_FIELD_PATH = "status.podIP"; + // pod name + public static final String POD_NAME_FIELD_PATH = "metadata.name"; + public static final String ENV_REMOTE_SHUFFLE_POD_NAME = "_POD_NAME"; + + // conf dir. + public static final String REMOTE_SHUFFLE_CONF_VOLUME = "shuffle-config-file-volume"; + public static final String REMOTE_SHUFFLE_CONF_DIR = "/flink-remote-shuffle/conf"; + + // user agent + public static final String REMOTE_SHUFFLE_OPERATOR_USER_AGENT = "flink-remote-shuffle-operator"; + + // Kubernetes start scripts + public static final String SHUFFLE_MANAGER_SCRIPT_PATH = + "/flink-remote-shuffle/bin/kubernetes-shufflemanager.sh"; + public static final String SHUFFLE_WORKER_SCRIPT_PATH = + "/flink-remote-shuffle/bin/kubernetes-shuffleworker.sh"; +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/util/KubernetesInternalOptions.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/util/KubernetesInternalOptions.java new file mode 100644 index 00000000..2c5742ef --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/util/KubernetesInternalOptions.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.util; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; + +import java.util.Collections; +import java.util.Map; + +/** Kubernetes configuration options that are not meant to be set by the user. */ +public class KubernetesInternalOptions { + + // -------------------------------------------------------------------------------------------- + // Internal configurations. + // -------------------------------------------------------------------------------------------- + public static final ConfigOption NAMESPACE = + new ConfigOption("remote-shuffle.kubernetes.namespace") + .defaultValue("default") + .description( + "The namespace that will be used for running the shuffle manager and " + + "worker pods."); + + public static final ConfigOption CONFIG_VOLUME_NAME = + new ConfigOption("remote-shuffle.kubernetes.config-volume.name") + .defaultValue(null) + .description("Config volume name."); + + public static final ConfigOption CONFIG_VOLUME_CONFIG_MAP_NAME = + new ConfigOption("remote-shuffle.kubernetes.config-volume.config-map-name") + .defaultValue(null) + .description("Config map name."); + + public static final ConfigOption> CONFIG_VOLUME_ITEMS = + new ConfigOption>("remote-shuffle.kubernetes.config-volume.items") + .defaultValue(Collections.emptyMap()) + .description("Config volume items."); + + public static final ConfigOption CONFIG_VOLUME_MOUNT_PATH = + new ConfigOption("remote-shuffle.kubernetes.config-volume.mount-path") + .defaultValue(null) + .description("Config volume mount path."); +} diff --git a/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/util/KubernetesUtils.java b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/util/KubernetesUtils.java new file mode 100644 index 00000000..5a2bccb8 --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/java/com/alibaba/flink/shuffle/kubernetes/operator/util/KubernetesUtils.java @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.util; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.functions.RunnableWithException; +import com.alibaba.flink.shuffle.common.functions.SupplierWithException; + +import io.fabric8.kubernetes.api.model.HasMetadata; +import io.fabric8.kubernetes.api.model.OwnerReference; +import io.fabric8.kubernetes.api.model.OwnerReferenceBuilder; +import io.fabric8.kubernetes.api.model.Quantity; +import io.fabric8.kubernetes.api.model.ResourceRequirements; +import io.fabric8.kubernetes.api.model.ResourceRequirementsBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** Common utils for Kubernetes. */ +public class KubernetesUtils { + + public static final String SHUFFLE_WORKER_CONTAINER_NAME = "shuffleworker"; + public static final String SHUFFLE_MANAGER_CONTAINER_NAME = "shufflemanager"; + private static final int MAX_RETRY_TIMES = 10; + private static final long RETRY_INTERVAL = 100L; + private static final Logger LOG = LoggerFactory.getLogger(KubernetesUtils.class); + + /** + * Get resource requirements from memory and cpu. + * + * @param mem Memory in mb. + * @param cpu cpu. + * @return KubernetesResource requirements. + */ + public static ResourceRequirements getResourceRequirements(int mem, double cpu) { + final Quantity cpuQuantity = new Quantity(String.valueOf(cpu)); + final Quantity memQuantity = new Quantity(mem + Constants.RESOURCE_UNIT_MB); + + ResourceRequirementsBuilder resourceRequirementsBuilder = + new ResourceRequirementsBuilder() + .addToRequests(Constants.RESOURCE_NAME_MEMORY, memQuantity) + .addToRequests(Constants.RESOURCE_NAME_CPU, cpuQuantity) + .addToLimits(Constants.RESOURCE_NAME_MEMORY, memQuantity) + .addToLimits(Constants.RESOURCE_NAME_CPU, cpuQuantity); + + return resourceRequirementsBuilder.build(); + } + + /** + * Get the common labels for remote shuffle service clusters. All the Kubernetes resources will + * be set with these labels. + * + * @param clusterId cluster id + * @return Return common labels map + */ + public static Map getCommonLabels(String clusterId) { + final Map commonLabels = new HashMap<>(); + commonLabels.put(Constants.LABEL_APPTYPE_KEY, Constants.LABEL_APPTYPE_VALUE); + commonLabels.put(Constants.LABEL_APP_KEY, clusterId); + + return commonLabels; + } + + public static String getShuffleManagerNameWithClusterId(String clusterId) { + return clusterId + "-" + SHUFFLE_MANAGER_CONTAINER_NAME; + } + + public static String getShuffleWorkersNameWithClusterId(String clusterId) { + return clusterId + "-" + SHUFFLE_WORKER_CONTAINER_NAME; + } + + public static void setOwnerReference(HasMetadata resource, HasMetadata owner) { + final OwnerReference ownerReference = + new OwnerReferenceBuilder() + .withName(owner.getMetadata().getName()) + .withApiVersion(owner.getApiVersion()) + .withUid(owner.getMetadata().getUid()) + .withKind(owner.getKind()) + .withController(true) + .build(); + + resource.getMetadata().setOwnerReferences(Collections.singletonList(ownerReference)); + } + + public static OwnerReference getControllerOf(HasMetadata resource) { + List ownerReferences = resource.getMetadata().getOwnerReferences(); + for (OwnerReference ownerReference : ownerReferences) { + if (ownerReference.getController().equals(Boolean.TRUE)) { + return ownerReference; + } + } + return null; + } + + public static String getResourceFullName(HasMetadata resource) { + return getNameWithNameSpace( + resource.getMetadata().getNamespace(), resource.getMetadata().getName()); + } + + public static String getNameWithNameSpace(String namespace, String name) { + return namespace + "/" + name; + } + + public static void executeWithRetry(RunnableWithException action, String actionName) { + for (int retry = 0; retry < MAX_RETRY_TIMES; retry++) { + try { + action.run(); + return; + } catch (Throwable throwable) { + if (LOG.isDebugEnabled()) { + LOG.debug( + String.format("%s error, retry times = %d.", actionName, retry), + throwable); + } + if (retry >= MAX_RETRY_TIMES - 1) { + throw new RuntimeException(String.format("%s failed.", actionName), throwable); + } + try { + TimeUnit.SECONDS.sleep(1); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + } + + /** + * Extract and parse configuration properties with a given name prefix and return the result as + * a Map. + */ + public static Map getPrefixedKeyValuePairs( + String prefix, Configuration configuration) { + Map result = new HashMap<>(); + for (Map.Entry entry : configuration.toMap().entrySet()) { + if (entry.getKey().startsWith(prefix) && entry.getKey().length() > prefix.length()) { + String key = entry.getKey().substring(prefix.length()); + result.put(key, entry.getValue()); + } + } + return result; + } + + public static void updateResourceRequirements( + ResourceRequirements resourceRequirements, Map limitFactors) { + + limitFactors.forEach( + (type, factor) -> { + final Quantity quantity = resourceRequirements.getRequests().get(type); + if (quantity != null) { + final double limit = + Double.parseDouble(quantity.getAmount()) + * Double.parseDouble(factor); + LOG.info("Updating the {} limit to {}", type, limit); + resourceRequirements + .getLimits() + .put( + type, + new Quantity(String.valueOf(limit), quantity.getFormat())); + } else { + LOG.warn( + "Could not find the request for {}, ignoring the factor {}.", + type, + factor); + } + }); + } + + public static Map filterVolumeMountsConfigs(Map configs) { + return filterConfigsWithSpecifiedKeys( + configs, Arrays.asList(Constants.VOLUME_NAME, Constants.VOLUME_MOUNT_PATH)); + } + + public static Map filterEmptyDirVolumeConfigs(Map configs) { + return filterConfigsWithSpecifiedKeys( + configs, + Arrays.asList( + Constants.VOLUME_NAME, + Constants.EMPTY_DIR_VOLUME_MEDIUM, + Constants.EMPTY_DIR_VOLUME_SIZE_LIMIT)); + } + + public static Map filterHostPathVolumeConfigs(Map configs) { + return filterConfigsWithSpecifiedKeys( + configs, + Arrays.asList( + Constants.VOLUME_NAME, + Constants.HOST_PATH_VOLUME_PATH, + Constants.HOST_PATH_VOLUME_TYPE)); + } + + public static List> filterVolumesConfigs( + Configuration configuration, + ConfigOption>> option, + Function, Map> configFilter) { + List> volumes = + Optional.ofNullable(configuration.getList(option, Map.class)) + .orElse(Collections.emptyList()); + + return volumes.stream().map(configFilter).collect(Collectors.toList()); + } + + private static Map filterConfigsWithSpecifiedKeys( + Map configs, List specifiedKeys) { + return configs.entrySet().stream() + .filter(entry -> specifiedKeys.contains(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + public static void waitUntilCondition( + SupplierWithException condition, Duration timeout) + throws Exception { + waitUntilCondition( + condition, timeout, RETRY_INTERVAL, "Condition was not met in given timeout."); + } + + public static void waitUntilCondition( + SupplierWithException condition, + Duration timeout, + long retryIntervalMillis, + String errorMsg) + throws Exception { + long timeLeft = timeout.toMillis(); + long endTime = System.currentTimeMillis() + timeLeft; + + while (timeLeft > 0 && !condition.get()) { + Thread.sleep(Math.min(retryIntervalMillis, timeLeft)); + timeLeft = endTime - System.currentTimeMillis(); + } + + if (timeLeft <= 0) { + throw new TimeoutException(errorMsg); + } + } +} diff --git a/shuffle-kubernetes-operator/src/main/resources/log4j2.properties b/shuffle-kubernetes-operator/src/main/resources/log4j2.properties new file mode 100644 index 00000000..5baee61c --- /dev/null +++ b/shuffle-kubernetes-operator/src/main/resources/log4j2.properties @@ -0,0 +1,25 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +rootLogger.level = OFF +rootLogger.appenderRef.console.ref = ConsoleAppender + +appender.console.name = ConsoleAppender +appender.console.type = CONSOLE +appender.console.layout.type = PatternLayout +appender.console.layout.pattern = %d{yyyy-MM-dd HH:mm:ss,SSS} %-5p [%t] %-60c %x - %m%n diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/RemoteShuffleApplicationOperatorEntrypointTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/RemoteShuffleApplicationOperatorEntrypointTest.java new file mode 100644 index 00000000..22fa58bd --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/RemoteShuffleApplicationOperatorEntrypointTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator; + +import com.alibaba.flink.shuffle.kubernetes.operator.crd.RemoteShuffleApplication; + +import io.fabric8.kubernetes.api.model.apiextensions.v1beta1.CustomResourceColumnDefinition; +import io.fabric8.kubernetes.api.model.apiextensions.v1beta1.CustomResourceDefinition; +import org.apache.commons.lang3.tuple.Triple; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; + +/** Test for {@link RemoteShuffleApplicationOperatorEntrypoint}. */ +public class RemoteShuffleApplicationOperatorEntrypointTest { + + @Test + public void testCreateRemoteShuffleApplicationCRD() { + CustomResourceDefinition crd = + RemoteShuffleApplicationOperatorEntrypoint.createRemoteShuffleApplicationCRD(); + RemoteShuffleApplication instance = new RemoteShuffleApplication(); + + assertThat(crd.getKind(), is("CustomResourceDefinition")); + assertThat(crd.getApiVersion(), is("apiextensions.k8s.io/v1beta1")); + assertThat(crd.getMetadata().getName(), is(instance.getCRDName())); + assertThat(crd.getSpec().getGroup(), is(instance.getGroup())); + assertThat(crd.getSpec().getNames().getKind(), is(instance.getKind())); + assertThat(crd.getSpec().getNames().getSingular(), is(instance.getSingular())); + assertThat(crd.getSpec().getNames().getPlural(), is(instance.getPlural())); + assertThat(crd.getSpec().getScope(), is(instance.getScope())); + + assertThat(crd.getSpec().getAdditionalPrinterColumns().size(), is(2)); + + List> additionalColumns = new ArrayList<>(); + for (CustomResourceColumnDefinition definition : + crd.getSpec().getAdditionalPrinterColumns()) { + additionalColumns.add( + Triple.of( + definition.getName(), definition.getType(), definition.getJSONPath())); + } + + assertEquals( + RemoteShuffleApplicationOperatorEntrypoint.ADDITIONAL_COLUMN, additionalColumns); + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/controller/RemoteShuffleApplicationControllerTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/controller/RemoteShuffleApplicationControllerTest.java new file mode 100644 index 00000000..ff25fdc2 --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/controller/RemoteShuffleApplicationControllerTest.java @@ -0,0 +1,643 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.controller; + +import com.alibaba.flink.shuffle.core.config.ClusterOptions; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; +import com.alibaba.flink.shuffle.core.config.KubernetesOptions; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.config.TransferOptions; +import com.alibaba.flink.shuffle.core.config.WorkerOptions; +import com.alibaba.flink.shuffle.core.executor.ExecutorThreadFactory; +import com.alibaba.flink.shuffle.kubernetes.operator.crd.RemoteShuffleApplication; +import com.alibaba.flink.shuffle.kubernetes.operator.crd.RemoteShuffleApplicationSpec; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesDaemonSetParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesDeploymentParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.resources.KubernetesDaemonSetBuilder; +import com.alibaba.flink.shuffle.kubernetes.operator.resources.KubernetesDaemonSetBuilderTest; +import com.alibaba.flink.shuffle.kubernetes.operator.resources.KubernetesDeploymentBuilder; +import com.alibaba.flink.shuffle.kubernetes.operator.resources.KubernetesDeploymentBuilderTest; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesInternalOptions; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesTestBase; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesUtils; + +import io.fabric8.kubernetes.api.model.HasMetadata; +import io.fabric8.kubernetes.api.model.ObjectMeta; +import io.fabric8.kubernetes.api.model.ObjectMetaBuilder; +import io.fabric8.kubernetes.api.model.apps.DaemonSet; +import io.fabric8.kubernetes.api.model.apps.Deployment; +import io.fabric8.kubernetes.client.KubernetesClient; +import io.fabric8.kubernetes.client.informers.ResourceEventHandler; +import io.fabric8.kubernetes.client.informers.SharedInformerFactory; +import io.fabric8.kubernetes.client.server.mock.KubernetesServer; +import org.junit.After; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.InetAddress; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** Test for {@link RemoteShuffleApplicationController}. */ +public class RemoteShuffleApplicationControllerTest extends KubernetesTestBase { + + private static final Logger LOG = + LoggerFactory.getLogger(RemoteShuffleApplicationControllerTest.class); + + private RemoteShuffleApplicationController shuffleAppController; + private ScheduledExecutorService shuffleAppControllerExecutor; + private KubernetesClient kubeClient; + + static Map> deleteEvents = new HashMap<>(); + static Map> addEvents = new HashMap<>(); + static Map> updateEvents = new HashMap<>(); + + @Before + public void setUp() throws Exception { + addEvents.clear(); + updateEvents.clear(); + deleteEvents.clear(); + + // setup client and resource event handler. + KubernetesServer server = + new KubernetesServer( + false, + true, + InetAddress.getByName("127.0.0.1"), + 0, + Collections.emptyList()); + server.before(); + kubeClient = server.getClient(); + ControllerRunner controllerRunner = new ControllerRunner(kubeClient); + controllerRunner.registerShuffleManagerResourceEventHandler( + new RecordResourceEventHandler<>()); + controllerRunner.registerShuffleWorkersResourceEventHandler( + new RecordResourceEventHandler<>()); + shuffleAppController = controllerRunner.getRemoteShuffleApplicationController(); + shuffleAppControllerExecutor = Executors.newSingleThreadScheduledExecutor(); + shuffleAppControllerExecutor.schedule(controllerRunner, 0, TimeUnit.MILLISECONDS); + waitForControllerReady(); + } + + @After + public void cleanUp() throws Exception { + shuffleAppControllerExecutor.shutdownNow(); + } + + private static class RecordResourceEventHandler implements ResourceEventHandler { + + @Override + public void onAdd(T t) { + checkState(t instanceof HasMetadata); + HasMetadata resource = (HasMetadata) t; + recordEvents(addEvents, resource); + } + + @Override + public void onUpdate(T t, T t1) { + checkState(t instanceof HasMetadata); + HasMetadata resource = (HasMetadata) t1; + recordEvents(updateEvents, resource); + } + + @Override + public void onDelete(T t, boolean b) { + checkState(t instanceof HasMetadata); + HasMetadata resource = (HasMetadata) t; + recordEvents(deleteEvents, resource); + } + } + + private static void recordEvents(Map> events, HasMetadata resource) { + events.compute( + resource.getKind(), + (kind, list) -> { + final List hasMetadataList; + if (list != null) { + hasMetadataList = list; + } else { + hasMetadataList = new ArrayList<>(); + } + hasMetadataList.add(resource); + return hasMetadataList; + }); + } + + private void waitUntilResourceReady(Class clazz, String resourceName) throws Exception { + KubernetesUtils.waitUntilCondition( + () -> checkResourceReady(clazz, resourceName), Duration.ofMinutes(2)); + } + + private boolean checkResourceReady(Class clazz, String resourceName) { + if (clazz == Deployment.class) { + return getDeploymentList().stream() + .map(deployment -> deployment.getMetadata().getName()) + .collect(Collectors.toList()) + .contains(resourceName); + } else if (clazz == DaemonSet.class) { + return getDaemonSetList().stream() + .map(daemonSet -> daemonSet.getMetadata().getName()) + .collect(Collectors.toList()) + .contains(resourceName); + } else if (clazz == RemoteShuffleApplication.class) { + return getRemoteShuffleApplicationList().stream() + .map(shuffleApplication -> shuffleApplication.getMetadata().getName()) + .collect(Collectors.toList()) + .contains(resourceName); + } else { + throw new UnsupportedOperationException(); + } + } + + private void waitForControllerReady() throws Exception { + KubernetesUtils.waitUntilCondition( + () -> shuffleAppController.isRunning.get(), Duration.ofSeconds(30)); + } + + private void deployRemoteShuffleApplication(RemoteShuffleApplication shuffleApp) { + kubeClient + .customResources(RemoteShuffleApplication.class) + .inNamespace(NAMESPACE) + .createOrReplace(shuffleApp); + } + + private void deleteRemoteShuffleApplication(RemoteShuffleApplication shuffleApp) { + kubeClient + .customResources(RemoteShuffleApplication.class) + .inNamespace(NAMESPACE) + .delete(shuffleApp); + } + + private List getRemoteShuffleApplicationList() { + return shuffleAppController.getShuffleAppLister().namespace(NAMESPACE).list(); + } + + private RemoteShuffleApplication getRemoteShuffleApplication(String name) { + return shuffleAppController.getShuffleAppLister().namespace(NAMESPACE).get(name); + } + + private List getDeploymentList() { + return shuffleAppController.getDeploymentLister().namespace(NAMESPACE).list(); + } + + private List getDaemonSetList() { + return shuffleAppController.getDaemonSetLister().namespace(NAMESPACE).list(); + } + + private Deployment getDeployment(String name) { + return shuffleAppController.getDeploymentLister().namespace(NAMESPACE).get(name); + } + + private DaemonSet getDaemonSet(String name) { + return shuffleAppController.getDaemonSetLister().namespace(NAMESPACE).get(name); + } + + private void deleteDeployment(Deployment deployment) { + kubeClient.apps().deployments().inNamespace(NAMESPACE).delete(deployment); + } + + private void deleteDaemonSet(DaemonSet daemonSet) { + kubeClient.apps().daemonSets().inNamespace(NAMESPACE).delete(daemonSet); + } + + private void deployDeployment(Deployment deployment) { + kubeClient.apps().deployments().inNamespace(NAMESPACE).createOrReplace(deployment); + } + + private void deployDaemonSet(DaemonSet daemonSet) { + kubeClient.apps().daemonSets().inNamespace(NAMESPACE).createOrReplace(daemonSet); + } + + private String getShuffleManagerName(RemoteShuffleApplication remoteShuffleApplication) { + return KubernetesUtils.getShuffleManagerNameWithClusterId( + remoteShuffleApplication.getMetadata().getName()); + } + + private String getShuffleWorkersName(RemoteShuffleApplication remoteShuffleApplication) { + return KubernetesUtils.getShuffleWorkersNameWithClusterId( + remoteShuffleApplication.getMetadata().getName()); + } + + private RemoteShuffleApplication createRemoteShuffleApplication() throws Exception { + Map dynamicConfigs = createDynamicConfigs(); + Map fileConfigs = createFileConfigs(); + return createRemoteShuffleApplication(dynamicConfigs, fileConfigs); + } + + private RemoteShuffleApplication createRemoteShuffleApplication( + Map dynamicConfigs, Map fileConfigs) throws Exception { + + RemoteShuffleApplicationSpec remoteShuffleApplicationSpec = + new RemoteShuffleApplicationSpec(); + remoteShuffleApplicationSpec.setShuffleDynamicConfigs(dynamicConfigs); + remoteShuffleApplicationSpec.setShuffleFileConfigs(fileConfigs); + + RemoteShuffleApplication remoteShuffleApplication = new RemoteShuffleApplication(); + ObjectMeta metadata = + new ObjectMetaBuilder().withNamespace(NAMESPACE).withName(CLUSTER_ID).build(); + remoteShuffleApplication.setMetadata(metadata); + remoteShuffleApplication.setSpec(remoteShuffleApplicationSpec); + + deployRemoteShuffleApplication(remoteShuffleApplication); + waitUntilResourceReady( + RemoteShuffleApplication.class, remoteShuffleApplication.getMetadata().getName()); + // check shuffle manager and shuffle worker. + waitUntilResourceReady(Deployment.class, getShuffleManagerName(remoteShuffleApplication)); + waitUntilResourceReady(DaemonSet.class, getShuffleWorkersName(remoteShuffleApplication)); + + List shuffleApps = getRemoteShuffleApplicationList(); + assertThat(shuffleApps.size(), is(1)); + assertThat(shuffleApps.get(0).getMetadata(), is(remoteShuffleApplication.getMetadata())); + assertThat(shuffleApps.get(0).getSpec(), is(remoteShuffleApplication.getSpec())); + + return remoteShuffleApplication; + } + + private Map createDynamicConfigs() { + Map dynamicConfigs = new HashMap<>(); + dynamicConfigs.put(ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID.key(), CLUSTER_ID); + dynamicConfigs.put(KubernetesInternalOptions.NAMESPACE.key(), NAMESPACE); + dynamicConfigs.put(KubernetesOptions.CONTAINER_IMAGE.key(), CONTAINER_IMAGE); + dynamicConfigs.put(KubernetesOptions.POD_HOST_NETWORK_ENABLED.key(), "true"); + dynamicConfigs.put(StorageOptions.STORAGE_LOCAL_DATA_DIRS.key(), "/data"); + dynamicConfigs.put(TransferOptions.SERVER_DATA_PORT.key(), "10085"); + dynamicConfigs.put(HighAvailabilityOptions.HA_MODE.key(), "ZOOKEEPER"); + dynamicConfigs.put(HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM.key(), "localhost"); + // shuffle manager configs + dynamicConfigs.put(KubernetesOptions.SHUFFLE_MANAGER_CPU.key(), "2"); + dynamicConfigs.put(ManagerOptions.FRAMEWORK_HEAP_MEMORY.key(), "256mb"); + // shuffle worker configs + dynamicConfigs.put(KubernetesOptions.SHUFFLE_WORKER_CPU.key(), "2"); + dynamicConfigs.put(WorkerOptions.FRAMEWORK_HEAP_MEMORY.key(), "256mb"); + + return dynamicConfigs; + } + + private Map createFileConfigs() { + Map fileConfigs = new HashMap<>(); + fileConfigs.put("file1", "This is a test for config file."); + fileConfigs.put("file2", "This is a test for config file."); + return fileConfigs; + } + + @Ignore("Temporarily ignore.") + @Test(timeout = 60000L) + public void testShuffleApplicationAddAndDelete() throws Exception { + // create a new shuffle application + RemoteShuffleApplication remoteShuffleApplication = createRemoteShuffleApplication(); + + // delete shuffle application + deleteRemoteShuffleApplication(remoteShuffleApplication); + KubernetesUtils.waitUntilCondition( + () -> getRemoteShuffleApplicationList().size() == 0, Duration.ofSeconds(30)); + } + + @Ignore("Temporarily ignore.") + @Test(timeout = 60000L) + public void testShuffleApplicationUpdate() throws Exception { + // create a new shuffle application + RemoteShuffleApplication remoteShuffleApplication = createRemoteShuffleApplication(); + + remoteShuffleApplication + .getSpec() + .getShuffleDynamicConfigs() + .put(KubernetesOptions.CONTAINER_IMAGE.key(), "shuffleApp:666"); + deployRemoteShuffleApplication(remoteShuffleApplication); + + KubernetesUtils.waitUntilCondition( + () -> { + RemoteShuffleApplication shuffleApp = + getRemoteShuffleApplication( + remoteShuffleApplication.getMetadata().getName()); + assertNotNull(shuffleApp); + return shuffleApp + .getSpec() + .getShuffleDynamicConfigs() + .get(KubernetesOptions.CONTAINER_IMAGE.key()) + .equals("shuffleApp:666"); + }, + Duration.ofSeconds(30)); + // check update image in shuffle manager. + KubernetesUtils.waitUntilCondition( + () -> { + Deployment shuffleManager = + getDeployment(getShuffleManagerName(remoteShuffleApplication)); + assertNotNull(shuffleManager); + return shuffleManager + .getSpec() + .getTemplate() + .getSpec() + .getContainers() + .get(0) + .getImage() + .equals("shuffleApp:666"); + }, + Duration.ofSeconds(30)); + // check update image in shuffle workers. + KubernetesUtils.waitUntilCondition( + () -> { + DaemonSet shuffleWorkers = + getDaemonSet(getShuffleWorkersName(remoteShuffleApplication)); + assertNotNull(shuffleWorkers); + return shuffleWorkers + .getSpec() + .getTemplate() + .getSpec() + .getContainers() + .get(0) + .getImage() + .equals("shuffleApp:666"); + }, + Duration.ofSeconds(30)); + } + + @Test(timeout = 60000L) + public void testShuffleManagerUpdate() throws Exception { + // create a new shuffle application + RemoteShuffleApplication remoteShuffleApplication = createRemoteShuffleApplication(); + + // update shuffle manager. + Deployment newDeployment = + RemoteShuffleApplicationController.cloneResource( + getDeployment(getShuffleManagerName(remoteShuffleApplication))); + newDeployment.getSpec().setReplicas(10); + deployDeployment(newDeployment); + String newDeploymentName = newDeployment.getMetadata().getName(); + + // check the replicas has been updated to 10. + KubernetesUtils.waitUntilCondition( + () -> { + List resources = updateEvents.get("Deployment"); + if (resources == null) { + return false; + } + for (HasMetadata resource : resources) { + Deployment deployment = (Deployment) resource; + if (deployment.getMetadata().getName().equals(newDeploymentName) + && deployment.getSpec().getReplicas().intValue() == 10) { + return true; + } + } + return false; + }, + Duration.ofSeconds(30)); + + // check the replicas be reconciled to 1. + KubernetesUtils.waitUntilCondition( + () -> { + Deployment deployment = getDeployment(newDeploymentName); + assertNotNull(deployment); + return deployment.getSpec().getReplicas().intValue() == 1; + }, + Duration.ofSeconds(30)); + } + + @Test(timeout = 60000L) + public void testShuffleManagerDelete() throws Exception { + // create a new shuffle application + RemoteShuffleApplication remoteShuffleApplication = createRemoteShuffleApplication(); + + // delete shuffle manager. + Deployment deployment = getDeployment(getShuffleManagerName(remoteShuffleApplication)); + deleteDeployment(deployment); + + // check the shuffle manager has been delete. + KubernetesUtils.waitUntilCondition( + () -> { + List resources = deleteEvents.get("Deployment"); + if (resources == null) { + return false; + } + + return resources.stream() + .map(resource -> resource.getMetadata().getName()) + .collect(Collectors.toList()) + .contains(deployment.getMetadata().getName()); + }, + Duration.ofSeconds(30)); + + // check the shuffle manager be reconciled. + KubernetesUtils.waitUntilCondition( + () -> + getDeploymentList().size() == 1 + && getDeploymentList() + .get(0) + .getMetadata() + .getName() + .equals(deployment.getMetadata().getName()), + Duration.ofSeconds(30)); + } + + @Ignore("Temporarily ignore.") + @Test(timeout = 60000L) + public void testShuffleWorkersUpdate() throws Exception { + // create a new shuffle application + RemoteShuffleApplication remoteShuffleApplication = createRemoteShuffleApplication(); + + // update shuffle workers. + DaemonSet newDaemonSet = + RemoteShuffleApplicationController.cloneResource( + getDaemonSet(getShuffleWorkersName(remoteShuffleApplication))); + newDaemonSet + .getSpec() + .getTemplate() + .getSpec() + .getContainers() + .get(0) + .setImage("shuffleApp:666"); + deployDaemonSet(newDaemonSet); + String newDaemonSetName = newDaemonSet.getMetadata().getName(); + + // check the image has been updated to shuffleApp:666. + KubernetesUtils.waitUntilCondition( + () -> { + List resources = updateEvents.get("DaemonSet"); + if (resources == null) { + return false; + } + for (HasMetadata resource : resources) { + DaemonSet daemonSet = (DaemonSet) resource; + if (daemonSet.getMetadata().getName().equals(newDaemonSetName) + && daemonSet + .getSpec() + .getTemplate() + .getSpec() + .getContainers() + .get(0) + .getImage() + .equals("shuffleApp:666")) { + return true; + } + } + return false; + }, + Duration.ofSeconds(30)); + + // check the image be reconciled to flink-remote-shuffle-k8s-test:latest. + KubernetesUtils.waitUntilCondition( + () -> { + DaemonSet daemonSet = getDaemonSet(newDaemonSetName); + assertNotNull(daemonSet); + return daemonSet + .getSpec() + .getTemplate() + .getSpec() + .getContainers() + .get(0) + .getImage() + .equals(CONTAINER_IMAGE); + }, + Duration.ofSeconds(30)); + } + + @Test(timeout = 60000L) + public void testShuffleWorkersDelete() throws Exception { + // create a new shuffle application + RemoteShuffleApplication remoteShuffleApplication = createRemoteShuffleApplication(); + + // delete shuffle workers. + DaemonSet daemonSet = getDaemonSet(getShuffleWorkersName(remoteShuffleApplication)); + deleteDaemonSet(daemonSet); + + // check the shuffle workers has been deleted. + KubernetesUtils.waitUntilCondition( + () -> { + List resources = deleteEvents.get("DaemonSet"); + if (resources == null) { + return false; + } + + return resources.stream() + .map(resource -> resource.getMetadata().getName()) + .collect(Collectors.toList()) + .contains(daemonSet.getMetadata().getName()); + }, + Duration.ofSeconds(30)); + + // check the shuffle workers be reconciled. + KubernetesUtils.waitUntilCondition( + () -> + getDaemonSetList().size() == 1 + && getDaemonSetList() + .get(0) + .getMetadata() + .getName() + .equals(daemonSet.getMetadata().getName()), + Duration.ofSeconds(30)); + } + + @Ignore + @Test(timeout = 60000L) + public void testNullFileConfigs() throws Exception { + // create a new shuffle application + RemoteShuffleApplication remoteShuffleApplication = + createRemoteShuffleApplication(createDynamicConfigs(), null); + // no configmap. + assertThat(kubeClient.configMaps().inNamespace(NAMESPACE).list().getItems().size(), is(0)); + } + + @Test + public void testCloneDeployment() throws IOException { + KubernetesDeploymentParameters deploymentParameters = + new KubernetesDeploymentBuilderTest.TestingDeploymentParameters(); + Deployment deployment = + new KubernetesDeploymentBuilder().buildKubernetesResourceFrom(deploymentParameters); + Deployment cloneDeployment = RemoteShuffleApplicationController.cloneResource(deployment); + assertEquals(cloneDeployment, deployment); + } + + @Test + public void testCloneDaemonSet() throws IOException { + KubernetesDaemonSetParameters daemonSetParameters = + new KubernetesDaemonSetBuilderTest.TestingDaemonSetParameters(); + DaemonSet daemonSet = + new KubernetesDaemonSetBuilder().buildKubernetesResourceFrom(daemonSetParameters); + DaemonSet cloneDaemonSet = RemoteShuffleApplicationController.cloneResource(daemonSet); + assertEquals(cloneDaemonSet, daemonSet); + } + + /** Runner class for running a shuffle application controller. */ + private static class ControllerRunner implements Runnable { + + public final RemoteShuffleApplicationController remoteShuffleApplicationController; + public final SharedInformerFactory informerFactory; + public final KubernetesClient kubeClient; + public final ExecutorService executorPool = + Executors.newFixedThreadPool(5, new ExecutorThreadFactory("informers")); + + public ControllerRunner(KubernetesClient kubeClient) { + this.kubeClient = kubeClient; + this.informerFactory = kubeClient.informers(executorPool); + this.remoteShuffleApplicationController = + RemoteShuffleApplicationController.createRemoteShuffleApplicationController( + kubeClient, informerFactory); + } + + public void registerShuffleManagerResourceEventHandler( + ResourceEventHandler eventHandler) { + remoteShuffleApplicationController + .getDeploymentInformer() + .addEventHandler(eventHandler); + } + + public void registerShuffleWorkersResourceEventHandler( + ResourceEventHandler eventHandler) { + remoteShuffleApplicationController.getDaemonSetInformer().addEventHandler(eventHandler); + } + + @Override + public void run() { + try { + informerFactory.startAllRegisteredInformers(); + informerFactory.addSharedInformerEventListener( + exception -> LOG.error("Exception occurred, but caught", exception)); + remoteShuffleApplicationController.run(); + } catch (Throwable throwable) { + LOG.error("Shuffle application operator failed.", throwable); + } finally { + executorPool.shutdownNow(); + informerFactory.stopAllRegisteredInformers(); + } + } + + public RemoteShuffleApplicationController getRemoteShuffleApplicationController() { + return remoteShuffleApplicationController; + } + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationSpecTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationSpecTest.java new file mode 100644 index 00000000..02c532f0 --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationSpecTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.crd; + +import org.junit.Before; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +/** Test for {@link RemoteShuffleApplicationSpec}. */ +public class RemoteShuffleApplicationSpecTest { + + private final Map dynamicConfigs = new HashMap<>(); + + private final Map fileConfigs = new HashMap<>(); + + @Before + public void sutUp() { + dynamicConfigs.put("key1", "value1"); + dynamicConfigs.put("key2", "value2"); + dynamicConfigs.put("key3", "value3"); + + fileConfigs.put("file1", "This is a test for config file."); + fileConfigs.put("file2", "This is a test for config file."); + } + + @Test + public void testToString() { + RemoteShuffleApplicationSpec spec = + new RemoteShuffleApplicationSpec(dynamicConfigs, fileConfigs); + assertEquals( + spec.toString(), + "RemoteShuffleApplicationSpec(" + + "shuffleDynamicConfigs={key1=value1, key2=value2, key3=value3}, " + + "shuffleFileConfigs={file2=This is a test for config file., " + + "file1=This is a test for config file.})"); + } + + @Test + public void testEquals() { + RemoteShuffleApplicationSpec spec1 = + new RemoteShuffleApplicationSpec(dynamicConfigs, fileConfigs); + RemoteShuffleApplicationSpec spec2 = + new RemoteShuffleApplicationSpec(dynamicConfigs, fileConfigs); + RemoteShuffleApplicationSpec spec3 = + new RemoteShuffleApplicationSpec(Collections.emptyMap(), fileConfigs); + + assertEquals(spec1, spec2); + assertNotEquals(spec1, spec3); + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationStatusTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationStatusTest.java new file mode 100644 index 00000000..f9432bd0 --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationStatusTest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.crd; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +/** Test for {@link RemoteShuffleApplicationStatus}. */ +public class RemoteShuffleApplicationStatusTest { + + @Test + public void testToString() { + RemoteShuffleApplicationStatus status = new RemoteShuffleApplicationStatus(1, 2, 3, 4); + assertEquals( + status.toString(), + "RemoteShuffleApplicationStatus(readyShuffleManagers=1, readyShuffleWorkers=2, desiredShuffleManagers=3, desiredShuffleWorkers=4)"); + } + + @Test + public void testEquals() { + + RemoteShuffleApplicationStatus status1 = new RemoteShuffleApplicationStatus(1, 2, 3, 4); + RemoteShuffleApplicationStatus status2 = new RemoteShuffleApplicationStatus(1, 2, 3, 4); + RemoteShuffleApplicationStatus status3 = new RemoteShuffleApplicationStatus(0, 2, 3, 4); + RemoteShuffleApplicationStatus status4 = new RemoteShuffleApplicationStatus(1, 0, 3, 4); + RemoteShuffleApplicationStatus status5 = new RemoteShuffleApplicationStatus(1, 2, 0, 4); + RemoteShuffleApplicationStatus status6 = new RemoteShuffleApplicationStatus(1, 2, 3, 0); + + assertEquals(status1, status2); + assertNotEquals(status1, status3); + assertNotEquals(status1, status4); + assertNotEquals(status1, status5); + assertNotEquals(status1, status6); + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationTest.java new file mode 100644 index 00000000..a47900d5 --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/crd/RemoteShuffleApplicationTest.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.crd; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** Test for {@link RemoteShuffleApplication}. */ +public class RemoteShuffleApplicationTest { + + private final RemoteShuffleApplication shuffleApp = new RemoteShuffleApplication(); + + @Test + public void testVersion() { + assertEquals(shuffleApp.getVersion(), "v1"); + } + + @Test + public void testGroup() { + assertEquals(shuffleApp.getGroup(), "shuffleoperator.alibaba.com"); + } + + @Test + public void testSingular() { + assertEquals(shuffleApp.getSingular(), "remoteshuffle"); + } + + @Test + public void testPlural() { + assertEquals(shuffleApp.getPlural(), "remoteshuffles"); + } + + @Test + public void testCRDName() { + assertEquals(shuffleApp.getCRDName(), "remoteshuffles.shuffleoperator.alibaba.com"); + } + + @Test + public void testScope() { + assertEquals(shuffleApp.getScope(), "Namespaced"); + } + + @Test + public void testKind() { + assertEquals(shuffleApp.getKind(), "RemoteShuffle"); + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/AbstractKubernetesParametersTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/AbstractKubernetesParametersTest.java new file mode 100644 index 00000000..79be44a7 --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/AbstractKubernetesParametersTest.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; +import com.alibaba.flink.shuffle.core.config.KubernetesOptions; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.util.ContainerCommandAndArgs; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesInternalOptions; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesTestBase; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +/** Test for {@link AbstractKubernetesParameters}. */ +public class AbstractKubernetesParametersTest extends KubernetesTestBase { + + private final Configuration conf = new Configuration(); + private final TestingKubernetesParameters testingKubernetesParameters = + new TestingKubernetesParameters(conf); + + @Test + public void testGetNamespace() { + conf.setString(KubernetesInternalOptions.NAMESPACE, NAMESPACE); + assertThat(testingKubernetesParameters.getNamespace(), is(NAMESPACE)); + } + + @Test + public void testEnablePodHostNetwork() { + conf.setBoolean(KubernetesOptions.POD_HOST_NETWORK_ENABLED, true); + assertThat(testingKubernetesParameters.enablePodHostNetwork(), is(true)); + } + + @Test + public void testGetContainerImage() { + conf.setString(KubernetesOptions.CONTAINER_IMAGE, CONTAINER_IMAGE); + assertThat(testingKubernetesParameters.getContainerImage(), is(CONTAINER_IMAGE)); + } + + @Test + public void testGetContainerImagePullPolicy() { + conf.setString(KubernetesOptions.CONTAINER_IMAGE_PULL_POLICY, "Always"); + + assertThat(testingKubernetesParameters.getContainerImagePullPolicy(), is("Always")); + } + + @Test + public void testGetClusterId() { + conf.setString(ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID, CLUSTER_ID); + assertThat(testingKubernetesParameters.getClusterId(), is(CLUSTER_ID)); + } + + @Test + public void testClusterIdMustNotBeBlank() { + conf.setString(ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID, " "); + assertThrows( + "must not be blank", + IllegalArgumentException.class, + testingKubernetesParameters::getClusterId); + } + + @Test + public void testClusterIdLengthLimitation() { + final String stringWithIllegalLength = + CommonUtils.randomHexString(Constants.MAXIMUM_CHARACTERS_OF_CLUSTER_ID + 1); + conf.setString(ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID, stringWithIllegalLength); + assertThrows( + "must be no more than " + + Constants.MAXIMUM_CHARACTERS_OF_CLUSTER_ID + + " characters", + IllegalArgumentException.class, + testingKubernetesParameters::getClusterId); + } + + @Test + public void testInvalidContainerImage() { + conf.setString(KubernetesOptions.CONTAINER_IMAGE, " "); + assertThrows( + "Invalid " + KubernetesOptions.CONTAINER_IMAGE + ".", + IllegalArgumentException.class, + testingKubernetesParameters::getContainerImage); + } + + /** Checks whether an exception with a message occurs when running a piece of code. */ + public static void assertThrows( + String msg, Class expected, Callable code) { + try { + Object result = code.call(); + Assert.fail("Previous method call should have failed but it returned: " + result); + } catch (Exception e) { + assertThat(e, instanceOf(expected)); + assertThat(e.getMessage(), containsString(msg)); + } + } + + /** Simple subclass of {@link AbstractKubernetesParameters} for testing purposes. */ + public static class TestingKubernetesParameters extends AbstractKubernetesParameters { + + public TestingKubernetesParameters(Configuration conf) { + super(conf); + } + + @Override + public Map getLabels() { + throw new UnsupportedOperationException("NOT supported"); + } + + @Override + public Map getNodeSelector() { + throw new UnsupportedOperationException("NOT supported"); + } + + @Override + public List> getEmptyDirVolumes() { + throw new UnsupportedOperationException("NOT supported"); + } + + @Override + public List> getHostPathVolumes() { + throw new UnsupportedOperationException("NOT supported"); + } + + @Override + public List> getTolerations() { + throw new UnsupportedOperationException("NOT supported"); + } + + @Override + public String getContainerName() { + throw new UnsupportedOperationException("NOT supported"); + } + + @Override + public List> getContainerVolumeMounts() { + throw new UnsupportedOperationException("NOT supported"); + } + + @Override + public Integer getContainerMemoryMB() { + throw new UnsupportedOperationException("NOT supported"); + } + + @Override + public Double getContainerCPU() { + throw new UnsupportedOperationException("NOT supported"); + } + + @Override + public Map getResourceLimitFactors() { + throw new UnsupportedOperationException("NOT supported"); + } + + @Override + public ContainerCommandAndArgs getContainerCommandAndArgs() { + throw new UnsupportedOperationException("NOT supported"); + } + + @Override + public Map getEnvironmentVars() { + throw new UnsupportedOperationException("NOT supported"); + } + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/K8sRemoteShuffleFileConfigsParametersTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/K8sRemoteShuffleFileConfigsParametersTest.java new file mode 100644 index 00000000..d46f3675 --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/K8sRemoteShuffleFileConfigsParametersTest.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters; + +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesTestBase; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; + +/** Test for {@link K8sRemoteShuffleFileConfigsParameters}. */ +public class K8sRemoteShuffleFileConfigsParametersTest extends KubernetesTestBase { + + private K8sRemoteShuffleFileConfigsParameters shuffleFileConfigsParameters; + + @Before + public void setup() { + shuffleFileConfigsParameters = + new K8sRemoteShuffleFileConfigsParameters( + NAMESPACE, CLUSTER_ID, CONFIG_MAP_VOLUME.getItems()); + } + + @Test + public void testGetConfigMapName() { + assertThat(shuffleFileConfigsParameters.getConfigMapName(), is(CLUSTER_ID + "-configmap")); + } + + @Test + public void testGetData() { + Assert.assertEquals(shuffleFileConfigsParameters.getData(), CONFIG_MAP_VOLUME.getItems()); + } + + @Test + public void testGetNamespace() { + assertThat(shuffleFileConfigsParameters.getNamespace(), is(NAMESPACE)); + } + + @Test + public void testGetLabels() { + assertEquals(shuffleFileConfigsParameters.getLabels(), getCommonLabels()); + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesShuffleManagerParametersTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesShuffleManagerParametersTest.java new file mode 100644 index 00000000..b10e6123 --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesShuffleManagerParametersTest.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.config.MemorySize; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; +import com.alibaba.flink.shuffle.core.config.KubernetesOptions; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.util.ContainerCommandAndArgs; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesInternalOptions; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesTestBase; + +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; + +/** Test for {@link KubernetesShuffleManagerParameters}. */ +public class KubernetesShuffleManagerParametersTest extends KubernetesTestBase { + + private Configuration conf; + private KubernetesShuffleManagerParameters shuffleManagerParameters; + + @Before + public void setup() { + conf = new Configuration(); + + // cluster id and namespace + conf.setString(ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID, CLUSTER_ID); + conf.setString(KubernetesInternalOptions.NAMESPACE, NAMESPACE); + + // memory config + conf.setMemorySize( + ManagerOptions.FRAMEWORK_HEAP_MEMORY, + MemorySize.parse(CONTAINER_FRAMEWORK_HEAP_MEMORY_MB + "m")); + conf.setMemorySize( + ManagerOptions.FRAMEWORK_OFF_HEAP_MEMORY, + MemorySize.parse(CONTAINER_FRAMEWORK_OFF_HEAP_MEMORY_MB + "m")); + conf.setMemorySize( + ManagerOptions.JVM_METASPACE, MemorySize.parse(CONTAINER_JVM_METASPACE_MB + "m")); + conf.setMemorySize( + ManagerOptions.JVM_OVERHEAD, MemorySize.parse(CONTAINER_JVM_OVERHEAD_MB + "m")); + + shuffleManagerParameters = new KubernetesShuffleManagerParameters(conf); + } + + @Test + public void testGetLabels() { + + conf.setMap(KubernetesOptions.SHUFFLE_MANAGER_LABELS, USER_LABELS); + + final Map expectedLabels = new HashMap<>(getCommonLabels()); + expectedLabels.put( + Constants.LABEL_COMPONENT_KEY, Constants.LABEL_COMPONENT_SHUFFLE_MANAGER); + expectedLabels.putAll(USER_LABELS); + assertEquals(shuffleManagerParameters.getLabels(), expectedLabels); + } + + @Test + public void testGetNodeSelector() { + conf.setMap(KubernetesOptions.SHUFFLE_MANAGER_NODE_SELECTOR, NODE_SELECTOR); + assertEquals(shuffleManagerParameters.getNodeSelector(), NODE_SELECTOR); + } + + @Test + public void testGetContainerVolumeMounts() { + conf.setList(KubernetesOptions.SHUFFLE_MANAGER_HOST_PATH_VOLUMES, HOST_PATH_VOLUMES); + assertEquals(shuffleManagerParameters.getContainerVolumeMounts(), CONTAINER_VOLUME_MOUNTS); + } + + @Test + public void testGetContainerMemoryMB() { + assertThat(shuffleManagerParameters.getContainerMemoryMB(), is(getTotalMemory(false))); + } + + @Test + public void testGetContainerCPU() { + conf.setDouble(KubernetesOptions.SHUFFLE_MANAGER_CPU, CONTAINER_CPU); + assertThat(shuffleManagerParameters.getContainerCPU(), is(CONTAINER_CPU)); + } + + @Test + public void testGetResourceLimitFactors() { + final Map limitFactors = new HashMap<>(); + limitFactors.put("cpu", "3.2"); + limitFactors.put("memory", "1.6"); + + conf.setString("remote-shuffle.kubernetes.manager.limit-factor.cpu", "3.2"); + conf.setString("remote-shuffle.kubernetes.manager.limit-factor.memory", "1.6"); + + assertEquals(shuffleManagerParameters.getResourceLimitFactors(), limitFactors); + } + + @Test + public void testGetContainerCommandAndArgs() { + + conf.setString(ManagerOptions.JVM_OPTIONS, CONTAINER_JVM_OPTIONS); + + ContainerCommandAndArgs commandAndArgs = + shuffleManagerParameters.getContainerCommandAndArgs(); + + assertThat("bash", is(commandAndArgs.getCommand())); + assertEquals( + commandAndArgs.getArgs(), + Arrays.asList( + "-c", + "/flink-remote-shuffle/bin/kubernetes-shufflemanager.sh" + + " -D 'remote-shuffle.cluster.id=TestingCluster'" + + " -D 'remote-shuffle.manager.jvm-opts=-verbose:gc -XX:+PrintGCDetails -XX:+PrintGCDateStamps -XX:ParallelGCThreads=4'" + + " -D 'remote-shuffle.manager.memory.heap-size=256 mb'" + + " -D 'remote-shuffle.manager.memory.jvm-metaspace-size=32 mb'" + + " -D 'remote-shuffle.manager.memory.jvm-overhead-size=32 mb'" + + " -D 'remote-shuffle.manager.memory.off-heap-size=128 mb'")); + } + + @Test + public void testGetDeploymentName() { + assertThat( + shuffleManagerParameters.getDeploymentName(), is(CLUSTER_ID + "-shufflemanager")); + } + + @Test + public void testGetReplicas() { + assertThat(shuffleManagerParameters.getReplicas(), is(1)); + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesShuffleWorkerParametersTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesShuffleWorkerParametersTest.java new file mode 100644 index 00000000..18a6d03d --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/parameters/KubernetesShuffleWorkerParametersTest.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.parameters; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.config.MemorySize; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; +import com.alibaba.flink.shuffle.core.config.KubernetesOptions; +import com.alibaba.flink.shuffle.core.config.MemoryOptions; +import com.alibaba.flink.shuffle.core.config.WorkerOptions; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.util.ContainerCommandAndArgs; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesInternalOptions; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesTestBase; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesUtils; + +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; + +/** Test for {@link KubernetesShuffleWorkerParameters}. */ +public class KubernetesShuffleWorkerParametersTest extends KubernetesTestBase { + + private Configuration conf; + private KubernetesShuffleWorkerParameters shuffleWorkerParameters; + + @Before + public void setup() { + conf = new Configuration(); + + // cluster id and namespace + conf.setString(ClusterOptions.REMOTE_SHUFFLE_CLUSTER_ID, CLUSTER_ID); + conf.setString(KubernetesInternalOptions.NAMESPACE, NAMESPACE); + + // memory config + conf.setMemorySize( + WorkerOptions.FRAMEWORK_HEAP_MEMORY, + MemorySize.parse(CONTAINER_FRAMEWORK_HEAP_MEMORY_MB + "m")); + conf.setMemorySize( + WorkerOptions.FRAMEWORK_OFF_HEAP_MEMORY, + MemorySize.parse(CONTAINER_FRAMEWORK_OFF_HEAP_MEMORY_MB + "m")); + conf.setMemorySize(MemoryOptions.MEMORY_BUFFER_SIZE, NETWORK_MEMORY_BUFFER_SIZE); + conf.setMemorySize(MemoryOptions.MEMORY_SIZE_FOR_DATA_READING, NETWORK_READING_MEMORY_SIZE); + conf.setMemorySize(MemoryOptions.MEMORY_SIZE_FOR_DATA_WRITING, NETWORK_WRITING_MEMORY_SIZE); + conf.setMemorySize( + WorkerOptions.JVM_METASPACE, MemorySize.parse(CONTAINER_JVM_METASPACE_MB + "m")); + conf.setMemorySize( + WorkerOptions.JVM_OVERHEAD, MemorySize.parse(CONTAINER_JVM_OVERHEAD_MB + "m")); + + shuffleWorkerParameters = new KubernetesShuffleWorkerParameters(conf); + } + + @Test + public void testGetLabels() { + conf.setMap(KubernetesOptions.SHUFFLE_WORKER_LABELS, USER_LABELS); + + final Map expectedLabels = new HashMap<>(getCommonLabels()); + expectedLabels.put(Constants.LABEL_COMPONENT_KEY, Constants.LABEL_COMPONENT_SHUFFLE_WORKER); + expectedLabels.putAll(USER_LABELS); + assertEquals(shuffleWorkerParameters.getLabels(), expectedLabels); + } + + @Test + public void testGetNodeSelector() { + conf.setMap(KubernetesOptions.SHUFFLE_WORKER_NODE_SELECTOR, NODE_SELECTOR); + assertEquals(shuffleWorkerParameters.getNodeSelector(), NODE_SELECTOR); + } + + @Test + public void testGetContainerName() { + MatcherAssert.assertThat( + shuffleWorkerParameters.getContainerName(), + CoreMatchers.is(KubernetesUtils.SHUFFLE_WORKER_CONTAINER_NAME)); + } + + @Test + public void testGetContainerMemoryMB() { + assertThat(shuffleWorkerParameters.getContainerMemoryMB(), is(getTotalMemory(true))); + } + + @Test + public void testGetContainerCPU() { + conf.setDouble(KubernetesOptions.SHUFFLE_WORKER_CPU, CONTAINER_CPU); + assertThat(shuffleWorkerParameters.getContainerCPU(), is(CONTAINER_CPU)); + } + + @Test + public void testGetResourceLimitFactors() { + final Map limitFactors = new HashMap<>(); + limitFactors.put("cpu", "3.0"); + limitFactors.put("memory", "1.5"); + + conf.setString("remote-shuffle.kubernetes.worker.limit-factor.cpu", "3.0"); + conf.setString("remote-shuffle.kubernetes.worker.limit-factor.memory", "1.5"); + + assertEquals(shuffleWorkerParameters.getResourceLimitFactors(), limitFactors); + } + + @Test + public void testGetContainerVolumeMounts() { + conf.setList(KubernetesOptions.SHUFFLE_WORKER_HOST_PATH_VOLUMES, HOST_PATH_VOLUMES); + assertEquals(shuffleWorkerParameters.getContainerVolumeMounts(), CONTAINER_VOLUME_MOUNTS); + } + + @Test + public void testGetContainerCommandAndArgs() { + conf.setString(WorkerOptions.JVM_OPTIONS, CONTAINER_JVM_OPTIONS); + + ContainerCommandAndArgs commandAndArgs = + shuffleWorkerParameters.getContainerCommandAndArgs(); + + assertThat("bash", is(commandAndArgs.getCommand())); + assertEquals( + commandAndArgs.getArgs(), + Arrays.asList( + "-c", + "/flink-remote-shuffle/bin/kubernetes-shuffleworker.sh" + + " -D 'remote-shuffle.cluster.id=TestingCluster'" + + " -D 'remote-shuffle.memory.buffer-size=1 mb'" + + " -D 'remote-shuffle.memory.data-reading-size=32 mb'" + + " -D 'remote-shuffle.memory.data-writing-size=32 mb'" + + " -D 'remote-shuffle.worker.jvm-opts=-verbose:gc -XX:+PrintGCDetails -XX:+PrintGCDateStamps -XX:ParallelGCThreads=4'" + + " -D 'remote-shuffle.worker.memory.heap-size=256 mb'" + + " -D 'remote-shuffle.worker.memory.jvm-metaspace-size=32 mb'" + + " -D 'remote-shuffle.worker.memory.jvm-overhead-size=32 mb'" + + " -D 'remote-shuffle.worker.memory.off-heap-size=128 mb'")); + } + + @Test + public void testGetDaemonSetName() { + assertThat(shuffleWorkerParameters.getDaemonSetName(), is(CLUSTER_ID + "-shuffleworker")); + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/reconciler/RemoteShuffleApplicationReconcilerTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/reconciler/RemoteShuffleApplicationReconcilerTest.java new file mode 100644 index 00000000..bfbcc471 --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/reconciler/RemoteShuffleApplicationReconcilerTest.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.reconciler; + +import com.alibaba.flink.shuffle.core.config.KubernetesOptions; +import com.alibaba.flink.shuffle.kubernetes.operator.crd.RemoteShuffleApplication; +import com.alibaba.flink.shuffle.kubernetes.operator.crd.RemoteShuffleApplicationSpec; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesUtils; + +import io.fabric8.kubernetes.api.model.ConfigMap; +import io.fabric8.kubernetes.api.model.ConfigMapBuilder; +import io.fabric8.kubernetes.api.model.apps.DaemonSet; +import io.fabric8.kubernetes.api.model.apps.DaemonSetBuilder; +import io.fabric8.kubernetes.api.model.apps.Deployment; +import io.fabric8.kubernetes.api.model.apps.DeploymentBuilder; +import io.fabric8.kubernetes.client.KubernetesClient; +import io.fabric8.kubernetes.client.server.mock.KubernetesServer; +import io.fabric8.kubernetes.client.utils.Serialization; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.Assert; +import org.junit.Test; + +import java.net.HttpURLConnection; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** Test for {@link RemoteShuffleApplicationReconciler}. */ +public class RemoteShuffleApplicationReconcilerTest { + + Map dynamicConfigs = + new HashMap() { + { + put(KubernetesOptions.CONTAINER_IMAGE.key(), "flink-remote-shuffle:latest"); + } + }; + + Map fileConfigs = + new HashMap() { + { + put("file1", "This is a test for config file."); + } + }; + + @Test + public void testReconcile() throws Exception { + KubernetesServer server = + new KubernetesServer( + false, + false, + InetAddress.getByName("127.0.0.1"), + 0, + Collections.emptyList()); + server.before(); + + String namespace = "default"; + RemoteShuffleApplication testShuffleApp = + getRemoteShuffleApplication( + "testShuffleApp", namespace, "0800cff3-9d80-11ea-8973-0e13a02d8ebd"); + + // setup for ConfigMap. + server.expect() + .post() + .withPath("/api/v1/namespaces/" + namespace + "/configmaps") + .andReturn( + HttpURLConnection.HTTP_CREATED, + new ConfigMapBuilder().withNewMetadata().endMetadata().build()) + .times(1); + + // setup for ShuffleManager. + server.expect() + .post() + .withPath("/apis/apps/v1/namespaces/" + namespace + "/deployments") + .andReturn( + HttpURLConnection.HTTP_CREATED, + new DeploymentBuilder().withNewMetadata().endMetadata().build()) + .times(1); + + // setup for ShuffleWorkers. + server.expect() + .post() + .withPath("/apis/apps/v1/namespaces/" + namespace + "/daemonsets") + .andReturn( + HttpURLConnection.HTTP_CREATED, + new DaemonSetBuilder().withNewMetadata().endMetadata().build()) + .times(1); + + KubernetesClient client = server.getClient(); + RemoteShuffleApplicationReconciler reconciler = + new RemoteShuffleApplicationReconciler(client); + + // trigger reconcile + reconciler.reconcile(testShuffleApp); + + // check deploy ConfigMap request that server has received. + RecordedRequest configMapRequest = server.getMockServer().takeRequest(); + assertEquals("POST", configMapRequest.getMethod()); + + String configMapRequestBody = configMapRequest.getBody().readUtf8(); + ConfigMap configMapInRequest = + Serialization.jsonMapper().readValue(configMapRequestBody, ConfigMap.class); + assertNotNull(configMapInRequest); + assertEquals( + configMapInRequest.getMetadata().getName(), + testShuffleApp.getMetadata().getName() + "-configmap"); + assertEquals(1, configMapInRequest.getMetadata().getOwnerReferences().size()); + assertEquals( + testShuffleApp.getMetadata().getName(), + configMapInRequest.getMetadata().getOwnerReferences().get(0).getName()); + + // check deploy ShuffleManager request that server has received. + RecordedRequest shuffleManagerRequest = server.getMockServer().takeRequest(); + assertEquals("POST", shuffleManagerRequest.getMethod()); + + String shuffleManagerRequestBody = shuffleManagerRequest.getBody().readUtf8(); + Deployment deploymentInRequest = + Serialization.jsonMapper().readValue(shuffleManagerRequestBody, Deployment.class); + assertNotNull(deploymentInRequest); + Assert.assertEquals( + KubernetesUtils.getShuffleManagerNameWithClusterId( + testShuffleApp.getMetadata().getName()), + deploymentInRequest.getMetadata().getName()); + assertEquals(1, deploymentInRequest.getMetadata().getOwnerReferences().size()); + assertEquals( + testShuffleApp.getMetadata().getName(), + deploymentInRequest.getMetadata().getOwnerReferences().get(0).getName()); + + // check deploy ShuffleWorkers request that server has received. + RecordedRequest shuffleWorkersRequest = server.getMockServer().takeRequest(); + assertEquals("POST", shuffleWorkersRequest.getMethod()); + + String shuffleWorkersRequestBody = shuffleWorkersRequest.getBody().readUtf8(); + DaemonSet daemonSetInRequest = + Serialization.jsonMapper().readValue(shuffleWorkersRequestBody, DaemonSet.class); + assertNotNull(daemonSetInRequest); + assertEquals( + KubernetesUtils.getShuffleWorkersNameWithClusterId( + testShuffleApp.getMetadata().getName()), + daemonSetInRequest.getMetadata().getName()); + assertEquals(1, daemonSetInRequest.getMetadata().getOwnerReferences().size()); + assertEquals( + testShuffleApp.getMetadata().getName(), + daemonSetInRequest.getMetadata().getOwnerReferences().get(0).getName()); + } + + private RemoteShuffleApplication getRemoteShuffleApplication( + String name, String namespace, String uid) { + + RemoteShuffleApplicationSpec shuffleAppSpec = new RemoteShuffleApplicationSpec(); + shuffleAppSpec.setShuffleDynamicConfigs(dynamicConfigs); + shuffleAppSpec.setShuffleFileConfigs(fileConfigs); + + RemoteShuffleApplication shuffleApp = new RemoteShuffleApplication(); + shuffleApp.getMetadata().setName(name); + shuffleApp.getMetadata().setNamespace(namespace); + shuffleApp.getMetadata().setUid(uid); + shuffleApp.setSpec(shuffleAppSpec); + + return shuffleApp; + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesConfigMapBuilderTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesConfigMapBuilderTest.java new file mode 100644 index 00000000..83bd8be8 --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesConfigMapBuilderTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.resources; + +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesConfigMapParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesTestBase; + +import io.fabric8.kubernetes.api.model.ConfigMap; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** Test for {@link KubernetesConfigMapBuilder}. */ +public class KubernetesConfigMapBuilderTest extends KubernetesTestBase { + + private KubernetesConfigMapParameters configMapParameters; + private ConfigMap resultConfigMap; + + @Before + public void setup() { + configMapParameters = new TestingConfigMapParameters(); + resultConfigMap = + new KubernetesConfigMapBuilder().buildKubernetesResourceFrom(configMapParameters); + } + + @Test + public void testApiVersion() { + Assert.assertEquals(Constants.API_VERSION, resultConfigMap.getApiVersion()); + } + + @Test + public void testConfigMapName() { + assertEquals(CONFIG_MAP_NAME, resultConfigMap.getMetadata().getName()); + } + + @Test + public void testNameSpace() { + assertEquals(NAMESPACE, resultConfigMap.getMetadata().getNamespace()); + } + + @Test + public void testLabels() { + assertEquals(USER_LABELS, resultConfigMap.getMetadata().getLabels()); + } + + /** Simple {@link KubernetesConfigMapParameters} implementation for testing purposes. */ + public static class TestingConfigMapParameters implements KubernetesConfigMapParameters { + + @Override + public String getConfigMapName() { + return CONFIG_MAP_NAME; + } + + @Override + public Map getData() { + return CONFIG_MAP_VOLUME.getItems(); + } + + @Override + public String getNamespace() { + return NAMESPACE; + } + + @Override + public Map getLabels() { + return USER_LABELS; + } + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesContainerBuilderTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesContainerBuilderTest.java new file mode 100644 index 00000000..8c9eabc8 --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesContainerBuilderTest.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.resources; + +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesContainerParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.util.ContainerCommandAndArgs; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesTestBase; + +import io.fabric8.kubernetes.api.model.Container; +import io.fabric8.kubernetes.api.model.Quantity; +import io.fabric8.kubernetes.api.model.ResourceRequirements; +import io.fabric8.kubernetes.api.model.VolumeMount; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** Test for {@link KubernetesContainerBuilder}. */ +public class KubernetesContainerBuilderTest extends KubernetesTestBase { + + private Container resultContainer; + + @Before + public void setup() { + resultContainer = + new KubernetesContainerBuilder() + .buildKubernetesResourceFrom(new TestingContainerParameters()); + } + + @Test + public void testContainerName() { + assertEquals(CONTAINER_NAME, resultContainer.getName()); + } + + @Test + public void testContainerImage() { + assertEquals(CONTAINER_IMAGE, resultContainer.getImage()); + assertEquals("Always", resultContainer.getImagePullPolicy()); + } + + @Test + public void testMainContainerResourceRequirements() { + final ResourceRequirements resourceRequirements = resultContainer.getResources(); + + final Map requests = resourceRequirements.getRequests(); + final Map limits = resourceRequirements.getLimits(); + + assertEquals(Double.toString(CONTAINER_CPU), requests.get("cpu").getAmount()); + assertEquals( + Double.toString( + CONTAINER_CPU * Double.parseDouble(RESOURCE_LIMIT_FACTOR.get("cpu"))), + limits.get("cpu").getAmount()); + assertEquals(String.valueOf(getTotalMemory(false)), requests.get("memory").getAmount()); + assertEquals( + String.valueOf( + getTotalMemory(false) + * Double.parseDouble(RESOURCE_LIMIT_FACTOR.get("memory"))), + limits.get("memory").getAmount()); + } + + @Test + public void testContainerCommandAndArgs() { + assertEquals(Collections.singletonList("exec"), resultContainer.getCommand()); + assertEquals(Arrays.asList("bash", "-c", "sleep"), resultContainer.getArgs()); + } + + @Test + public void testContainerVolumeMounts() { + List> resultVolumeMounts = new ArrayList<>(); + for (VolumeMount volumeMount : resultContainer.getVolumeMounts()) { + Map resultVolumeMount = new HashMap<>(); + resultVolumeMount.put("name", volumeMount.getName()); + resultVolumeMount.put("mountPath", volumeMount.getMountPath()); + resultVolumeMounts.add(resultVolumeMount); + } + + assertEquals(CONTAINER_VOLUME_MOUNTS, resultVolumeMounts); + } + + /** Simple {@link KubernetesContainerParameters} implementation for testing purposes. */ + public static class TestingContainerParameters implements KubernetesContainerParameters { + + @Override + public String getContainerName() { + return CONTAINER_NAME; + } + + @Override + public String getContainerImage() { + return CONTAINER_IMAGE; + } + + @Override + public String getContainerImagePullPolicy() { + return "Always"; + } + + @Override + public List> getContainerVolumeMounts() { + return CONTAINER_VOLUME_MOUNTS; + } + + @Override + public Integer getContainerMemoryMB() { + return getTotalMemory(false); + } + + @Override + public Double getContainerCPU() { + return CONTAINER_CPU; + } + + @Override + public Map getResourceLimitFactors() { + return RESOURCE_LIMIT_FACTOR; + } + + @Override + public ContainerCommandAndArgs getContainerCommandAndArgs() { + return new ContainerCommandAndArgs("exec", Arrays.asList("bash", "-c", "sleep")); + } + + @Override + public Map getEnvironmentVars() { + return Collections.emptyMap(); + } + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesDaemonSetBuilderTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesDaemonSetBuilderTest.java new file mode 100644 index 00000000..3a63fc7e --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesDaemonSetBuilderTest.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.resources; + +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesDaemonSetParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesPodParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesTestBase; + +import io.fabric8.kubernetes.api.model.apps.DaemonSet; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Map; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertEquals; + +/** Test for {@link KubernetesDaemonSetBuilder}. */ +public class KubernetesDaemonSetBuilderTest extends KubernetesTestBase { + + private KubernetesDaemonSetParameters daemonSetParameters; + private DaemonSet resultDaemonSet; + + @Before + public void setup() { + daemonSetParameters = new TestingDaemonSetParameters(); + resultDaemonSet = + new KubernetesDaemonSetBuilder().buildKubernetesResourceFrom(daemonSetParameters); + } + + @Test + public void testApiVersion() { + Assert.assertEquals(Constants.APPS_API_VERSION, resultDaemonSet.getApiVersion()); + } + + @Test + public void testDaemonSetName() { + assertThat(DAEMON_SET_NAME, is(resultDaemonSet.getMetadata().getName())); + } + + @Test + public void testNameSpace() { + assertEquals(NAMESPACE, resultDaemonSet.getMetadata().getNamespace()); + } + + @Test + public void testLabels() { + assertEquals(USER_LABELS, resultDaemonSet.getMetadata().getLabels()); + } + + /** Simple {@link KubernetesDaemonSetParameters} implementation for testing purposes. */ + public static class TestingDaemonSetParameters implements KubernetesDaemonSetParameters { + + @Override + public String getDaemonSetName() { + return DAEMON_SET_NAME; + } + + @Override + public KubernetesPodParameters getPodTemplateParameters() { + return new KubernetesPodBuilderTest.TestingPodParameters(); + } + + @Override + public String getNamespace() { + return NAMESPACE; + } + + @Override + public Map getLabels() { + return USER_LABELS; + } + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesDeploymentBuilderTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesDeploymentBuilderTest.java new file mode 100644 index 00000000..dd4f86cc --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesDeploymentBuilderTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.resources; + +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesDeploymentParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesPodParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesTestBase; + +import io.fabric8.kubernetes.api.model.apps.Deployment; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** Test for {@link KubernetesDeploymentBuilder}. */ +public class KubernetesDeploymentBuilderTest extends KubernetesTestBase { + + private KubernetesDeploymentParameters deploymentParameters; + private Deployment resultDeployment; + + @Before + public void setup() { + deploymentParameters = new TestingDeploymentParameters(); + resultDeployment = + new KubernetesDeploymentBuilder().buildKubernetesResourceFrom(deploymentParameters); + } + + @Test + public void testApiVersion() { + Assert.assertEquals(Constants.APPS_API_VERSION, resultDeployment.getApiVersion()); + } + + @Test + public void testDeploymentName() { + assertEquals(DEPLOYMENT_NAME, resultDeployment.getMetadata().getName()); + } + + @Test + public void testNameSpace() { + assertEquals(NAMESPACE, resultDeployment.getMetadata().getNamespace()); + } + + @Test + public void testLabels() { + assertEquals(USER_LABELS, resultDeployment.getMetadata().getLabels()); + } + + @Test + public void testReplicas() { + assertEquals(Integer.valueOf(1), resultDeployment.getSpec().getReplicas()); + } + + @Test + public void testSelector() { + assertEquals( + USER_LABELS, resultDeployment.getSpec().getTemplate().getMetadata().getLabels()); + } + + /** Simple {@link KubernetesDeploymentParameters} implementation for testing purposes. */ + public static class TestingDeploymentParameters implements KubernetesDeploymentParameters { + + @Override + public String getDeploymentName() { + return DEPLOYMENT_NAME; + } + + @Override + public int getReplicas() { + return 1; + } + + @Override + public KubernetesPodParameters getPodTemplateParameters() { + return new KubernetesPodBuilderTest.TestingPodParameters(); + } + + @Override + public String getNamespace() { + return NAMESPACE; + } + + @Override + public Map getLabels() { + return USER_LABELS; + } + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesPodBuilderTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesPodBuilderTest.java new file mode 100644 index 00000000..49ea9725 --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/resources/KubernetesPodBuilderTest.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.resources; + +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesContainerParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.KubernetesPodParameters; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.util.ConfigMapVolume; +import com.alibaba.flink.shuffle.kubernetes.operator.util.Constants; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesTestBase; +import com.alibaba.flink.shuffle.kubernetes.operator.util.KubernetesUtils; + +import io.fabric8.kubernetes.api.model.Pod; +import io.fabric8.kubernetes.api.model.Toleration; +import io.fabric8.kubernetes.api.model.Volume; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertEquals; + +/** Test for {@link KubernetesPodBuilder}. */ +public class KubernetesPodBuilderTest extends KubernetesTestBase { + + private KubernetesPodParameters podParameters; + private Pod resultPod; + + @Before + public void setup() { + podParameters = new TestingPodParameters(); + resultPod = new KubernetesPodBuilder().buildKubernetesResourceFrom(podParameters); + } + + @Test + public void testApiVersion() { + Assert.assertEquals(Constants.API_VERSION, this.resultPod.getApiVersion()); + } + + @Test + public void testLabels() { + assertEquals(USER_LABELS, resultPod.getMetadata().getLabels()); + } + + @Test + public void testHostNetwork() { + assertEquals(false, resultPod.getSpec().getHostNetwork()); + } + + @Test + public void testNodeSelector() { + assertThat(this.resultPod.getSpec().getNodeSelector(), is(equalTo(NODE_SELECTOR))); + } + + @Test + public void testVolumes() { + List volumes = this.resultPod.getSpec().getVolumes(); + assertThat(volumes.size(), is(EMPTY_DIR_VOLUMES.size() + HOST_PATH_VOLUMES.size() + 1)); + List emptyDirVolumes = Arrays.asList(volumes.get(0), volumes.get(1)); + List hostPathVolumes = Arrays.asList(volumes.get(2), volumes.get(3)); + Volume configMapVolume = volumes.get(4); + + // emptyDir + List> emptyDirParams = new ArrayList<>(); + for (Volume volume : emptyDirVolumes) { + assertThat(volume.getEmptyDir(), not(nullValue())); + + Map emptyDirParam = new HashMap<>(); + emptyDirParam.put("name", volume.getName()); + emptyDirParam.put("sizeLimit", volume.getEmptyDir().getSizeLimit().toString()); + emptyDirParams.add(emptyDirParam); + } + + assertEquals( + EMPTY_DIR_VOLUMES.stream() + .map(KubernetesUtils::filterEmptyDirVolumeConfigs) + .collect(Collectors.toList()), + emptyDirParams); + + // hostPath + List> hostPathParams = new ArrayList<>(); + for (Volume volume : hostPathVolumes) { + assertThat(volume.getHostPath(), not(nullValue())); + + Map hostPathParam = new HashMap<>(); + hostPathParam.put("name", volume.getName()); + hostPathParam.put("path", volume.getHostPath().getPath()); + hostPathParams.add(hostPathParam); + } + assertEquals( + HOST_PATH_VOLUMES.stream() + .map(KubernetesUtils::filterHostPathVolumeConfigs) + .collect(Collectors.toList()), + hostPathParams); + + // configmap + assertThat(configMapVolume.getConfigMap(), not(nullValue())); + assertEquals(configMapVolume.getName(), CONFIG_MAP_VOLUME.getVolumeName()); + assertEquals( + configMapVolume.getConfigMap().getName(), CONFIG_MAP_VOLUME.getConfigMapName()); + assertEquals( + configMapVolume.getConfigMap().getItems().stream() + .collect(Collectors.toMap(item -> item.getKey(), item -> item.getPath())), + CONFIG_MAP_VOLUME.getItems()); + } + + @Test + public void testTolerations() { + List> tolerationParams = new ArrayList<>(); + for (Toleration toleration : resultPod.getSpec().getTolerations()) { + Map tolerationParam = new HashMap<>(); + tolerationParam.put("effect", toleration.getEffect()); + tolerationParam.put("key", toleration.getKey()); + tolerationParam.put("operator", toleration.getOperator()); + tolerationParam.put("value", toleration.getValue()); + + tolerationParams.add(tolerationParam); + } + assertEquals(tolerationParams, TOLERATIONS); + } + + /** Simple {@link KubernetesPodParameters} implementation for testing purposes. */ + public static class TestingPodParameters implements KubernetesPodParameters { + + @Override + public Map getNodeSelector() { + return NODE_SELECTOR; + } + + @Override + public boolean enablePodHostNetwork() { + return false; + } + + @Override + public List> getEmptyDirVolumes() { + return EMPTY_DIR_VOLUMES; + } + + @Override + public List> getHostPathVolumes() { + return HOST_PATH_VOLUMES; + } + + @Override + public List getConfigMapVolumes() { + return Collections.singletonList(CONFIG_MAP_VOLUME); + } + + @Override + public List> getTolerations() { + return TOLERATIONS; + } + + @Override + public KubernetesContainerParameters getContainerParameters() { + return new KubernetesContainerBuilderTest.TestingContainerParameters(); + } + + @Override + public String getNamespace() { + return NAMESPACE; + } + + @Override + public Map getLabels() { + return USER_LABELS; + } + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/util/KubernetesTestBase.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/util/KubernetesTestBase.java new file mode 100644 index 00000000..9a76a992 --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/util/KubernetesTestBase.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.util; + +import com.alibaba.flink.shuffle.common.config.MemorySize; +import com.alibaba.flink.shuffle.kubernetes.operator.parameters.util.ConfigMapVolume; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Base class for kubernetes unit test. */ +public class KubernetesTestBase { + + // common config. + protected static final String CLUSTER_ID = "TestingCluster"; + protected static final String NAMESPACE = "default"; + protected static final String SERVICE_ACCOUNT = "default"; + + // container image + protected static final String CONTAINER_NAME = "TestingContainer"; + protected static final String CONTAINER_IMAGE = "flink-remote-shuffle-k8s-test:latest"; + + // container cpu and memory + protected static final double CONTAINER_CPU = 1.5; + protected static final int CONTAINER_FRAMEWORK_HEAP_MEMORY_MB = 256; + protected static final int CONTAINER_FRAMEWORK_OFF_HEAP_MEMORY_MB = 128; + protected static final MemorySize NETWORK_MEMORY_BUFFER_SIZE = MemorySize.parse("1m"); + protected static final MemorySize NETWORK_WRITING_MEMORY_SIZE = MemorySize.parse("32m"); + protected static final MemorySize NETWORK_READING_MEMORY_SIZE = MemorySize.parse("32m"); + + protected static final int CONTAINER_JVM_METASPACE_MB = 32; + protected static final int CONTAINER_JVM_OVERHEAD_MB = 32; + + // container command + protected static final String CONTAINER_JVM_OPTIONS = + "-verbose:gc -XX:+PrintGCDetails -XX:+PrintGCDateStamps -XX:ParallelGCThreads=4"; + + // deployment, daemonset, configmap name + protected static final String DEPLOYMENT_NAME = "TestingDeployment"; + protected static final String DAEMON_SET_NAME = "TestingDaemonSet"; + protected static final String CONFIG_MAP_NAME = "TestingConfigMap"; + + protected static final List> CONTAINER_VOLUME_MOUNTS = + new ArrayList>() { + { + add( + new HashMap() { + { + put(Constants.VOLUME_NAME, "disk1"); + put(Constants.VOLUME_MOUNT_PATH, "/opt/disk1"); + } + }); + add( + new HashMap() { + { + put(Constants.VOLUME_NAME, "disk2"); + put(Constants.VOLUME_MOUNT_PATH, "/opt/disk2"); + } + }); + } + }; + + protected static final List> EMPTY_DIR_VOLUMES = + new ArrayList>() { + { + add( + new HashMap() { + { + put(Constants.VOLUME_NAME, "disk1"); + put(Constants.EMPTY_DIR_VOLUME_SIZE_LIMIT, "5Gi"); + put(Constants.VOLUME_MOUNT_PATH, "/opt/disk1"); + } + }); + add( + new HashMap() { + { + put(Constants.VOLUME_NAME, "disk2"); + put(Constants.EMPTY_DIR_VOLUME_SIZE_LIMIT, "10Gi"); + put(Constants.VOLUME_MOUNT_PATH, "/opt/disk2"); + } + }); + } + }; + + protected static final List> HOST_PATH_VOLUMES = + new ArrayList>() { + { + add( + new HashMap() { + { + put(Constants.VOLUME_NAME, "disk1"); + put(Constants.HOST_PATH_VOLUME_PATH, "/dump/1"); + put(Constants.VOLUME_MOUNT_PATH, "/opt/disk1"); + } + }); + add( + new HashMap() { + { + put(Constants.VOLUME_NAME, "disk2"); + put(Constants.HOST_PATH_VOLUME_PATH, "/dump/2"); + put(Constants.VOLUME_MOUNT_PATH, "/opt/disk2"); + } + }); + } + }; + + protected static final ConfigMapVolume CONFIG_MAP_VOLUME = + new ConfigMapVolume( + "disk1", + "config-map", + new HashMap() { + { + put("log4j2.xml", "log4j2.xml"); + put("log4j.properties", "log4j.properties"); + } + }, + "/opt/conf"); + + protected static final List> TOLERATIONS = + new ArrayList>() { + { + add( + new HashMap() { + { + put("key", "key1"); + put("operator", "Equal"); + put("value", "value1"); + put("effect", "NoSchedule"); + } + }); + } + }; + + protected static final Map USER_LABELS = + new HashMap() { + { + put("label1", "value1"); + put("label2", "value2"); + } + }; + + protected static final Map NODE_SELECTOR = + new HashMap() { + { + put("env", "production"); + put("disk", "ssd"); + } + }; + + protected static final Map RESOURCE_LIMIT_FACTOR = + new HashMap() { + { + put("cpu", "3.0"); + put("memory", "1.5"); + } + }; + + protected static int getTotalMemory(boolean containsNetwork) { + int totalMemory = + CONTAINER_FRAMEWORK_HEAP_MEMORY_MB + + CONTAINER_FRAMEWORK_OFF_HEAP_MEMORY_MB + + CONTAINER_JVM_METASPACE_MB + + CONTAINER_JVM_OVERHEAD_MB; + if (containsNetwork) { + totalMemory += + NETWORK_READING_MEMORY_SIZE.getMebiBytes() + + NETWORK_WRITING_MEMORY_SIZE.getMebiBytes(); + } + return totalMemory; + } + + protected static Map getCommonLabels() { + Map labels = new HashMap<>(); + labels.put(Constants.LABEL_APPTYPE_KEY, Constants.LABEL_APPTYPE_VALUE); + labels.put(Constants.LABEL_APP_KEY, CLUSTER_ID); + return labels; + } +} diff --git a/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/util/KubernetesUtilsTest.java b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/util/KubernetesUtilsTest.java new file mode 100644 index 00000000..0cad483d --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/java/com/alibaba/flink/shuffle/kubernetes/operator/util/KubernetesUtilsTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.operator.util; + +import io.fabric8.kubernetes.api.model.HasMetadata; +import io.fabric8.kubernetes.api.model.ObjectMeta; +import io.fabric8.kubernetes.api.model.ObjectMetaBuilder; +import io.fabric8.kubernetes.api.model.OwnerReference; +import io.fabric8.kubernetes.api.model.apps.Deployment; +import io.fabric8.kubernetes.api.model.apps.DeploymentBuilder; +import org.junit.Test; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; + +/** Test for {@link KubernetesUtils}. */ +public class KubernetesUtilsTest { + + @Test + public void testSetAndGetOwnerReference() { + Deployment owner = + new DeploymentBuilder() + .editOrNewMetadata() + .withName("testOwner") + .withNamespace("default") + .endMetadata() + .build(); + Deployment resource = new DeploymentBuilder().editOrNewMetadata().endMetadata().build(); + + KubernetesUtils.setOwnerReference(resource, owner); + assertEquals(resource.getMetadata().getOwnerReferences().size(), 1); + OwnerReference ownerReference = KubernetesUtils.getControllerOf(resource); + assertEquals("testOwner", ownerReference.getName()); + assertEquals("apps/v1", ownerReference.getApiVersion()); + assertEquals(true, ownerReference.getController()); + assertEquals("Deployment", ownerReference.getKind()); + assertEquals(owner.getMetadata().getUid(), ownerReference.getUid()); + } + + @Test + public void testGetShuffleManagerNameWithClusterId() { + assertEquals( + "test-shufflemanager", KubernetesUtils.getShuffleManagerNameWithClusterId("test")); + } + + @Test + public void testGetShuffleWorkersNameWithClusterId() { + assertEquals( + "test-shuffleworker", KubernetesUtils.getShuffleWorkersNameWithClusterId("test")); + } + + @Test + public void testGetResourceFullName() { + HasMetadata resource = + new HasMetadata() { + @Override + public ObjectMeta getMetadata() { + return new ObjectMetaBuilder() + .withNamespace("testNamespace") + .withName("testResource") + .build(); + } + + @Override + public void setMetadata(ObjectMeta objectMeta) {} + + @Override + public void setApiVersion(String s) {} + }; + + assertThat(KubernetesUtils.getResourceFullName(resource), is("testNamespace/testResource")); + } +} diff --git a/shuffle-kubernetes-operator/src/test/resources/log4j2-test.properties b/shuffle-kubernetes-operator/src/test/resources/log4j2-test.properties new file mode 100644 index 00000000..d7fcb327 --- /dev/null +++ b/shuffle-kubernetes-operator/src/test/resources/log4j2-test.properties @@ -0,0 +1,26 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +rootLogger.level=OFF +rootLogger.appenderRef.test.ref=TestLogger +appender.testlogger.name=TestLogger +appender.testlogger.type=CONSOLE +appender.testlogger.target=SYSTEM_ERR +appender.testlogger.layout.type=PatternLayout +appender.testlogger.layout.pattern=%-4r [%t] %-5p %c %x - %m%n diff --git a/shuffle-kubernetes/pom.xml b/shuffle-kubernetes/pom.xml new file mode 100644 index 00000000..a0dbe823 --- /dev/null +++ b/shuffle-kubernetes/pom.xml @@ -0,0 +1,38 @@ + + + + + + com.alibaba.flink.shuffle + flink-shuffle-parent + 1.0-SNAPSHOT + + 4.0.0 + + shuffle-kubernetes + + + + com.alibaba.flink.shuffle + shuffle-coordinator + ${project.version} + + + diff --git a/shuffle-kubernetes/src/main/java/com/alibaba/flink/shuffle/kubernetes/manager/KubernetesShuffleManagerRunner.java b/shuffle-kubernetes/src/main/java/com/alibaba/flink/shuffle/kubernetes/manager/KubernetesShuffleManagerRunner.java new file mode 100644 index 00000000..6673f609 --- /dev/null +++ b/shuffle-kubernetes/src/main/java/com/alibaba/flink/shuffle/kubernetes/manager/KubernetesShuffleManagerRunner.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.manager; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.JvmShutdownSafeguard; +import com.alibaba.flink.shuffle.common.utils.SignalHandler; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaMode; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleManager; +import com.alibaba.flink.shuffle.coordinator.manager.entrypoint.ShuffleManagerEntrypoint; +import com.alibaba.flink.shuffle.coordinator.utils.ClusterEntrypointUtils; +import com.alibaba.flink.shuffle.coordinator.utils.EnvironmentInformation; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; +import static com.alibaba.flink.shuffle.coordinator.utils.ClusterEntrypointUtils.STARTUP_FAILURE_RETURN_CODE; + +/** The kubernetes entrypoint class for the {@link ShuffleManager}. */ +public class KubernetesShuffleManagerRunner { + + private static final Logger LOG = LoggerFactory.getLogger(KubernetesShuffleManagerRunner.class); + + public static final String ENV_REMOTE_SHUFFLE_POD_IP_ADDRESS = "_POD_IP_ADDRESS"; + + public static void main(String[] args) { + // startup checks and logging + EnvironmentInformation.logEnvironmentInfo(LOG, "Shuffle Manager", args); + SignalHandler.register(LOG); + JvmShutdownSafeguard.installAsShutdownHook(LOG); + + long maxOpenFileHandles = EnvironmentInformation.getOpenFileHandlesLimit(); + + if (maxOpenFileHandles != -1L) { + LOG.info("Maximum number of open file descriptors is {}.", maxOpenFileHandles); + } else { + LOG.info("Cannot determine the maximum number of open file descriptors"); + } + + try { + Configuration configuration = ClusterEntrypointUtils.parseParametersOrExit(args); + ShuffleManagerEntrypoint shuffleManagerEntrypoint = + new ShuffleManagerEntrypoint(loadConfiguration(configuration)); + ShuffleManagerEntrypoint.runShuffleManagerEntrypoint(shuffleManagerEntrypoint); + } catch (Throwable t) { + LOG.error("ShuffleManager initialization failed.", t); + System.exit(STARTUP_FAILURE_RETURN_CODE); + } + } + + /** + * For HA cluster, {@link ManagerOptions#RPC_ADDRESS} will be set to the pod ip address. The + * ShuffleWorker use Zookeeper or other high-availability service to find the address of + * ShuffleManager. + * + * @return Updated configuration + */ + static Configuration loadConfiguration(Configuration conf) { + Configuration configuration = new Configuration(conf); + if (HaMode.isHighAvailabilityModeActivated(configuration)) { + final String ipAddress = System.getenv(ENV_REMOTE_SHUFFLE_POD_IP_ADDRESS); + checkState( + ipAddress != null, + "ShuffleManager ip address environment variable " + + ENV_REMOTE_SHUFFLE_POD_IP_ADDRESS + + " not set"); + configuration.setString(ManagerOptions.RPC_ADDRESS, ipAddress); + } + + return configuration; + } +} diff --git a/shuffle-kubernetes/src/test/java/com/alibaba/flink/shuffle/kubernetes/manager/KubernetesShuffleManagerRunnerTest.java b/shuffle-kubernetes/src/test/java/com/alibaba/flink/shuffle/kubernetes/manager/KubernetesShuffleManagerRunnerTest.java new file mode 100644 index 00000000..674a610f --- /dev/null +++ b/shuffle-kubernetes/src/test/java/com/alibaba/flink/shuffle/kubernetes/manager/KubernetesShuffleManagerRunnerTest.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.manager; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; +import com.alibaba.flink.shuffle.kubernetes.manager.util.EnvUtils; + +import org.junit.Test; + +import java.util.Collections; + +import static com.alibaba.flink.shuffle.kubernetes.manager.KubernetesShuffleManagerRunner.ENV_REMOTE_SHUFFLE_POD_IP_ADDRESS; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +/** Test for {@link KubernetesShuffleManagerRunner}. */ +public class KubernetesShuffleManagerRunnerTest { + + @Test + public void testLoadConfigurationWithNoneHaMode() { + EnvUtils.setEnv(Collections.singletonMap(ENV_REMOTE_SHUFFLE_POD_IP_ADDRESS, "192.168.1.1")); + Configuration configuration = new Configuration(); + configuration.setString(ManagerOptions.RPC_ADDRESS, "localhost"); + configuration.setString(HighAvailabilityOptions.HA_MODE, "NONE"); + + Configuration updatedConfiguration = + KubernetesShuffleManagerRunner.loadConfiguration(configuration); + assertThat(updatedConfiguration.getString(ManagerOptions.RPC_ADDRESS), is("localhost")); + } + + @Test + public void testLoadConfigurationWithZooKeeperHaMode() { + EnvUtils.setEnv(Collections.singletonMap(ENV_REMOTE_SHUFFLE_POD_IP_ADDRESS, "192.168.1.1")); + Configuration configuration = new Configuration(); + configuration.setString(ManagerOptions.RPC_ADDRESS, "localhost"); + configuration.setString(HighAvailabilityOptions.HA_MODE, "ZOOKEEPER"); + + Configuration updatedConfiguration = + KubernetesShuffleManagerRunner.loadConfiguration(configuration); + assertThat(updatedConfiguration.getString(ManagerOptions.RPC_ADDRESS), is("192.168.1.1")); + } +} diff --git a/shuffle-kubernetes/src/test/java/com/alibaba/flink/shuffle/kubernetes/manager/util/EnvUtils.java b/shuffle-kubernetes/src/test/java/com/alibaba/flink/shuffle/kubernetes/manager/util/EnvUtils.java new file mode 100644 index 00000000..81a51576 --- /dev/null +++ b/shuffle-kubernetes/src/test/java/com/alibaba/flink/shuffle/kubernetes/manager/util/EnvUtils.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.kubernetes.manager.util; + +import java.lang.reflect.Field; +import java.util.Map; + +/** Utility class for setting environment variables. Borrowed from Apache Flink. */ +public class EnvUtils { + + public static void setEnv(Map newenv) { + setEnv(newenv, true); + } + + // This code is taken slightly modified from: http://stackoverflow.com/a/7201825/568695 + // it changes the environment variables of this JVM. Use only for testing purposes! + @SuppressWarnings("unchecked") + public static void setEnv(Map newenv, boolean clearExisting) { + try { + Map env = System.getenv(); + Class clazz = env.getClass(); + Field field = clazz.getDeclaredField("m"); + field.setAccessible(true); + Map map = (Map) field.get(env); + if (clearExisting) { + map.clear(); + } + map.putAll(newenv); + + // only for Windows + Class processEnvironmentClass = Class.forName("java.lang.ProcessEnvironment"); + try { + Field theCaseInsensitiveEnvironmentField = + processEnvironmentClass.getDeclaredField("theCaseInsensitiveEnvironment"); + theCaseInsensitiveEnvironmentField.setAccessible(true); + Map cienv = + (Map) theCaseInsensitiveEnvironmentField.get(null); + if (clearExisting) { + cienv.clear(); + } + cienv.putAll(newenv); + } catch (NoSuchFieldException ignored) { + } + + } catch (Exception e1) { + throw new RuntimeException(e1); + } + } +} diff --git a/shuffle-metrics/pom.xml b/shuffle-metrics/pom.xml new file mode 100644 index 00000000..3e950358 --- /dev/null +++ b/shuffle-metrics/pom.xml @@ -0,0 +1,72 @@ + + + + + com.alibaba.flink.shuffle + flink-shuffle-parent + 1.0-SNAPSHOT + + 4.0.0 + + shuffle-metrics + + + + + com.alibaba.middleware + metrics-core-api + ${alibaba.metrics.version} + + + com.alibaba.middleware + metrics-core-impl + ${alibaba.metrics.version} + + + com.alibaba.middleware + metrics-integration + ${alibaba.metrics.version} + + + com.alibaba.middleware + metrics-rest + ${alibaba.metrics.version} + + + com.alibaba.middleware + metrics-reporter + ${alibaba.metrics.version} + + + + com.alibaba.flink.shuffle + shuffle-common + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-core + ${project.version} + + + + diff --git a/shuffle-metrics/src/main/java/com/alibaba/flink/shuffle/metrics/entry/MetricBootstrap.java b/shuffle-metrics/src/main/java/com/alibaba/flink/shuffle/metrics/entry/MetricBootstrap.java new file mode 100644 index 00000000..6ba7a899 --- /dev/null +++ b/shuffle-metrics/src/main/java/com/alibaba/flink/shuffle/metrics/entry/MetricBootstrap.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.metrics.entry; + +import com.alibaba.metrics.integrate.MetricsIntegrateUtils; +import com.alibaba.metrics.rest.server.MetricsHttpServer; + +/** Bootstrap for metrics. */ +public class MetricBootstrap { + + private static final MetricsHttpServer metricsHttpServer = new MetricsHttpServer(); + + public static void init(MetricConfiguration conf) { + if (conf.isHttpServerEnabled()) { + startHttpServer(); + } + MetricsIntegrateUtils.registerAllMetrics(conf.getProperties()); + } + + public static void destroy() { + metricsHttpServer.stop(); + } + + private static void startHttpServer() { + metricsHttpServer.start(); + } +} diff --git a/shuffle-metrics/src/main/java/com/alibaba/flink/shuffle/metrics/entry/MetricConfiguration.java b/shuffle-metrics/src/main/java/com/alibaba/flink/shuffle/metrics/entry/MetricConfiguration.java new file mode 100644 index 00000000..b4a846d4 --- /dev/null +++ b/shuffle-metrics/src/main/java/com/alibaba/flink/shuffle/metrics/entry/MetricConfiguration.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.metrics.entry; + +import com.alibaba.flink.shuffle.common.config.Configuration; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Properties; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.core.config.MetricOptions.METRICS_BIND_HOST; +import static com.alibaba.flink.shuffle.core.config.MetricOptions.METRICS_HTTP_SERVER_ENABLE; +import static com.alibaba.flink.shuffle.core.config.MetricOptions.METRICS_SHUFFLE_MANAGER_HTTP_BIND_PORT; +import static com.alibaba.flink.shuffle.core.config.MetricOptions.METRICS_SHUFFLE_WORKER_HTTP_BIND_PORT; + +/** + * This class is used to transform configurations, because the configurations can't be used directly + * in the dependency metrics project. + */ +public class MetricConfiguration { + private static final Logger LOG = LoggerFactory.getLogger(MetricConfiguration.class); + // Config key used in the dependency metrics project. + private static final String ALI_METRICS_BINDING_HOST = "com.alibaba.metrics.http.binding.host"; + private static final String ALI_METRICS_HTTP_PORT = "com.alibaba.metrics.http.port"; + + private final Configuration conf; + + private final Properties properties; + + private final boolean isManager; + + public MetricConfiguration(Configuration configuration, boolean isManager) { + this.conf = configuration; + this.isManager = isManager; + this.properties = parseMetricProperties(configuration); + } + + /** Transform configurations into new formats used in the dependency metrics project. */ + private Properties parseMetricProperties(Configuration configuration) { + checkNotNull(configuration); + Properties properties = new Properties(); + + // Transfer the bind host config from cluster config + String bindHost = configuration.getString(METRICS_BIND_HOST); + System.setProperty(ALI_METRICS_BINDING_HOST, bindHost); + + int bindPort; + if (isManager) { + bindPort = configuration.getInteger(METRICS_SHUFFLE_MANAGER_HTTP_BIND_PORT); + } else { + bindPort = configuration.getInteger(METRICS_SHUFFLE_WORKER_HTTP_BIND_PORT); + } + System.setProperty(ALI_METRICS_HTTP_PORT, Integer.toString(bindPort)); + LOG.info("Metrics http server port is set to " + bindPort); + + properties.putAll(configuration.toProperties()); + return properties; + } + + boolean isHttpServerEnabled() { + return conf == null ? false : conf.getBoolean(METRICS_HTTP_SERVER_ENABLE); + } + + public Properties getProperties() { + return properties; + } + + public Configuration getConfiguration() { + return conf; + } +} diff --git a/shuffle-metrics/src/main/java/com/alibaba/flink/shuffle/metrics/entry/MetricUtils.java b/shuffle-metrics/src/main/java/com/alibaba/flink/shuffle/metrics/entry/MetricUtils.java new file mode 100644 index 00000000..668548fc --- /dev/null +++ b/shuffle-metrics/src/main/java/com/alibaba/flink/shuffle/metrics/entry/MetricUtils.java @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.metrics.entry; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.metrics.reporter.ReporterSetup; + +import com.alibaba.metrics.Compass; +import com.alibaba.metrics.Counter; +import com.alibaba.metrics.FastCompass; +import com.alibaba.metrics.Gauge; +import com.alibaba.metrics.Histogram; +import com.alibaba.metrics.Meter; +import com.alibaba.metrics.Metric; +import com.alibaba.metrics.MetricLevel; +import com.alibaba.metrics.MetricManager; +import com.alibaba.metrics.MetricName; +import com.alibaba.metrics.Timer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** Utils to manager metrics. */ +public class MetricUtils { + private static final Logger LOG = LoggerFactory.getLogger(MetricUtils.class); + + // --------------------------------------------------------------- + // Manage metric system + // --------------------------------------------------------------- + + public static void startManagerMetricSystem(Configuration config) { + MetricConfiguration metricConf = new MetricConfiguration(config, true); + startMetricSystemInternal(metricConf); + } + + public static void startWorkerMetricSystem(Configuration config) { + MetricConfiguration metricConf = new MetricConfiguration(config, false); + startMetricSystemInternal(metricConf); + } + + private static void startMetricSystemInternal(MetricConfiguration conf) { + try { + MetricBootstrap.init(conf); + ReporterSetup.fromConfiguration(conf.getConfiguration()); + LOG.info("Metric system start successfully"); + } catch (Throwable t) { + LOG.error("Start metric system failed, ", t); + } + } + + public static void stopMetricSystem() { + try { + MetricBootstrap.destroy(); + LOG.info("Metric system is stopped."); + } catch (Throwable throwable) { + LOG.error("Failed to stop metric system.", throwable); + } + } + + // --------------------------------------------------------------- + // Metrics Getter + // --------------------------------------------------------------- + + /** Get {@link Counter}. */ + public static Counter getCounter(String groupName, String metricName, String... tags) { + checkNotNull(groupName); + checkNotNull(metricName); + + return MetricManager.getCounter(groupName, generateMetricName(metricName, tags)); + } + + /** Get {@link Counter}. */ + public static Counter getCounter( + String groupName, String metricName, MetricLevel metricLevel, String... tags) { + checkNotNull(groupName); + checkNotNull(metricName); + + return MetricManager.getCounter( + groupName, generateMetricName(metricName, metricLevel, tags)); + } + + /** Get {@link Meter}. */ + public static Meter getMeter(String groupName, String metricName, String... tags) { + checkNotNull(groupName); + checkNotNull(metricName); + + return MetricManager.getMeter(groupName, generateMetricName(metricName, tags)); + } + + /** Get {@link Meter}. */ + public static Meter getMeter( + String groupName, String metricName, MetricLevel metricLevel, String... tags) { + checkNotNull(groupName); + checkNotNull(metricName); + + return MetricManager.getMeter(groupName, generateMetricName(metricName, metricLevel, tags)); + } + + /** Get {@link Histogram}. */ + public static Histogram getHistogram(String groupName, String metricName, String... tags) { + checkNotNull(groupName); + checkNotNull(metricName); + + return MetricManager.getHistogram(groupName, generateMetricName(metricName, tags)); + } + + /** Get {@link Histogram}. */ + public static Histogram getHistogram( + String groupName, String metricName, MetricLevel metricLevel, String... tags) { + checkNotNull(groupName); + checkNotNull(metricName); + + return MetricManager.getHistogram( + groupName, generateMetricName(metricName, metricLevel, tags)); + } + + /** Get {@link Timer}. */ + public static Timer getTimer(String groupName, String metricName, String... tags) { + checkNotNull(groupName); + checkNotNull(metricName); + + return MetricManager.getTimer(groupName, generateMetricName(metricName, tags)); + } + + /** Get {@link Timer}. */ + public static Timer getTimer( + String groupName, String metricName, MetricLevel metricLevel, String... tags) { + checkNotNull(groupName); + checkNotNull(metricName); + + return MetricManager.getTimer(groupName, generateMetricName(metricName, metricLevel, tags)); + } + + /** + * Get {@link Compass}. This metric is used when recording throughput, response time + * distribution, success rate or error code metrics. + */ + public static Compass getCompass(String groupName, String metricName, String... tags) { + checkNotNull(groupName); + checkNotNull(metricName); + + return MetricManager.getCompass(groupName, generateMetricName(metricName, tags)); + } + + /** + * Get {@link Compass}. This metric is used when recording throughput, response time + * distribution, success rate or error code metrics. + */ + public static Compass getCompass( + String groupName, String metricName, MetricLevel metricLevel, String... tags) { + checkNotNull(groupName); + checkNotNull(metricName); + + return MetricManager.getCompass( + groupName, generateMetricName(metricName, metricLevel, tags)); + } + + /** + * Get {@link FastCompass}. This metric is used when recording efficient statistical throughput, + * average RT and metric of custom dimensions. + */ + public static FastCompass getFastCompass(String groupName, String metricName, String... tags) { + checkNotNull(groupName); + checkNotNull(metricName); + + return MetricManager.getFastCompass(groupName, generateMetricName(metricName, tags)); + } + + /** + * Get {@link FastCompass}. This metric is used when recording efficient statistical throughput, + * average RT and metric of custom dimensions. + */ + public static FastCompass getFastCompass( + String groupName, String metricName, MetricLevel metricLevel, String... tags) { + checkNotNull(groupName); + checkNotNull(metricName); + + return MetricManager.getFastCompass( + groupName, generateMetricName(metricName, metricLevel, tags)); + } + + /** + * To get a {@link Gauge} metric, please register a Gauge metric to {@link MetricManager}. For + * example: + * + *

Gauge< Integer> listenerSizeGauge = new Gauge< Integer>() { @Override public Integer + * getValue() { return defaultEnv.getAllListeners().size(); } }; + * MetricManager.register("testGroup", MetricName.build("abc.defaultEnv.listenerSize"), + * listenerSizeGauge); + */ + public static void registerMetric( + String groupName, String metricName, Metric metric, String... tags) { + checkNotNull(groupName); + checkNotNull(metricName); + + try { + MetricManager.register(groupName, generateMetricName(metricName, tags), metric); + } catch (Exception exception) { + LOG.warn("Failed to register metric " + metricName, exception); + } + } + + public static void registerMetric( + String groupName, + String metricName, + Metric metric, + MetricLevel metricLevel, + String... tags) { + checkNotNull(groupName); + checkNotNull(metricName); + + try { + MetricManager.register( + groupName, generateMetricName(metricName, metricLevel, tags), metric); + } catch (Exception exception) { + LOG.warn("Failed to register metric " + metricName, exception); + } + } + + private static MetricName generateMetricName(String metricName, String... tags) { + return generateMetricName(metricName, null, tags); + } + + private static MetricName generateMetricName( + String metricName, MetricLevel metricLevel, String... tags) { + MetricName buildName = MetricName.build(metricName); + if (metricLevel != null) { + buildName.level(metricLevel); + } + if (tags != null && tags.length > 0) { + buildName.tagged(tags); + } + return buildName; + } +} diff --git a/shuffle-metrics/src/main/java/com/alibaba/flink/shuffle/metrics/reporter/MetricReporterFactory.java b/shuffle-metrics/src/main/java/com/alibaba/flink/shuffle/metrics/reporter/MetricReporterFactory.java new file mode 100644 index 00000000..3e752859 --- /dev/null +++ b/shuffle-metrics/src/main/java/com/alibaba/flink/shuffle/metrics/reporter/MetricReporterFactory.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.metrics.reporter; + +import com.alibaba.metrics.reporter.MetricManagerReporter; + +import java.util.Properties; + +/** + * {@link MetricManagerReporter} factory. Metric reporters that can be instantiated with a factory. + */ +public interface MetricReporterFactory { + + /** + * Creates a new metric reporter. + * + * @param conf configurations for reporters + * @return created metric reporter + */ + MetricManagerReporter createMetricReporter(final Properties conf); +} diff --git a/shuffle-metrics/src/main/java/com/alibaba/flink/shuffle/metrics/reporter/ReporterSetup.java b/shuffle-metrics/src/main/java/com/alibaba/flink/shuffle/metrics/reporter/ReporterSetup.java new file mode 100644 index 00000000..ecb52159 --- /dev/null +++ b/shuffle-metrics/src/main/java/com/alibaba/flink/shuffle/metrics/reporter/ReporterSetup.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.metrics.reporter; + +import com.alibaba.flink.shuffle.common.config.Configuration; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static com.alibaba.flink.shuffle.core.config.MetricOptions.METRICS_REPORTER_CLASSES; + +/** Encapsulates everything needed for the instantiation and configuration of a metric reporter. */ +public final class ReporterSetup { + private static final Logger LOG = LoggerFactory.getLogger(ReporterSetup.class); + + private static final String CONFIGURATION_ARGS_DELIMITER = ";"; + + public static void fromConfiguration(final Configuration conf) { + String reportersString = conf.getString(METRICS_REPORTER_CLASSES); + if (reportersString == null) { + LOG.info("Metric reporter factories are not configured"); + return; + } + + Set reporterFactories = + Stream.of(reportersString.split(CONFIGURATION_ARGS_DELIMITER)) + .collect(Collectors.toSet()); + reporterFactories.forEach( + factoryClass -> setupReporterViaReflection(factoryClass.trim(), conf)); + } + + private static void setupReporterViaReflection( + final String reporterFactory, final Configuration conf) { + try { + loadViaReflection(reporterFactory, conf); + } catch (Throwable th) { + LOG.error("Setup reporter " + reporterFactory + " error, ", th); + } + } + + /** This method is used for unit testing, so package level permissions are required. */ + static void loadViaReflection(final String reporterFactory, final Configuration conf) + throws Exception { + Class factoryClazz = Class.forName(reporterFactory); + MetricReporterFactory metricReporterFactory = + (MetricReporterFactory) factoryClazz.newInstance(); + metricReporterFactory.createMetricReporter(conf.toProperties()); + LOG.info("Setup metric reporter " + reporterFactory + " successfully"); + } +} diff --git a/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/entry/FakedMetricRegistryListener.java b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/entry/FakedMetricRegistryListener.java new file mode 100644 index 00000000..a2ddf572 --- /dev/null +++ b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/entry/FakedMetricRegistryListener.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.metrics.entry; + +import com.alibaba.metrics.Compass; +import com.alibaba.metrics.Counter; +import com.alibaba.metrics.FastCompass; +import com.alibaba.metrics.Gauge; +import com.alibaba.metrics.Histogram; +import com.alibaba.metrics.Meter; +import com.alibaba.metrics.MetricName; +import com.alibaba.metrics.MetricRegistryListener; +import com.alibaba.metrics.Timer; + +/** Faked implementation for {@link MetricRegistryListener}. */ +public class FakedMetricRegistryListener implements MetricRegistryListener { + private int onGaugeAddedCalledTimes = 0; + private int onGaugeRemovedCalledTimes = 0; + private int onCounterAddedCalledTimes = 0; + private int onCounterRemovedCalledTimes = 0; + private int onHistogramAddedCalledTimes = 0; + private int onHistogramRemovedCalledTimes = 0; + private int onMeterAddedCalledTimes = 0; + private int onMeterRemovedCalledTimes = 0; + private int onTimerAddedCalledTimes = 0; + private int onTimerRemovedCalledTimes = 0; + private int onCompassAddedCalledTimes = 0; + private int onCompassRemovedCalledTimes = 0; + private int onFastCompassAddedCalledTimes = 0; + private int onFastCompassRemovedCalledTimes = 0; + + @Override + public void onGaugeAdded(MetricName name, Gauge gauge) { + onGaugeAddedCalledTimes++; + } + + @Override + public void onGaugeRemoved(MetricName name) { + onGaugeRemovedCalledTimes++; + } + + @Override + public void onCounterAdded(MetricName name, Counter counter) { + onCounterAddedCalledTimes++; + } + + @Override + public void onCounterRemoved(MetricName name) { + onCounterRemovedCalledTimes++; + } + + @Override + public void onHistogramAdded(MetricName name, Histogram histogram) { + onHistogramAddedCalledTimes++; + } + + @Override + public void onHistogramRemoved(MetricName name) { + onHistogramRemovedCalledTimes++; + } + + @Override + public void onMeterAdded(MetricName name, Meter meter) { + onMeterAddedCalledTimes++; + } + + @Override + public void onMeterRemoved(MetricName name) { + onMeterRemovedCalledTimes++; + } + + @Override + public void onTimerAdded(MetricName name, Timer timer) { + onTimerAddedCalledTimes++; + } + + @Override + public void onTimerRemoved(MetricName name) { + onTimerRemovedCalledTimes++; + } + + @Override + public void onCompassAdded(MetricName name, Compass compass) { + onCompassAddedCalledTimes++; + } + + @Override + public void onCompassRemoved(MetricName name) { + onCompassRemovedCalledTimes++; + } + + @Override + public void onFastCompassAdded(MetricName name, FastCompass compass) { + onFastCompassAddedCalledTimes++; + } + + @Override + public void onFastCompassRemoved(MetricName name) { + onFastCompassRemovedCalledTimes++; + } + + public int getOnGaugeAddedCalledTimes() { + return onGaugeAddedCalledTimes; + } + + public void setOnGaugeAddedCalledTimes(int onGaugeAddedCalledTimes) { + this.onGaugeAddedCalledTimes = onGaugeAddedCalledTimes; + } + + public int getOnCounterAddedCalledTimes() { + return onCounterAddedCalledTimes; + } + + public void setOnCounterAddedCalledTimes(int onCounterAddedCalledTimes) { + this.onCounterAddedCalledTimes = onCounterAddedCalledTimes; + } + + public int getOnGaugeRemovedCalledTimes() { + return onGaugeRemovedCalledTimes; + } + + public void setOnGaugeRemovedCalledTimes(int onGaugeRemovedCalledTimes) { + this.onGaugeRemovedCalledTimes = onGaugeRemovedCalledTimes; + } + + public int getOnCounterRemovedCalledTimes() { + return onCounterRemovedCalledTimes; + } + + public void setOnCounterRemovedCalledTimes(int onCounterRemovedCalledTimes) { + this.onCounterRemovedCalledTimes = onCounterRemovedCalledTimes; + } + + public int getOnHistogramAddedCalledTimes() { + return onHistogramAddedCalledTimes; + } + + public void setOnHistogramAddedCalledTimes(int onHistogramAddedCalledTimes) { + this.onHistogramAddedCalledTimes = onHistogramAddedCalledTimes; + } + + public int getOnHistogramRemovedCalledTimes() { + return onHistogramRemovedCalledTimes; + } + + public void setOnHistogramRemovedCalledTimes(int onHistogramRemovedCalledTimes) { + this.onHistogramRemovedCalledTimes = onHistogramRemovedCalledTimes; + } + + public int getOnMeterAddedCalledTimes() { + return onMeterAddedCalledTimes; + } + + public void setOnMeterAddedCalledTimes(int onMeterAddedCalledTimes) { + this.onMeterAddedCalledTimes = onMeterAddedCalledTimes; + } + + public int getOnMeterRemovedCalledTimes() { + return onMeterRemovedCalledTimes; + } + + public void setOnMeterRemovedCalledTimes(int onMeterRemovedCalledTimes) { + this.onMeterRemovedCalledTimes = onMeterRemovedCalledTimes; + } + + public int getOnTimerAddedCalledTimes() { + return onTimerAddedCalledTimes; + } + + public void setOnTimerAddedCalledTimes(int onTimerAddedCalledTimes) { + this.onTimerAddedCalledTimes = onTimerAddedCalledTimes; + } + + public int getOnTimerRemovedCalledTimes() { + return onTimerRemovedCalledTimes; + } + + public void setOnTimerRemovedCalledTimes(int onTimerRemovedCalledTimes) { + this.onTimerRemovedCalledTimes = onTimerRemovedCalledTimes; + } + + public int getOnCompassAddedCalledTimes() { + return onCompassAddedCalledTimes; + } + + public void setOnCompassAddedCalledTimes(int onCompassAddedCalledTimes) { + this.onCompassAddedCalledTimes = onCompassAddedCalledTimes; + } + + public int getOnCompassRemovedCalledTimes() { + return onCompassRemovedCalledTimes; + } + + public void setOnCompassRemovedCalledTimes(int onCompassRemovedCalledTimes) { + this.onCompassRemovedCalledTimes = onCompassRemovedCalledTimes; + } + + public int getOnFastCompassAddedCalledTimes() { + return onFastCompassAddedCalledTimes; + } + + public void setOnFastCompassAddedCalledTimes(int onFastCompassAddedCalledTimes) { + this.onFastCompassAddedCalledTimes = onFastCompassAddedCalledTimes; + } + + public int getOnFastCompassRemovedCalledTimes() { + return onFastCompassRemovedCalledTimes; + } + + public void setOnFastCompassRemovedCalledTimes(int onFastCompassRemovedCalledTimes) { + this.onFastCompassRemovedCalledTimes = onFastCompassRemovedCalledTimes; + } +} diff --git a/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/entry/FakedMetrics.java b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/entry/FakedMetrics.java new file mode 100644 index 00000000..fef28887 --- /dev/null +++ b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/entry/FakedMetrics.java @@ -0,0 +1,507 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.metrics.entry; + +import com.alibaba.metrics.BucketCounter; +import com.alibaba.metrics.Clock; +import com.alibaba.metrics.Compass; +import com.alibaba.metrics.Counter; +import com.alibaba.metrics.FastCompass; +import com.alibaba.metrics.Gauge; +import com.alibaba.metrics.Histogram; +import com.alibaba.metrics.Meter; +import com.alibaba.metrics.Reservoir; +import com.alibaba.metrics.Snapshot; +import com.alibaba.metrics.Timer; + +import java.io.OutputStream; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.TimeUnit; + +/** Faked implementation for different metric types. */ +public class FakedMetrics { + /** A faked implementation for {@link Gauge}. */ + public static class FakedGauge implements Gauge { + @Override + public Object getValue() { + return null; + } + + @Override + public long lastUpdateTime() { + return 0; + } + } + + /** A faked implementation for {@link Gauge}. */ + public static class FakedStringGauge implements Gauge { + @Override + public String getValue() { + return "abcde"; + } + + @Override + public long lastUpdateTime() { + return 0; + } + } + + /** A faked implementation for {@link Gauge}. */ + public static class FakedFloatGauge implements Gauge { + @Override + public Float getValue() { + return 1.0f; + } + + @Override + public long lastUpdateTime() { + return 0; + } + } + + /** A faked implementation for {@link Counter}. */ + public static class FakedCounter implements Counter { + @Override + public void inc() {} + + @Override + public void inc(long n) {} + + @Override + public void dec() {} + + @Override + public void dec(long n) {} + + @Override + public long getCount() { + return 0; + } + + @Override + public long lastUpdateTime() { + return 0; + } + } + + /** A faked implementation for {@link Histogram}. */ + public static class FakedHistogram implements Histogram { + + @Override + public void update(int value) {} + + @Override + public void update(long value) {} + + @Override + public long getCount() { + return 0; + } + + @Override + public long lastUpdateTime() { + return 0; + } + + @Override + public Snapshot getSnapshot() { + return new FakedSnapshot(); + } + } + + /** A faked implementation for {@link Meter}. */ + public static class FakedMeter implements Meter { + @Override + public void mark() {} + + @Override + public void mark(long n) {} + + @Override + public long getCount() { + return 0; + } + + @Override + public Map getInstantCount() { + return null; + } + + @Override + public Map getInstantCount(long startTime) { + return null; + } + + @Override + public int getInstantCountInterval() { + return 0; + } + + @Override + public double getFifteenMinuteRate() { + return 0; + } + + @Override + public double getFiveMinuteRate() { + return 0; + } + + @Override + public double getMeanRate() { + return 0; + } + + @Override + public double getOneMinuteRate() { + return 0; + } + + @Override + public long lastUpdateTime() { + return 0; + } + } + + /** A faked implementation for {@link Timer}. */ + public static class FakedTimer implements Timer { + @Override + public void update(long duration, TimeUnit unit) {} + + @Override + public T time(Callable event) throws Exception { + return null; + } + + @Override + public Context time() { + return null; + } + + @Override + public long getCount() { + return 0; + } + + @Override + public Map getInstantCount() { + return null; + } + + @Override + public Map getInstantCount(long startTime) { + return null; + } + + @Override + public int getInstantCountInterval() { + return 0; + } + + @Override + public double getFifteenMinuteRate() { + return 0; + } + + @Override + public double getFiveMinuteRate() { + return 0; + } + + @Override + public double getMeanRate() { + return 0; + } + + @Override + public double getOneMinuteRate() { + return 0; + } + + @Override + public long lastUpdateTime() { + return 0; + } + + @Override + public Snapshot getSnapshot() { + return null; + } + } + + /** A faked implementation for {@link Compass}. */ + public static class FakedCompass implements Compass { + @Override + public void update(long duration, TimeUnit unit) {} + + @Override + public void update( + long duration, TimeUnit unit, boolean isSuccess, String errorCode, String addon) {} + + @Override + public T time(Callable event) throws Exception { + return null; + } + + @Override + public Context time() { + return null; + } + + @Override + public Map getErrorCodeCounts() { + return null; + } + + @Override + public double getSuccessRate() { + return 0; + } + + @Override + public long getSuccessCount() { + return 0; + } + + @Override + public BucketCounter getBucketSuccessCount() { + return null; + } + + @Override + public Map getAddonCounts() { + return null; + } + + @Override + public long getCount() { + return 0; + } + + @Override + public Map getInstantCount() { + return null; + } + + @Override + public Map getInstantCount(long startTime) { + return null; + } + + @Override + public int getInstantCountInterval() { + return 0; + } + + @Override + public double getFifteenMinuteRate() { + return 0; + } + + @Override + public double getFiveMinuteRate() { + return 0; + } + + @Override + public double getMeanRate() { + return 0; + } + + @Override + public double getOneMinuteRate() { + return 0; + } + + @Override + public long lastUpdateTime() { + return 0; + } + + @Override + public Snapshot getSnapshot() { + return null; + } + } + + /** A faked implementation for {@link FastCompass}. */ + public static class FakedFastCompass implements FastCompass { + @Override + public void record(long duration, String subCategory) {} + + @Override + public Map> getMethodCountPerCategory() { + return null; + } + + @Override + public Map> getMethodCountPerCategory(long startTime) { + return null; + } + + @Override + public Map> getMethodRtPerCategory() { + return null; + } + + @Override + public Map> getMethodRtPerCategory(long startTime) { + return null; + } + + @Override + public Map> getCountAndRtPerCategory() { + return null; + } + + @Override + public Map> getCountAndRtPerCategory(long startTime) { + return null; + } + + @Override + public int getBucketInterval() { + return 0; + } + + @Override + public long lastUpdateTime() { + return 0; + } + } + + /** A faked implementation for {@link Reservoir}. */ + public static class FakedReservoir implements Reservoir { + private Snapshot snapshot; + int updateCount = 0; + + @Override + public int size() { + return 0; + } + + @Override + public void update(long value) { + updateCount++; + } + + @Override + public Snapshot getSnapshot() { + return snapshot; + } + + public void setSnapshot(Snapshot snapshot) { + this.snapshot = snapshot; + } + + public void setUpdateCount(int updateCount) { + this.updateCount = updateCount; + } + + public int getUpdateCount() { + return updateCount; + } + } + + /** A faked implementation for {@link Clock}. */ + public static class FakedClock extends Clock { + long curTime = 0; + + @Override + public long getTick() { + return curTime; + } + + public void setCurTime(long curTime) { + this.curTime = curTime; + } + } + + /** A faked implementation for {@link Snapshot}. */ + public static class FakedSnapshot implements Snapshot { + @Override + public double getValue(double quantile) { + return 0; + } + + @Override + public long[] getValues() { + return new long[0]; + } + + @Override + public int size() { + return 0; + } + + @Override + public double getMedian() { + return 0; + } + + @Override + public double get75thPercentile() { + return 0; + } + + @Override + public double get95thPercentile() { + return 0; + } + + @Override + public double get98thPercentile() { + return 0; + } + + @Override + public double get99thPercentile() { + return 0; + } + + @Override + public double get999thPercentile() { + return 0; + } + + @Override + public long getMax() { + return 0; + } + + @Override + public double getMean() { + return 0; + } + + @Override + public long getMin() { + return 0; + } + + @Override + public double getStdDev() { + return 0; + } + + @Override + public void dump(OutputStream output) {} + } +} diff --git a/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/entry/MetricConfigTest.java b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/entry/MetricConfigTest.java new file mode 100644 index 00000000..40fcd6b0 --- /dev/null +++ b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/entry/MetricConfigTest.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.metrics.entry; + +import com.alibaba.metrics.Counter; +import com.alibaba.metrics.Gauge; +import com.alibaba.metrics.Metric; +import com.alibaba.metrics.MetricFilter; +import com.alibaba.metrics.MetricLevel; +import com.alibaba.metrics.MetricManager; +import com.alibaba.metrics.MetricName; +import com.alibaba.metrics.PersistentGauge; +import com.alibaba.metrics.integrate.ConfigFields; +import com.alibaba.metrics.integrate.MetricsIntegrateUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Map; +import java.util.Properties; + +/** Tests for configuring metrics. */ +public class MetricConfigTest { + @Before + public void setUp() { + MetricManager.getIMetricManager().clear(); + } + + @Test(expected = IllegalArgumentException.class) + public void testEnumValueOf() { + MetricLevel.valueOf("test"); + } + + @Test + public void testEnabled() { + Properties properties = new Properties(); + Assert.assertTrue(MetricsIntegrateUtils.isEnabled(properties, "test")); + + properties.put("com.alibaba.metrics.tomcat.thread.enable", "false"); + Assert.assertFalse( + MetricsIntegrateUtils.isEnabled( + properties, "com.alibaba.metrics.tomcat.thread.enable")); + + properties.put("com.alibaba.metrics.tomcat.thread.enable", "true"); + Assert.assertTrue( + MetricsIntegrateUtils.isEnabled( + properties, "com.alibaba.metrics.tomcat.thread.enable")); + } + + @Test + public void testConfigMetricLevel() { + Properties properties = new Properties(); + properties.put("com.alibaba.metrics.jvm.class_load.level", "CRITICAL"); + MetricsIntegrateUtils.registerJvmMetrics(properties); + Map gauges = + MetricManager.getIMetricManager() + .getGauges( + "jvm", + new MetricFilter() { + @Override + public boolean matches(MetricName name, Metric metric) { + return name.getKey().equals("jvm.class_load.loaded"); + } + }); + Assert.assertEquals(1, gauges.size()); + Assert.assertEquals( + MetricLevel.CRITICAL, + gauges.entrySet().iterator().next().getKey().getMetricLevel()); + } + + @Test + public void testCleaner() { + Counter c = MetricManager.getCounter("cleaner", MetricName.build("test.cleaner")); + c.inc(); + Properties properties = new Properties(); + properties.put(ConfigFields.METRICS_CLEANER_ENABLE, "true"); + properties.put(ConfigFields.METRICS_CLEANER_KEEP_INTERVAL, "1"); + properties.put(ConfigFields.METRICS_CLEANER_DELAY, "1"); + MetricsIntegrateUtils.startMetricsCleaner(properties); + try { + Thread.sleep(2000); + Assert.assertEquals( + 0, + MetricManager.getIMetricManager() + .getCounters("cleaner", MetricFilter.ALL) + .size()); + } catch (InterruptedException e) { + e.printStackTrace(); + } + MetricsIntegrateUtils.stopMetricsCleaner(); + } + + @Test + public void testDisableCleaner() { + Counter c = MetricManager.getCounter("cleaner2", MetricName.build("test.cleaner2")); + c.inc(); + Properties properties = new Properties(); + properties.put(ConfigFields.METRICS_CLEANER_ENABLE, "false"); + MetricsIntegrateUtils.startMetricsCleaner(properties); + try { + Thread.sleep(2000); + Assert.assertEquals( + 1, + MetricManager.getIMetricManager() + .getCounters("cleaner2", MetricFilter.ALL) + .size()); + } catch (InterruptedException e) { + e.printStackTrace(); + } + MetricsIntegrateUtils.stopMetricsCleaner(); + } + + @Test + public void testCleanPersistentGauge() { + PersistentGauge g = + new PersistentGauge() { + @Override + public Integer getValue() { + return 1; + } + }; + MetricManager.register("ppp", MetricName.build("ppp1"), g); + Properties properties = new Properties(); + properties.put(ConfigFields.METRICS_CLEANER_ENABLE, "true"); + properties.put(ConfigFields.METRICS_CLEANER_KEEP_INTERVAL, "1"); + properties.put(ConfigFields.METRICS_CLEANER_DELAY, "1"); + MetricsIntegrateUtils.startMetricsCleaner(properties); + try { + Thread.sleep(2000); + Assert.assertEquals( + 1, MetricManager.getIMetricManager().getGauges("ppp", MetricFilter.ALL).size()); + } catch (InterruptedException e) { + e.printStackTrace(); + } + MetricsIntegrateUtils.stopMetricsCleaner(); + } + + @Test + public void testDoNotBeCleaned() { + MetricManager.getCounter("cleaner3", MetricName.build("test.cleaner")); + Properties properties = new Properties(); + properties.put(ConfigFields.METRICS_CLEANER_ENABLE, "true"); + properties.put(ConfigFields.METRICS_CLEANER_KEEP_INTERVAL, "10"); + properties.put(ConfigFields.METRICS_CLEANER_DELAY, "1"); + MetricsIntegrateUtils.startMetricsCleaner(properties); + try { + Thread.sleep(2000); + Assert.assertEquals( + "Because keep interval is 10 seconds, the counter should not be cleaned.", + 1, + MetricManager.getIMetricManager() + .getCounters("cleaner3", MetricFilter.ALL) + .size()); + } catch (InterruptedException e) { + e.printStackTrace(); + } + MetricsIntegrateUtils.stopMetricsCleaner(); + } + + @Test + public void testDisableFromSystemProperty() { + System.setProperty(ConfigFields.METRICS_CLEANER_ENABLE, "false"); + Assert.assertFalse( + MetricsIntegrateUtils.isEnabled(null, ConfigFields.METRICS_CLEANER_ENABLE)); + System.setProperty(ConfigFields.METRICS_CLEANER_ENABLE, "true"); + } +} diff --git a/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/entry/MetricUtilsTest.java b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/entry/MetricUtilsTest.java new file mode 100644 index 00000000..db96b663 --- /dev/null +++ b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/entry/MetricUtilsTest.java @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.metrics.entry; + +import com.alibaba.metrics.CachedGauge; +import com.alibaba.metrics.Clock; +import com.alibaba.metrics.Compass; +import com.alibaba.metrics.CompassImpl; +import com.alibaba.metrics.Counter; +import com.alibaba.metrics.CounterImpl; +import com.alibaba.metrics.FastCompass; +import com.alibaba.metrics.FastCompassImpl; +import com.alibaba.metrics.Gauge; +import com.alibaba.metrics.Histogram; +import com.alibaba.metrics.HistogramImpl; +import com.alibaba.metrics.ManualClock; +import com.alibaba.metrics.Meter; +import com.alibaba.metrics.MeterImpl; +import com.alibaba.metrics.MetricFilter; +import com.alibaba.metrics.MetricManager; +import com.alibaba.metrics.MetricName; +import com.alibaba.metrics.Reservoir; +import com.alibaba.metrics.ReservoirType; +import com.alibaba.metrics.Snapshot; +import com.alibaba.metrics.Timer; +import com.alibaba.metrics.TimerImpl; +import org.junit.Before; +import org.junit.Test; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * Tests for {@link MetricUtils}. This class is used to test metric values. Different metric values + * will be checked after updated in this class. + */ +public class MetricUtilsTest { + + private static final String TEST_GROUP = "test.group"; + + // Test for counter + private final Counter counter = new CounterImpl(); + + // Test for Timer + private final Reservoir reservoir = new FakedMetrics.FakedReservoir(); + + private final Clock clock = + new Clock() { + // a mock clock that increments its ticker by 50 ms per call + private long val = 0; + + @Override + public long getTick() { + return val += 50000000; + } + }; + private final Timer timer = new TimerImpl(reservoir, clock, 60); + + // Test for Histogram + private final FakedMetrics.FakedReservoir histogramReservoir = + new FakedMetrics.FakedReservoir(); + private final Histogram histogram = + new HistogramImpl(histogramReservoir, 10, 2, Clock.defaultClock()); + + // Test for Gauge + private final AtomicInteger value = new AtomicInteger(0); + private final Gauge gauge = + new CachedGauge(100, TimeUnit.MILLISECONDS) { + @Override + protected Integer loadValue() { + return value.incrementAndGet(); + } + }; + + // Test for Meter + private final FakedMetrics.FakedClock timerClock = new FakedMetrics.FakedClock(); + + private final Meter meter = new MeterImpl(timerClock); + + // Test for Compass + private final Reservoir compassReservoir = new FakedMetrics.FakedReservoir(); + private final Clock compassClock = + new Clock() { + // a mock clock that increments its ticker by 50 ms per call + private long val = 0; + + @Override + public long getTick() { + return val += 50000000; + } + }; + + private CompassImpl compass; + + @Before + public void setUp() throws Exception { + compass = new CompassImpl(ReservoirType.BUCKET, compassClock, 10, 60, 100, 2); + compass.setReservoir(compassReservoir); + } + + @Test + public void testCounterValueUpdate() { + final Counter counter = MetricUtils.getCounter(TEST_GROUP, "testCounter"); + assertEquals(0, counter.getCount()); + counter.inc(); + assertEquals(1, counter.getCount()); + counter.dec(); + counter.dec(); + assertEquals(-1, counter.getCount()); + counter.inc(); + assertEquals(0, counter.getCount()); + + // increase more than 1 + counter.inc(20); + assertEquals(20, counter.getCount()); + // decrease more than 1 + counter.dec(40); + assertEquals(-20, counter.getCount()); + assertEquals( + MetricManager.getIMetricManager() + .getCounter(TEST_GROUP, MetricName.build("testCounter")) + .getCount(), + counter.getCount()); + } + + @Test + public void testTimerInitValues() { + assertTrue(MetricManager.getIMetricManager().listMetricGroups().contains(TEST_GROUP)); + assertEquals(0, timer.getCount()); + + assertEquals(0.0, timer.getMeanRate(), 0.01); + + assertEquals(0.0, timer.getOneMinuteRate(), 0.001); + + assertEquals(0.0, timer.getOneMinuteRate(), 0.001); + + assertEquals(0.0, timer.getFifteenMinuteRate(), 0.001); + } + + @Test + public void testTimerValueUpdate() throws Exception { + MetricUtils.registerMetric(TEST_GROUP, "testTimer", timer); + // Test for ignore negative values + timer.update(-1, TimeUnit.SECONDS); + assertEquals(0, timer.getCount()); + + // Test for update values + assertEquals(0, timer.getCount()); + timer.update(1, TimeUnit.SECONDS); + assertEquals(1, timer.getCount()); + timer.time().stop(); + assertEquals(2, timer.getCount()); + + assertEquals( + MetricManager.getIMetricManager() + .getTimer(TEST_GROUP, MetricName.build("testTimer")) + .getCount(), + timer.getCount()); + } + + @Test + public void testTimerTimesCallableInstances() throws Exception { + final String value = timer.time(() -> "one"); + + assertEquals(1, timer.getCount()); + assertEquals("one", value); + } + + @Test + public void testHistogramValueUpdate() { + assertEquals(0, histogram.getCount()); + histogram.update(1); + assertEquals(1, histogram.getCount()); + + // test snapshot + final Snapshot snapshot = new FakedMetrics.FakedSnapshot(); + histogramReservoir.setSnapshot(snapshot); + assertEquals(histogram.getSnapshot(), snapshot); + + histogram.update(1); + assertEquals(2, histogramReservoir.getUpdateCount()); + + // update more values + ManualClock clock = new ManualClock(); + Histogram histogram = new HistogramImpl(ReservoirType.BUCKET, 5, 2, clock); + MetricUtils.registerMetric(TEST_GROUP, "testHistogram", histogram); + clock.addSeconds(10); + histogram.update(10); + histogram.update(20); + Snapshot snapshot1 = histogram.getSnapshot(); + assertEquals(15, snapshot1.getMean(), 0.001); + clock.addSeconds(6); + histogram.update(200); + histogram.update(400); + clock.addSeconds(5); + Snapshot snapshot2 = histogram.getSnapshot(); + assertEquals(300, snapshot2.getMean(), 0.001); + assertEquals( + MetricManager.getIMetricManager() + .getHistogram(TEST_GROUP, MetricName.build("testHistogram")) + .getCount(), + histogram.getCount()); + } + + @Test + public void testGaugeValueUpdate() throws Exception { + MetricUtils.registerMetric(TEST_GROUP, "testGauge", gauge); + assertEquals(1, gauge.getValue().intValue()); + long lastUpdateTime = gauge.lastUpdateTime(); + assertEquals(1, gauge.getValue().intValue()); + assertEquals(lastUpdateTime, gauge.lastUpdateTime()); + Thread.sleep(150); + assertEquals(2, gauge.getValue().intValue()); + assertEquals(2, gauge.getValue().intValue()); + assertEquals( + MetricManager.getIMetricManager() + .getGauges(TEST_GROUP, MetricFilter.ALL) + .get(MetricName.build("testGauge")) + .getValue(), + gauge.getValue()); + } + + @Test + public void testMeterInitValue() { + // Test init value + assertEquals(0, meter.getCount()); + assertEquals(0.0, meter.getMeanRate(), 0.001); + assertEquals(0.0, meter.getOneMinuteRate(), 0.001); + assertEquals(0.0, meter.getFiveMinuteRate(), 0.001); + assertEquals(0.0, meter.getFifteenMinuteRate(), 0.001); + } + + @Test + public void testMeterValueUpdate() { + MetricUtils.registerMetric(TEST_GROUP, "testMeter", meter); + // Test mark + meter.mark(); + timerClock.setCurTime(TimeUnit.SECONDS.toNanos(10)); + meter.mark(2); + assertEquals(0.3, meter.getMeanRate(), 0.001); + assertEquals(0.1840, meter.getOneMinuteRate(), 0.001); + assertEquals(0.1966, meter.getFiveMinuteRate(), 0.001); + assertEquals(0.1988, meter.getFifteenMinuteRate(), 0.001); + assertNotNull( + MetricManager.getIMetricManager() + .getMeter(TEST_GROUP, MetricName.build("testMeter"))); + } + + @Test + public void testCompassValueUpdate() { + Compass.Context context = compass.time(); + context.markAddon("hit"); + context.markAddon("loss"); + context.markAddon("goodbye"); + context.stop(); + + assertEquals(1, compass.getCount()); + + assertEquals(1, compass.getAddonCounts().get("hit").getCount()); + assertEquals(1, compass.getAddonCounts().get("loss").getCount()); + assertEquals(compass.getAddonCounts().get("goodbye"), null); + + // Test multiple update values + ManualClock clock = new ManualClock(); + Compass compass = new CompassImpl(ReservoirType.BUCKET, clock, 10, 60, 10, 5); + MetricUtils.registerMetric(TEST_GROUP, "testCompass", compass); + compass.update(10, TimeUnit.MILLISECONDS, true, null, "hit"); + compass.update(15, TimeUnit.MILLISECONDS, true, null, null); + compass.update(20, TimeUnit.MILLISECONDS, false, "error1", null); + clock.addSeconds(60); + assertEquals(compass.getCount(), 3); + assertEquals(compass.getSuccessCount(), 2); + assertEquals(1, compass.getAddonCounts().get("hit").getBucketCounts().get(0L).intValue()); + assertEquals(compass.getSnapshot().getMean(), TimeUnit.MILLISECONDS.toNanos(15), 0.001); + + compass.update(10, TimeUnit.MILLISECONDS, true, null, "hit"); + compass.update(15, TimeUnit.MILLISECONDS, true, null, null); + compass.update(20, TimeUnit.MILLISECONDS, false, "error1", null); + + clock.addSeconds(60); + assertEquals(6, compass.getCount()); + assertEquals(3, compass.getInstantCount().get(60000L).intValue()); + assertEquals(4, compass.getSuccessCount()); + assertEquals(2, compass.getBucketSuccessCount().getBucketCounts().get(60000L).intValue()); + assertEquals( + 1, compass.getAddonCounts().get("hit").getBucketCounts().get(60000L).intValue()); + assertEquals(compass.getSnapshot().getMean(), TimeUnit.MILLISECONDS.toNanos(15), 0.001); + assertEquals( + MetricManager.getIMetricManager() + .getCompass(TEST_GROUP, MetricName.build("testCompass")) + .getSuccessCount(), + compass.getSuccessCount()); + } + + @Test + public void testFastCompassValueUpdate() { + ManualClock clock = new ManualClock(); + FastCompass fastCompass = new FastCompassImpl(60, 10, clock, 10); + MetricUtils.registerMetric(TEST_GROUP, "testFastCompass", fastCompass); + fastCompass.record(10, "success"); + fastCompass.record(20, "error"); + fastCompass.record(15, "success"); + clock.addSeconds(60); + // verify count + assertTrue(fastCompass.getMethodCountPerCategory().containsKey("success")); + assertEquals( + 2, fastCompass.getMethodCountPerCategory(0L).get("success").get(0L).intValue()); + assertTrue(fastCompass.getMethodCountPerCategory().containsKey("error")); + assertEquals(1, fastCompass.getMethodCountPerCategory(0L).get("error").get(0L).intValue()); + // verify rt + assertTrue(fastCompass.getMethodRtPerCategory().containsKey("success")); + assertEquals(25, fastCompass.getMethodRtPerCategory(0L).get("success").get(0L).intValue()); + assertTrue(fastCompass.getMethodRtPerCategory().containsKey("error")); + assertEquals(20, fastCompass.getMethodRtPerCategory(0L).get("error").get(0L).intValue()); + // total count + long totalCount = + fastCompass.getMethodCountPerCategory(0L).get("success").get(0L) + + fastCompass.getMethodCountPerCategory(0L).get("error").get(0L); + assertEquals(3, totalCount); + // average rt + long avgRt = + (fastCompass.getMethodRtPerCategory(0L).get("success").get(0L) + + fastCompass.getMethodRtPerCategory(0L).get("error").get(0L)) + / totalCount; + assertEquals(15, avgRt); + // verify count and rt + assertTrue(fastCompass.getCountAndRtPerCategory().containsKey("success")); + assertEquals( + (2L << 38) + 25, + fastCompass.getCountAndRtPerCategory(0L).get("success").get(0L).longValue()); + assertTrue(fastCompass.getCountAndRtPerCategory().containsKey("error")); + assertEquals( + (1L << 38) + 20, + fastCompass.getCountAndRtPerCategory(0L).get("error").get(0L).longValue()); + assertEquals( + MetricManager.getIMetricManager() + .getFastCompass(TEST_GROUP, MetricName.build("testFastCompass")) + .getCountAndRtPerCategory() + .get("success") + .get(0L) + .longValue(), + fastCompass.getCountAndRtPerCategory(0L).get("success").get(0L).longValue()); + } +} diff --git a/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/entry/MetricsRegistryTest.java b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/entry/MetricsRegistryTest.java new file mode 100644 index 00000000..b80ebd70 --- /dev/null +++ b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/entry/MetricsRegistryTest.java @@ -0,0 +1,485 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.metrics.entry; + +import com.alibaba.metrics.BucketCounter; +import com.alibaba.metrics.BucketCounterImpl; +import com.alibaba.metrics.Compass; +import com.alibaba.metrics.CompassImpl; +import com.alibaba.metrics.Counter; +import com.alibaba.metrics.FastCompass; +import com.alibaba.metrics.FastCompassImpl; +import com.alibaba.metrics.Gauge; +import com.alibaba.metrics.Histogram; +import com.alibaba.metrics.HistogramImpl; +import com.alibaba.metrics.ManualClock; +import com.alibaba.metrics.Meter; +import com.alibaba.metrics.MeterImpl; +import com.alibaba.metrics.MetricName; +import com.alibaba.metrics.MetricRegistry; +import com.alibaba.metrics.MetricRegistryImpl; +import com.alibaba.metrics.MetricRegistryListener; +import com.alibaba.metrics.ReservoirType; +import com.alibaba.metrics.Timer; +import com.alibaba.metrics.TimerImpl; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.concurrent.TimeUnit; + +import static com.alibaba.metrics.MetricRegistry.name; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** Tests for registering metrics. */ +public class MetricsRegistryTest { + private static final MetricName TIMER2 = MetricName.build("timer"); + private static final MetricName METER2 = MetricName.build("meter"); + private static final MetricName HISTOGRAM2 = MetricName.build("histogram"); + private static final MetricName COUNTER = MetricName.build("counter2"); + private static final MetricName GAUGE = MetricName.build("gauge"); + private static final MetricName GAUGE2 = MetricName.build("gauge2"); + private static final MetricName THING = MetricName.build("something"); + private static final MetricName COMPASS = MetricName.build("compass"); + private static final MetricName FAST_COMPASS = MetricName.build("fast.compass"); + + private final FakedMetricRegistryListener listener = new FakedMetricRegistryListener(); + private final MetricRegistry registry = new MetricRegistryImpl(10); + + private final Gauge gauge = new FakedMetrics.FakedGauge(); + private final Counter counter = new FakedMetrics.FakedCounter(); + private final Histogram histogram = new FakedMetrics.FakedHistogram(); + private final Meter meter = new FakedMetrics.FakedMeter(); + private final Timer timer = new FakedMetrics.FakedTimer(); + private final Compass compass = new FakedMetrics.FakedCompass(); + private final FastCompass fastCompass = new FakedMetrics.FakedFastCompass(); + + @Before + public void setUp() throws Exception { + registry.addListener(listener); + } + + @Test + public void registeringAGaugeTriggerNotification() throws Exception { + assertEquals(registry.register(THING, gauge), gauge); + + checkOnGaugeAddedCalledTimes(1, 0); + } + + @Test + public void removingAGaugeTriggerNotification() throws Exception { + registry.register(THING, gauge); + + assertTrue(registry.remove(THING)); + + checkOnGaugeRemovedCalledTimes(1, 0); + } + + @Test + public void accessingACounterRegistersAndReusesTheCounter() throws Exception { + final Counter counter1 = registry.counter(THING); + final Counter counter2 = registry.counter(THING); + + assertTrue(counter1 == counter2); + + checkOnCounterAddedCalledTimes(1, 0); + } + + @Test + public void removingACounterTriggerNotification() throws Exception { + registry.register(THING, counter); + + assertTrue(registry.remove(THING)); + + checkOnCounterRemovedCalledTimes(1, 0); + } + + @Test + public void registeringAHistogramTriggerNotification() throws Exception { + assertEquals(registry.register(THING, histogram), histogram); + + checkOnHistogramAddedCalledTimes(1, 0); + } + + @Test + public void accessingAHistogramRegistersAndReusesIt() throws Exception { + final Histogram histogram1 = registry.histogram(THING); + final Histogram histogram2 = registry.histogram(THING); + + assertEquals(histogram1, histogram2); + + checkOnHistogramAddedCalledTimes(1, 0); + } + + @Test + public void removingAHistogramTriggerNotification() throws Exception { + registry.register(THING, histogram); + + assertTrue(registry.remove(THING)); + + checkOnHistogramRemovedCalledTimes(1, 0); + } + + @Test + public void registeringAMeterTriggerNotification() throws Exception { + assertEquals(registry.register(THING, meter), meter); + + checkOnMeterAddedCalledTimes(1, 0); + } + + @Test + public void accessingAMeterRegistersAndReusesIt() throws Exception { + final Meter meter1 = registry.meter(THING); + final Meter meter2 = registry.meter(THING); + + assertEquals(meter1, meter2); + + checkOnMeterAddedCalledTimes(1, 0); + } + + @Test + public void removingAMeterTriggerNotification() throws Exception { + registry.register(THING, meter); + + assertTrue(registry.remove(THING)); + + checkOnMeterRemovedCalledTimes(1, 0); + } + + @Test + public void registeringATimerTriggerNotification() throws Exception { + assertEquals(registry.register(THING, timer), timer); + + checkOnTimerAddedCalledTimes(1, 0); + } + + @Test + public void accessingATimerRegistersAndReusesIt() throws Exception { + final Timer timer1 = registry.timer(THING); + final Timer timer2 = registry.timer(THING); + + assertEquals(timer1, timer2); + + checkOnTimerAddedCalledTimes(1, 0); + } + + @Test + public void removingATimerTriggerNotification() throws Exception { + registry.register(THING, timer); + + assertTrue(registry.remove(THING)); + + checkOnTimerRemovedCalledTimes(1, 0); + } + + @Test + public void registeringACompassTriggerNotification() throws Exception { + assertEquals(registry.register(THING, compass), compass); + + checkOnCompassAddedCalledTimes(1, 0); + } + + @Test + public void accessingACompassRegistersAndReusesIt() throws Exception { + final Compass compass1 = registry.compass(THING); + final Compass compass2 = registry.compass(THING); + + assertEquals(compass1, compass2); + + checkOnCompassAddedCalledTimes(1, 0); + } + + @Test + public void removingACompassTriggerNotification() throws Exception { + registry.register(THING, compass); + + assertTrue(registry.remove(THING)); + + checkOnCompassRemovedCalledTimes(1, 0); + } + + @Test + public void registeringAFastCompassTriggerNotification() throws Exception { + assertEquals(registry.register(THING, fastCompass), fastCompass); + + checkOnFastCompassAddedCalledTimes(1, 0); + } + + @Test + public void accessingAFastCompassRegistersAndReusesIt() throws Exception { + final FastCompass compass1 = registry.fastCompass(THING); + final FastCompass compass2 = registry.fastCompass(THING); + + assertEquals(compass1, compass2); + + checkOnFastCompassAddedCalledTimes(1, 0); + } + + @Test + public void removingAFastCompassTriggerNotification() throws Exception { + registry.register(THING, fastCompass); + + assertTrue(registry.remove(THING)); + + checkOnFastCompassRemovedCalledTimes(1, 0); + } + + @Test + public void addingAListenerWithExistingMetricsCatchesItUp() throws Exception { + registry.register(GAUGE2, gauge); + registry.register(COUNTER, counter); + registry.register(HISTOGRAM2, histogram); + registry.register(METER2, meter); + registry.register(TIMER2, timer); + registry.register(COMPASS, compass); + registry.register(FAST_COMPASS, fastCompass); + + final MetricRegistryListener other = new FakedMetricRegistryListener(); + registry.addListener(other); + + checkOnGaugeAddedCalledTimes(1, 0); + checkOnCounterAddedCalledTimes(1, 0); + checkOnHistogramAddedCalledTimes(1, 0); + checkOnMeterAddedCalledTimes(1, 0); + checkOnTimerAddedCalledTimes(1, 0); + checkOnCompassAddedCalledTimes(1, 0); + checkOnFastCompassAddedCalledTimes(1, 0); + } + + @Test + public void aRemovedListenerDoesNotReceiveUpdates() throws Exception { + registry.register(GAUGE, gauge); + checkOnGaugeAddedCalledTimes(1, 0); + + registry.removeListener(listener); + registry.register(GAUGE2, gauge); + checkOnGaugeAddedCalledTimes(0, 0); + } + + private void checkOnGaugeAddedCalledTimes(int expect, int resetTo) { + assertEquals(expect, listener.getOnGaugeAddedCalledTimes()); + listener.setOnGaugeAddedCalledTimes(resetTo); + assertEquals(resetTo, listener.getOnGaugeAddedCalledTimes()); + } + + private void checkOnGaugeRemovedCalledTimes(int expect, int resetTo) { + assertEquals(expect, listener.getOnGaugeRemovedCalledTimes()); + listener.setOnGaugeRemovedCalledTimes(resetTo); + assertEquals(resetTo, listener.getOnGaugeRemovedCalledTimes()); + } + + private void checkOnCounterAddedCalledTimes(int expect, int resetTo) { + assertEquals(expect, listener.getOnCounterAddedCalledTimes()); + listener.setOnCounterAddedCalledTimes(resetTo); + assertEquals(resetTo, listener.getOnCounterAddedCalledTimes()); + } + + private void checkOnCounterRemovedCalledTimes(int expect, int resetTo) { + assertEquals(expect, listener.getOnCounterRemovedCalledTimes()); + listener.setOnCounterRemovedCalledTimes(resetTo); + assertEquals(resetTo, listener.getOnCounterRemovedCalledTimes()); + } + + private void checkOnHistogramAddedCalledTimes(int expect, int resetTo) { + assertEquals(expect, listener.getOnHistogramAddedCalledTimes()); + listener.setOnHistogramAddedCalledTimes(resetTo); + assertEquals(resetTo, listener.getOnHistogramAddedCalledTimes()); + } + + private void checkOnHistogramRemovedCalledTimes(int expect, int resetTo) { + assertEquals(expect, listener.getOnHistogramRemovedCalledTimes()); + listener.setOnHistogramRemovedCalledTimes(resetTo); + assertEquals(resetTo, listener.getOnHistogramRemovedCalledTimes()); + } + + private void checkOnMeterAddedCalledTimes(int expect, int resetTo) { + assertEquals(expect, listener.getOnMeterAddedCalledTimes()); + listener.setOnMeterAddedCalledTimes(resetTo); + assertEquals(resetTo, listener.getOnMeterAddedCalledTimes()); + } + + private void checkOnMeterRemovedCalledTimes(int expect, int resetTo) { + assertEquals(expect, listener.getOnMeterRemovedCalledTimes()); + listener.setOnMeterRemovedCalledTimes(resetTo); + assertEquals(resetTo, listener.getOnMeterRemovedCalledTimes()); + } + + private void checkOnTimerAddedCalledTimes(int expect, int resetTo) { + assertEquals(expect, listener.getOnTimerAddedCalledTimes()); + listener.setOnTimerAddedCalledTimes(resetTo); + assertEquals(resetTo, listener.getOnTimerAddedCalledTimes()); + } + + private void checkOnTimerRemovedCalledTimes(int expect, int resetTo) { + assertEquals(expect, listener.getOnTimerRemovedCalledTimes()); + listener.setOnTimerRemovedCalledTimes(resetTo); + assertEquals(resetTo, listener.getOnTimerRemovedCalledTimes()); + } + + private void checkOnCompassAddedCalledTimes(int expect, int resetTo) { + assertEquals(expect, listener.getOnCompassAddedCalledTimes()); + listener.setOnCompassAddedCalledTimes(resetTo); + assertEquals(resetTo, listener.getOnCompassAddedCalledTimes()); + } + + private void checkOnCompassRemovedCalledTimes(int expect, int resetTo) { + assertEquals(expect, listener.getOnCompassRemovedCalledTimes()); + listener.setOnCompassRemovedCalledTimes(resetTo); + assertEquals(resetTo, listener.getOnCompassRemovedCalledTimes()); + } + + private void checkOnFastCompassAddedCalledTimes(int expect, int resetTo) { + assertEquals(expect, listener.getOnFastCompassAddedCalledTimes()); + listener.setOnFastCompassAddedCalledTimes(resetTo); + assertEquals(resetTo, listener.getOnFastCompassAddedCalledTimes()); + } + + private void checkOnFastCompassRemovedCalledTimes(int expect, int resetTo) { + assertEquals(expect, listener.getOnFastCompassRemovedCalledTimes()); + listener.setOnFastCompassRemovedCalledTimes(resetTo); + assertEquals(resetTo, listener.getOnFastCompassRemovedCalledTimes()); + } + + @Test + public void concatenatesStringsToFormADottedName() throws Exception { + assertEquals(name("one", "two", "three"), MetricName.build("one.two.three")); + } + + @Test + @SuppressWarnings("NullArgumentToVariableArgMethod") + public void elidesNullValuesFromNamesWhenOnlyOneNullPassedIn() throws Exception { + assertEquals(name("one", (String) null), MetricName.build("one")); + } + + @Test + public void elidesNullValuesFromNamesWhenManyNullsPassedIn() throws Exception { + assertEquals(name("one", null, null), MetricName.build("one")); + } + + @Test + public void elidesNullValuesFromNamesWhenNullAndNotNullPassedIn() throws Exception { + assertEquals(name("one", null, "three"), MetricName.build("one.three")); + } + + @Test + public void elidesEmptyStringsFromNames() throws Exception { + assertEquals(name("one", "", "three"), MetricName.build("one.three")); + } + + @Test + public void concatenatesClassesWithoutCanonicalNamesWithStrings() throws Exception { + final Gauge g = + new Gauge() { + @Override + public String getValue() { + return null; + } + + @Override + public long lastUpdateTime() { + return 0; + } + }; + + assertEquals( + name(g.getClass(), "one", "two"), + MetricName.build(g.getClass().getName() + ".one.two")); + } + + @Test + public void testMaxMetricCount() { + MetricRegistry registry = new MetricRegistryImpl(10); + for (int i = 0; i < 20; i++) { + registry.counter(MetricName.build("counter-" + i)); + } + Assert.assertEquals(10, registry.getCounters().keySet().size()); + + registry = new MetricRegistryImpl(10); + for (int i = 0; i < 20; i++) { + registry.meter(MetricName.build("meter-" + i)); + } + Assert.assertEquals(10, registry.getMeters().keySet().size()); + + registry = new MetricRegistryImpl(10); + for (int i = 0; i < 20; i++) { + registry.histogram(MetricName.build("histogram-" + i)); + } + Assert.assertEquals(10, registry.getHistograms().keySet().size()); + + registry = new MetricRegistryImpl(10); + for (int i = 0; i < 20; i++) { + registry.timer(MetricName.build("timer-" + i)); + } + Assert.assertEquals(10, registry.getTimers().keySet().size()); + + registry = new MetricRegistryImpl(10); + for (int i = 0; i < 20; i++) { + registry.compass(MetricName.build("compass-" + i)); + } + Assert.assertEquals(10, registry.getCompasses().keySet().size()); + } + + @Test + public void testIllegalMetricKey() { + Assert.assertNotNull(registry.counter(MetricName.build("[aaa]"))); + } + + @Test + public void testIllegalMetricTagValue() { + registry.counter(MetricName.build("aaa").tagged("bbb", "[ccc]")); + } + + @Test + public void testLastUpdatedTime() { + MetricRegistry registry = new MetricRegistryImpl(); + ManualClock clock = new ManualClock(); + BucketCounter c1 = new BucketCounterImpl(10, 10, clock, false); + Meter m1 = new MeterImpl(clock, 10); + Timer t1 = new TimerImpl(ReservoirType.BUCKET, clock, 10); + Histogram h1 = new HistogramImpl(ReservoirType.BUCKET, 10, 10, clock); + Compass comp1 = new CompassImpl(ReservoirType.BUCKET, clock, 10, 10, 10, 10); + FastCompass fc1 = new FastCompassImpl(10, 10, clock, 10); + registry.register("a", c1); + registry.register("b", m1); + registry.register("c", t1); + registry.register("d", h1); + registry.register("e", comp1); + registry.register("f", fc1); + clock.addSeconds(10); + c1.update(); + clock.addSeconds(10); + Assert.assertEquals(10000L, registry.lastUpdateTime()); + m1.mark(); + clock.addSeconds(10); + Assert.assertEquals(20000L, registry.lastUpdateTime()); + t1.update(1, TimeUnit.SECONDS); + clock.addSeconds(10); + Assert.assertEquals(30000L, registry.lastUpdateTime()); + h1.update(1); + clock.addSeconds(10); + Assert.assertEquals(40000L, registry.lastUpdateTime()); + comp1.update(1, TimeUnit.SECONDS); + clock.addSeconds(10); + Assert.assertEquals(50000L, registry.lastUpdateTime()); + fc1.record(1, "aaa"); + clock.addSeconds(10); + Assert.assertEquals(60000L, registry.lastUpdateTime()); + } +} diff --git a/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/reporter/AnotherFakedReporterFactory.java b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/reporter/AnotherFakedReporterFactory.java new file mode 100644 index 00000000..6f230c4b --- /dev/null +++ b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/reporter/AnotherFakedReporterFactory.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.metrics.reporter; + +import com.alibaba.metrics.reporter.MetricManagerReporter; + +import java.util.Properties; + +/** + * Another implementation for {@link MetricReporterFactory}. When you need test multiple reporter + * implementations, the class can be used. + */ +public class AnotherFakedReporterFactory extends FakedMetricReporterFactory { + + @Override + public MetricManagerReporter createMetricReporter(Properties conf) { + methodCallCount++; + return null; + } +} diff --git a/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/reporter/FakedMetricReporterFactory.java b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/reporter/FakedMetricReporterFactory.java new file mode 100644 index 00000000..4ad918bd --- /dev/null +++ b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/reporter/FakedMetricReporterFactory.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.metrics.reporter; + +import com.alibaba.flink.shuffle.common.config.Configuration; + +import com.alibaba.metrics.reporter.MetricManagerReporter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Properties; + +/** A Faked implementation for {@link MetricReporterFactory}. */ +public class FakedMetricReporterFactory implements MetricReporterFactory { + private static final Logger LOG = LoggerFactory.getLogger(FakedMetricReporterFactory.class); + static volatile int methodCallCount = 0; + private static Configuration conf; + + @Override + public MetricManagerReporter createMetricReporter(Properties conf) { + methodCallCount++; + LOG.info("Faked metric reporter method is called"); + this.conf = new Configuration(conf); + return null; + } + + public static Configuration getConf() { + return conf; + } + + public static int getMethodCallCount() { + return methodCallCount; + } + + public static void resetMethodCallCount() { + methodCallCount = 0; + } +} diff --git a/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/reporter/ReporterSetupTest.java b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/reporter/ReporterSetupTest.java new file mode 100644 index 00000000..e9d1c67b --- /dev/null +++ b/shuffle-metrics/src/test/java/com/alibaba/flink/shuffle/metrics/reporter/ReporterSetupTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.metrics.reporter; + +import com.alibaba.flink.shuffle.common.config.Configuration; + +import org.junit.Test; + +import java.util.Properties; + +import static com.alibaba.flink.shuffle.core.config.MetricOptions.METRICS_REPORTER_CLASSES; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** Tests for {@link ReporterSetup}. */ +public class ReporterSetupTest { + @Test + public void testInitReporterFromConfiguration() { + FakedMetricReporterFactory.resetMethodCallCount(); + Properties properties = new Properties(); + properties.setProperty( + METRICS_REPORTER_CLASSES.key(), + "com.alibaba.flink.shuffle.metrics.reporter.FakedMetricReporterFactory"); + Configuration conf = new Configuration(properties); + + assertEquals(0, FakedMetricReporterFactory.getMethodCallCount()); + ReporterSetup.fromConfiguration(conf); + assertEquals(1, FakedMetricReporterFactory.getMethodCallCount()); + } + + @Test + public void testMultipleSameReporters() { + FakedMetricReporterFactory.resetMethodCallCount(); + Properties properties = new Properties(); + properties.setProperty( + METRICS_REPORTER_CLASSES.key(), + "com.alibaba.flink.shuffle.metrics.reporter.FakedMetricReporterFactory;" + + "com.alibaba.flink.shuffle.metrics.reporter.FakedMetricReporterFactory"); + Configuration conf = new Configuration(properties); + + assertEquals(0, FakedMetricReporterFactory.getMethodCallCount()); + ReporterSetup.fromConfiguration(conf); + assertEquals(1, FakedMetricReporterFactory.getMethodCallCount()); + } + + @Test + public void testMultipleDifferentReporters() { + FakedMetricReporterFactory.resetMethodCallCount(); + Properties properties = new Properties(); + properties.setProperty( + METRICS_REPORTER_CLASSES.key(), + "com.alibaba.flink.shuffle.metrics.reporter.FakedMetricReporterFactory;" + + "com.alibaba.flink.shuffle.metrics.reporter.AnotherFakedReporterFactory"); + Configuration conf = new Configuration(properties); + + assertEquals(0, FakedMetricReporterFactory.getMethodCallCount()); + ReporterSetup.fromConfiguration(conf); + assertEquals(2, FakedMetricReporterFactory.getMethodCallCount()); + } + + @Test(expected = Exception.class) + public void testLoadNonExistReporter() throws Exception { + ReporterSetup.loadViaReflection("needFailed", new Configuration(new Properties())); + } + + @Test + public void testConfigurationArgsRight() { + FakedMetricReporterFactory.resetMethodCallCount(); + assertTrue( + FakedMetricReporterFactory.getConf() == null + || FakedMetricReporterFactory.getConf().getString("my.k1") == null); + final String reporterKey = METRICS_REPORTER_CLASSES.key(); + final String reporterVal = + "com.alibaba.flink.shuffle.metrics.reporter.FakedMetricReporterFactory"; + + Properties properties = new Properties(); + properties.setProperty("my.k1", "v1"); + properties.setProperty("my.k2", "v2"); + properties.setProperty(reporterKey, reporterVal); + Configuration conf = new Configuration(properties); + ReporterSetup.fromConfiguration(conf); + + // Check args + Configuration config = FakedMetricReporterFactory.getConf(); + assertTrue(config.getString("my.k1").equals("v1")); + assertTrue(config.getString("my.k2").equals("v2")); + assertFalse(config.getString("my.k1").equals("v2")); + assertTrue(config.getString(reporterKey).equals(reporterVal)); + } +} diff --git a/shuffle-metrics/src/test/resources/log4j2-test.properties b/shuffle-metrics/src/test/resources/log4j2-test.properties new file mode 100644 index 00000000..d7fcb327 --- /dev/null +++ b/shuffle-metrics/src/test/resources/log4j2-test.properties @@ -0,0 +1,26 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +rootLogger.level=OFF +rootLogger.appenderRef.test.ref=TestLogger +appender.testlogger.name=TestLogger +appender.testlogger.type=CONSOLE +appender.testlogger.target=SYSTEM_ERR +appender.testlogger.layout.type=PatternLayout +appender.testlogger.layout.pattern=%-4r [%t] %-5p %c %x - %m%n diff --git a/shuffle-plugin/pom.xml b/shuffle-plugin/pom.xml new file mode 100644 index 00000000..3d43f710 --- /dev/null +++ b/shuffle-plugin/pom.xml @@ -0,0 +1,181 @@ + + + + + + com.alibaba.flink.shuffle + flink-shuffle-parent + 1.0-SNAPSHOT + + 4.0.0 + + shuffle-plugin + + + + com.alibaba.flink.shuffle + shuffle-common + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-coordinator + ${project.version} + + + org.apache.flink + * + + + com.alibaba.flink.shuffle + shuffle-metrics + + + commons-cli + commons-cli + + + + + + com.alibaba.flink.shuffle + shuffle-transfer + ${project.version} + + + com.alibaba.flink.shuffle + shuffle-metrics + + + + + + org.apache.flink + flink-runtime + ${flink.version} + provided + + + com.tysafe + * + + + com.typesafe.akka + * + + + commons-cli + commons-cli + + + + + + com.alibaba.flink.shuffle + shuffle-rpc + ${project.version} + + + org.apache.flink + * + + + + + + + com.alibaba.flink.shuffle + shuffle-transfer + ${project.version} + test-jar + test + + + + com.alibaba.flink.shuffle + shuffle-coordinator + ${project.version} + test-jar + test + + + + org.apache.flink + flink-runtime + ${flink.version} + test-jar + test + + + + org.apache.flink + flink-streaming-java_${scala.binary.version} + ${flink.version} + test + + + + + + + org.apache.maven.plugins + maven-shade-plugin + + + shade-remote-shuffle + package + + shade + + + false + false + ${project.artifactId}-${project.version} + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + *.aut + META-INF/maven/** + META-INF/services/*com.fasterxml* + META-INF/proguard/** + OSGI-INF/** + schema/** + *.vm + *.xml + META-INF/jandex.idx + license.header + org.apache.flink:* + org.slf4j:* + flink-rpc-akka.jar + + + + + + + + + + diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleDescriptor.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleDescriptor.java new file mode 100644 index 00000000..a0e1b964 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleDescriptor.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin; + +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleResource; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.plugin.utils.IdMappingUtils; + +import org.apache.flink.runtime.clusterframework.types.ResourceID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; + +import java.util.Optional; + +/** {@link ShuffleDescriptor} for the flink remote shuffle. */ +public class RemoteShuffleDescriptor implements ShuffleDescriptor { + + private static final long serialVersionUID = -6333647241057311308L; + + /** The ID of the upstream tasks' result partition. */ + private final ResultPartitionID resultPartitionID; + + /** The ID of the containing job. */ + private final JobID jobId; + + /** The allocated shuffle resource. */ + private final ShuffleResource shuffleResource; + + public RemoteShuffleDescriptor( + ResultPartitionID resultPartitionID, JobID jobId, ShuffleResource shuffleResource) { + this.resultPartitionID = resultPartitionID; + this.jobId = jobId; + this.shuffleResource = shuffleResource; + } + + @Override + public Optional storesLocalResourcesOn() { + return Optional.empty(); + } + + @Override + public ResultPartitionID getResultPartitionID() { + return resultPartitionID; + } + + public JobID getJobId() { + return jobId; + } + + public DataSetID getDataSetId() { + return IdMappingUtils.fromFlinkDataSetId( + resultPartitionID.getPartitionId().getIntermediateDataSetID()); + } + + public DataPartitionID getDataPartitionID() { + return IdMappingUtils.fromFlinkResultPartitionID(resultPartitionID); + } + + public ShuffleResource getShuffleResource() { + return shuffleResource; + } + + @Override + public String toString() { + return "RemoteShuffleDescriptor{" + + "resultPartitionID=" + + resultPartitionID + + ", jobId=" + + jobId + + ", dataSetId=" + + getDataSetId() + + ", dataPartitionID=" + + getDataPartitionID() + + ", shuffleResource=" + + shuffleResource + + '}'; + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleEnvironment.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleEnvironment.java new file mode 100644 index 00000000..7afb38e0 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleEnvironment.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin; + +import com.alibaba.flink.shuffle.plugin.transfer.RemoteShuffleInputGate; +import com.alibaba.flink.shuffle.plugin.transfer.RemoteShuffleInputGateFactory; +import com.alibaba.flink.shuffle.plugin.transfer.RemoteShuffleResultPartition; +import com.alibaba.flink.shuffle.plugin.transfer.RemoteShuffleResultPartitionFactory; +import com.alibaba.flink.shuffle.transfer.ConnectionManager; +import com.alibaba.flink.shuffle.transfer.NettyConfig; + +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; +import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.executiongraph.PartitionInfo; +import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; +import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; +import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate; +import org.apache.flink.runtime.shuffle.ShuffleEnvironment; +import org.apache.flink.runtime.shuffle.ShuffleIOOwnerContext; +import org.apache.flink.util.FlinkRuntimeException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; +import static com.alibaba.flink.shuffle.transfer.ConnectionManager.createReadConnectionManager; +import static com.alibaba.flink.shuffle.transfer.ConnectionManager.createWriteConnectionManager; +import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.METRIC_GROUP_INPUT; +import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.METRIC_GROUP_OUTPUT; +import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.createShuffleIOOwnerMetricGroup; + +/** + * The implementation of {@link ShuffleEnvironment} based on the remote shuffle service, providing + * shuffle environment on flink TM side. + */ +public class RemoteShuffleEnvironment + implements ShuffleEnvironment { + + private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleEnvironment.class); + + /** {@link ConnectionManager} for shuffle write connection. */ + private final ConnectionManager writeConnectionManager; + + /** {@link ConnectionManager} for shuffle read connection. */ + private final ConnectionManager readConnectionManager; + + /** Network buffer pool for shuffle read and shuffle write. */ + private final NetworkBufferPool networkBufferPool; + + /** A trivial {@link ResultPartitionManager}. */ + private final ResultPartitionManager resultPartitionManager; + + /** Factory class to create {@link RemoteShuffleResultPartition}. */ + private final RemoteShuffleResultPartitionFactory resultPartitionFactory; + + /** Factory class to create {@link RemoteShuffleInputGate}. */ + private final RemoteShuffleInputGateFactory inputGateFactory; + + /** Whether the shuffle environment is closed. */ + private boolean isClosed; + + private final Object lock = new Object(); + + /** + * @param networkBufferPool Network buffer pool for shuffle read and shuffle write. + * @param resultPartitionManager A trivial {@link ResultPartitionManager}. + * @param resultPartitionFactory Factory class to create {@link RemoteShuffleResultPartition}. + * @param inputGateFactory Factory class to create {@link RemoteShuffleInputGate}. + * @param nettyConfig Netty configuration. + */ + public RemoteShuffleEnvironment( + NetworkBufferPool networkBufferPool, + ResultPartitionManager resultPartitionManager, + RemoteShuffleResultPartitionFactory resultPartitionFactory, + RemoteShuffleInputGateFactory inputGateFactory, + NettyConfig nettyConfig) { + + this.networkBufferPool = networkBufferPool; + this.resultPartitionManager = resultPartitionManager; + this.resultPartitionFactory = resultPartitionFactory; + this.inputGateFactory = inputGateFactory; + this.isClosed = false; + this.writeConnectionManager = createWriteConnectionManager(nettyConfig, true); + this.readConnectionManager = createReadConnectionManager(nettyConfig, true); + } + + @Override + public List createResultPartitionWriters( + ShuffleIOOwnerContext ownerContext, + List resultPartitionDeploymentDescriptors) { + + synchronized (lock) { + checkState(!isClosed, "The RemoteShuffleEnvironment has already been shut down."); + + ResultPartitionWriter[] resultPartitions = + new ResultPartitionWriter[resultPartitionDeploymentDescriptors.size()]; + for (int index = 0; index < resultPartitions.length; index++) { + resultPartitions[index] = + resultPartitionFactory.create( + ownerContext.getOwnerName(), index, + resultPartitionDeploymentDescriptors.get(index), + writeConnectionManager); + } + return Arrays.asList(resultPartitions); + } + } + + @Override + public List createInputGates( + ShuffleIOOwnerContext ownerContext, + PartitionProducerStateProvider producerStateProvider, + List inputGateDescriptors) { + + synchronized (lock) { + checkState(!isClosed, "The RemoteShuffleEnvironment has already been shut down."); + + IndexedInputGate[] inputGates = new IndexedInputGate[inputGateDescriptors.size()]; + for (int gateIndex = 0; gateIndex < inputGates.length; gateIndex++) { + InputGateDeploymentDescriptor igdd = inputGateDescriptors.get(gateIndex); + RemoteShuffleInputGate inputGate = + inputGateFactory.create( + ownerContext.getOwnerName(), + gateIndex, + igdd, + readConnectionManager); + inputGates[gateIndex] = inputGate; + } + return Arrays.asList(inputGates); + } + } + + @Override + public void close() { + LOG.info("Close RemoteShuffleEnvironment."); + synchronized (lock) { + try { + writeConnectionManager.shutdown(); + } catch (Throwable t) { + LOG.error("Close RemoteShuffleEnvironment failure.", t); + } + try { + readConnectionManager.shutdown(); + } catch (Throwable t) { + LOG.error("Close RemoteShuffleEnvironment failure.", t); + } + try { + networkBufferPool.destroyAllBufferPools(); + } catch (Throwable t) { + LOG.error("Close RemoteShuffleEnvironment failure.", t); + } + try { + resultPartitionManager.shutdown(); + } catch (Throwable t) { + LOG.error("Close RemoteShuffleEnvironment failure.", t); + } + try { + networkBufferPool.destroy(); + } catch (Throwable t) { + LOG.error("Close RemoteShuffleEnvironment failure.", t); + } + isClosed = true; + } + } + + @Override + public int start() throws IOException { + synchronized (lock) { + checkState(!isClosed, "The RemoteShuffleEnvironment has already been shut down."); + LOG.info("Starting the network environment and its components."); + + writeConnectionManager.start(); + readConnectionManager.start(); + // trivial value. + return 1; + } + } + + @Override + public boolean updatePartitionInfo(ExecutionAttemptID consumerID, PartitionInfo partitionInfo) { + throw new FlinkRuntimeException("Not implemented yet."); + } + + @Override + public ShuffleIOOwnerContext createShuffleIOOwnerContext( + String ownerName, ExecutionAttemptID executionAttemptID, MetricGroup parentGroup) { + MetricGroup nettyGroup = createShuffleIOOwnerMetricGroup(checkNotNull(parentGroup)); + return new ShuffleIOOwnerContext( + checkNotNull(ownerName), + checkNotNull(executionAttemptID), + parentGroup, + nettyGroup.addGroup(METRIC_GROUP_OUTPUT), + nettyGroup.addGroup(METRIC_GROUP_INPUT)); + } + + @Override + public void releasePartitionsLocally(Collection partitionIds) { + throw new FlinkRuntimeException("Not implemented yet."); + } + + @Override + public Collection getPartitionsOccupyingLocalResources() { + return new ArrayList<>(); + } + + // For testing. + public NetworkBufferPool getNetworkBufferPool() { + return networkBufferPool; + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleMaster.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleMaster.java new file mode 100644 index 00000000..83ab38a8 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleMaster.java @@ -0,0 +1,602 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin; + +import com.alibaba.flink.shuffle.client.ShuffleManagerClient; +import com.alibaba.flink.shuffle.client.ShuffleManagerClientConfiguration; +import com.alibaba.flink.shuffle.client.ShuffleManagerClientImpl; +import com.alibaba.flink.shuffle.client.ShuffleWorkerStatusListener; +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServices; +import com.alibaba.flink.shuffle.coordinator.heartbeat.HeartbeatServicesUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServiceUtils; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.manager.DataPartitionCoordinate; +import com.alibaba.flink.shuffle.core.config.WorkerOptions; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.plugin.config.PluginOptions; +import com.alibaba.flink.shuffle.plugin.utils.ConfigurationUtils; +import com.alibaba.flink.shuffle.plugin.utils.IdMappingUtils; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; +import com.alibaba.flink.shuffle.rpc.utils.AkkaRpcServiceUtils; + +import org.apache.flink.configuration.AkkaOptions; +import org.apache.flink.configuration.MemorySize; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.shuffle.JobShuffleContext; +import org.apache.flink.runtime.shuffle.PartitionDescriptor; +import org.apache.flink.runtime.shuffle.ProducerDescriptor; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; +import org.apache.flink.runtime.shuffle.ShuffleMaster; +import org.apache.flink.runtime.shuffle.ShuffleMasterContext; +import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +/** The shuffle manager implementation for remote shuffle service plugin. */ +public class RemoteShuffleMaster implements ShuffleMaster { + + private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleMaster.class); + + private static final int MAX_RETRY_TIMES = 3; + + private final ShuffleMasterContext shuffleMasterContext; + + private final Configuration configuration; + + // Job level configuration will be supported in the future + private final String partitionFactory; + + private final Duration workerRecoverTimeout; + + private final Map shuffleClients = new HashMap<>(); + + private final ScheduledThreadPoolExecutor executor = + new ScheduledThreadPoolExecutor( + 1, runnable -> new Thread(runnable, "remote-shuffle-master-executor")); + + private final RemoteShuffleRpcService rpcService; + + private final HaServices haServices; + + private final AtomicBoolean isClosed = new AtomicBoolean(false); + + public RemoteShuffleMaster(ShuffleMasterContext shuffleMasterContext) { + CommonUtils.checkArgument(shuffleMasterContext != null, "Must be not null."); + + this.shuffleMasterContext = shuffleMasterContext; + this.executor.setRemoveOnCancelPolicy(true); + this.configuration = + ConfigurationUtils.fromFlinkConfiguration(shuffleMasterContext.getConfiguration()); + this.partitionFactory = configuration.getString(PluginOptions.DATA_PARTITION_FACTORY_NAME); + this.workerRecoverTimeout = + configuration.getDuration(WorkerOptions.MAX_WORKER_RECOVER_TIME); + + Throwable error = null; + RemoteShuffleRpcService tmpRpcService = null; + try { + tmpRpcService = createRpcService(); + } catch (Throwable throwable) { + LOG.error("Failed to create the shuffle master RPC service.", throwable); + error = throwable; + } + this.rpcService = tmpRpcService; + + HaServices tmpHAService = null; + try { + tmpHAService = HaServiceUtils.createHAServices(configuration); + } catch (Throwable throwable) { + LOG.error("Failed to create the shuffle master HA service.", throwable); + error = throwable; + } + this.haServices = tmpHAService; + + if (error != null) { + close(); + shuffleMasterContext.onFatalError(error); + throw new ShuffleException("Failed to initialize shuffle master.", error); + } + } + + @Override + public CompletableFuture registerPartitionWithProducer( + org.apache.flink.api.common.JobID jobID, + PartitionDescriptor partitionDescriptor, + ProducerDescriptor producerDescriptor) { + CompletableFuture future = new CompletableFuture<>(); + executor.execute( + () -> { + try { + CommonUtils.checkState(!isClosed.get(), "ShuffleMaster has been closed."); + JobID shuffleJobID = IdMappingUtils.fromFlinkJobId(jobID); + ShuffleClient shuffleClient = + CommonUtils.checkNotNull(shuffleClients.get(shuffleJobID)); + + ResultPartitionID resultPartitionID = + new ResultPartitionID( + partitionDescriptor.getPartitionId(), + producerDescriptor.getProducerExecutionId()); + DataSetID dataSetID = + IdMappingUtils.fromFlinkDataSetId( + partitionDescriptor.getResultId()); + MapPartitionID mapPartitionId = + IdMappingUtils.fromFlinkResultPartitionID(resultPartitionID); + + shuffleClient + .getClient() + .requestShuffleResource( + dataSetID, + mapPartitionId, + partitionDescriptor.getNumberOfSubpartitions(), + partitionFactory) + .whenComplete( + (shuffleResource, throwable) -> { + if (throwable != null) { + future.completeExceptionally(throwable); + return; + } + InstanceID workerID = + shuffleResource + .getMapPartitionLocation() + .getWorkerId(); + future.complete( + new RemoteShuffleDescriptor( + resultPartitionID, + shuffleJobID, + shuffleResource)); + shuffleClient + .getListener() + .addPartition(workerID, resultPartitionID); + }); + } catch (Throwable throwable) { + LOG.error("Failed to allocate shuffle resource.", throwable); + future.completeExceptionally(throwable); + } + }); + return future; + } + + @Override + public void releasePartitionExternally(ShuffleDescriptor shuffleDescriptor) { + executor.execute( + () -> { + if (!(shuffleDescriptor instanceof RemoteShuffleDescriptor)) { + LOG.error( + "Only RemoteShuffleDescriptor is supported {}.", + shuffleDescriptor.getClass().getName()); + shuffleMasterContext.onFatalError( + new ShuffleException("Illegal shuffle descriptor type.")); + return; + } + + RemoteShuffleDescriptor descriptor = + (RemoteShuffleDescriptor) shuffleDescriptor; + try { + ShuffleClient shuffleClient = shuffleClients.get(descriptor.getJobId()); + if (shuffleClient != null) { + shuffleClient + .getClient() + .releaseShuffleResource( + descriptor.getDataSetId(), + (MapPartitionID) descriptor.getDataPartitionID()); + } + } catch (Throwable throwable) { + // it is not a problem if we failed to release the target data partition + // because the session timeout mechanism will do the work for us latter + LOG.debug("Failed to release data partition {}.", descriptor, throwable); + } + }); + } + + @Override + public void close() { + if (isClosed.compareAndSet(false, true)) { + executor.execute( + () -> { + for (ShuffleClient clientWithListener : shuffleClients.values()) { + try { + clientWithListener.close(); + } catch (Throwable throwable) { + LOG.error("Failed to close shuffle client.", throwable); + } + } + shuffleClients.clear(); + + try { + if (haServices != null) { + haServices.close(); + } + } catch (Throwable throwable) { + LOG.error("Failed to close HA service.", throwable); + } + + try { + if (rpcService != null) { + rpcService.stopService().get(); + } + } catch (Throwable throwable) { + LOG.error("Failed to close the rpc service.", throwable); + } + + try { + executor.shutdown(); + } catch (Throwable throwable) { + LOG.error("Failed to close the shuffle master executor.", throwable); + } + }); + } + } + + @Override + public void registerJob(JobShuffleContext context) { + CompletableFuture future = new CompletableFuture<>(); + executor.execute( + () -> { + JobID jobID = IdMappingUtils.fromFlinkJobId(context.getJobId()); + if (shuffleClients.containsKey(jobID)) { + future.completeExceptionally( + new ShuffleException("Duplicated job registration.")); + LOG.error("Duplicated job registration {}:{}.", context.getJobId(), jobID); + return; + } + + try { + LOG.info("Registering job {}:{}", context.getJobId(), jobID); + CommonUtils.checkState(!isClosed.get(), "ShuffleMaster has been closed."); + + ShuffleManagerClientConfiguration shuffleManagerClientConfiguration = + ShuffleManagerClientConfiguration.fromConfiguration(configuration); + + HeartbeatServices heartbeatServices = + HeartbeatServicesUtils.createManagerJobHeartbeatServices( + configuration); + ShuffleWorkerStatusListenerImpl listener = + new ShuffleWorkerStatusListenerImpl(context); + ShuffleManagerClient client = + new ShuffleManagerClientImpl( + jobID, + listener, + rpcService, + shuffleMasterContext::onFatalError, + shuffleManagerClientConfiguration, + haServices, + heartbeatServices); + shuffleClients.put(jobID, new ShuffleClient(client, listener)); + client.start(); + future.complete(null); + } catch (Throwable throwable) { + LOG.error("Failed to register job.", throwable); + future.completeExceptionally(throwable); + CommonUtils.runQuietly(() -> unregisterJob(context.getJobId())); + } + }); + try { + future.get(); + } catch (InterruptedException | ExecutionException exception) { + ExceptionUtils.rethrowAsRuntimeException(exception); + } + } + + @Override + public void unregisterJob(org.apache.flink.api.common.JobID flinkJobID) { + executor.execute( + () -> { + try { + JobID jobID = IdMappingUtils.fromFlinkJobId(flinkJobID); + LOG.info("Unregister job {}:{}", flinkJobID, jobID); + ShuffleClient clientWithListener = shuffleClients.remove(jobID); + if (clientWithListener != null) { + clientWithListener.close(); + } + } catch (Throwable throwable) { + LOG.error( + "Encounter an error when unregistering job {}:{}.", + flinkJobID, + IdMappingUtils.fromFlinkJobId(flinkJobID), + throwable); + } + }); + } + + RemoteShuffleRpcService createRpcService() throws Exception { + org.apache.flink.configuration.Configuration configuration = + new org.apache.flink.configuration.Configuration( + shuffleMasterContext.getConfiguration()); + configuration.set(AkkaOptions.FORK_JOIN_EXECUTOR_PARALLELISM_MIN, 2); + configuration.set(AkkaOptions.FORK_JOIN_EXECUTOR_PARALLELISM_MAX, 2); + configuration.set(AkkaOptions.FORK_JOIN_EXECUTOR_PARALLELISM_FACTOR, 1.0); + + AkkaRpcServiceUtils.AkkaRpcServiceBuilder rpcServiceBuilder = + AkkaRpcServiceUtils.remoteServiceBuilder( + ConfigurationUtils.fromFlinkConfiguration(configuration), null, "0"); + return rpcServiceBuilder.withBindAddress("0.0.0.0").createAndStart(); + } + + @Override + public MemorySize computeShuffleMemorySizeForTask( + TaskInputsOutputsDescriptor taskInputsOutputsDescriptor) { + for (ResultPartitionType partitionType : + taskInputsOutputsDescriptor.getPartitionTypes().values()) { + if (!partitionType.isBlocking()) { + throw new ShuffleException( + "Blocking result partition type expected but found " + partitionType); + } + } + + int numResultPartitions = taskInputsOutputsDescriptor.getSubpartitionNums().size(); + long numBytesPerPartition = + configuration.getMemorySize(PluginOptions.MEMORY_PER_RESULT_PARTITION).getBytes(); + long numBytesForOutput = numBytesPerPartition * numResultPartitions; + + int numInputGates = taskInputsOutputsDescriptor.getInputChannelNums().size(); + long numBytesPerGate = + configuration.getMemorySize(PluginOptions.MEMORY_PER_INPUT_GATE).getBytes(); + long numBytesForInput = numBytesPerGate * numInputGates; + + LOG.debug( + "Announcing number of bytes {} for output and {} for input.", + numBytesForOutput, + numBytesForInput); + + return new MemorySize(numBytesForInput + numBytesForOutput); + } + + private class ShuffleWorkerStatusListenerImpl implements ShuffleWorkerStatusListener { + + private final JobShuffleContext context; + + private final Map> partitions = new HashMap<>(); + + private final Map> problematicWorkers = new HashMap<>(); + + ShuffleWorkerStatusListenerImpl(JobShuffleContext context) { + CommonUtils.checkArgument(context != null, "Must be not null."); + + this.context = context; + } + + private void addPartition(InstanceID workerID, ResultPartitionID partitionID) { + Set ids = + partitions.computeIfAbsent(workerID, (id) -> new HashSet<>()); + ids.add(partitionID); + ScheduledFuture scheduledFuture = problematicWorkers.remove(workerID); + if (scheduledFuture != null) { + scheduledFuture.cancel(false); + } + } + + @Override + public void notifyIrrelevantWorker(InstanceID workerID) { + executor.execute( + () -> { + if (!problematicWorkers.containsKey(workerID)) { + ScheduledFuture scheduledFuture = + executor.schedule( + () -> { + Set partitionIDS = + partitions.remove(workerID); + problematicWorkers.remove(workerID); + stopTrackingPartitions( + partitionIDS, + new AtomicInteger(MAX_RETRY_TIMES)); + }, + workerRecoverTimeout.getSeconds(), + TimeUnit.SECONDS); + problematicWorkers.put(workerID, scheduledFuture); + } + }); + } + + @Override + public void notifyRelevantWorker( + InstanceID workerID, Set dataPartitions) { + Set partitionIDs = new HashSet<>(); + for (DataPartitionCoordinate coordinate : dataPartitions) { + partitionIDs.add( + IdMappingUtils.fromMapPartitionID( + (MapPartitionID) coordinate.getDataPartitionId())); + } + + if (partitionIDs.isEmpty()) { + return; + } + + if (partitions.containsKey(workerID)) { + executor.execute( + () -> { + cancelScheduledFuture(problematicWorkers.remove(workerID)); + Set trackedPartitions = partitions.get(workerID); + partitions.put(workerID, partitionIDs); + for (ResultPartitionID partitionID : partitionIDs) { + trackedPartitions.remove(partitionID); + } + stopTrackingPartitions( + trackedPartitions, new AtomicInteger(MAX_RETRY_TIMES)); + }); + } else { + executor.execute( + () -> { + InstanceID oldWorkerID = null; + ResultPartitionID targetPartitionID = partitionIDs.iterator().next(); + for (InstanceID candidate : problematicWorkers.keySet()) { + Set idSet = partitions.get(candidate); + if (idSet != null && idSet.contains(targetPartitionID)) { + oldWorkerID = candidate; + break; + } + } + + if (oldWorkerID != null) { + Set idSet = partitions.get(oldWorkerID); + for (ResultPartitionID partitionID : partitionIDs) { + idSet.remove(partitionID); + } + if (idSet.isEmpty()) { + partitions.remove(oldWorkerID); + } + partitions.put(workerID, partitionIDs); + } + }); + } + } + + private void stopTrackingPartitions( + Set partitionIDS, AtomicInteger remainingRetries) { + if (partitionIDS == null || partitionIDS.isEmpty()) { + return; + } + + int count = remainingRetries.decrementAndGet(); + try { + CompletableFuture future = + context.stopTrackingAndReleasePartitions(partitionIDS); + future.whenCompleteAsync( + (ignored, throwable) -> { + if (throwable == null) { + return; + } + + if (count == 0) { + LOG.error( + "Failed to stop tracking partitions {}.", + Arrays.toString(partitionIDS.toArray())); + return; + } + stopTrackingPartitions(partitionIDS, remainingRetries); + }, + executor); + } catch (Throwable throwable) { + if (count == 0) { + LOG.error( + "Failed to stop tracking partitions {}.", + Arrays.toString(partitionIDS.toArray())); + return; + } + stopTrackingPartitions(partitionIDS, remainingRetries); + } + } + + public JobShuffleContext getContext() { + return context; + } + + public Map> getPartitions() { + return partitions; + } + + public Map> getProblematicWorkers() { + return problematicWorkers; + } + } + + private static void cancelScheduledFuture(ScheduledFuture scheduledFuture) { + try { + if (scheduledFuture != null && !scheduledFuture.cancel(false)) { + LOG.error("Failed to cancel the scheduled future, may already run."); + } + } catch (Throwable throwable) { + LOG.error("Error encountered when cancel the scheduled future.", throwable); + throw throwable; + } + } + + private static class ShuffleClient implements AutoCloseable { + + private final ShuffleManagerClient client; + + private final ShuffleWorkerStatusListenerImpl listener; + + ShuffleClient(ShuffleManagerClient client, ShuffleWorkerStatusListenerImpl listener) { + CommonUtils.checkArgument(client != null, "Must be not null."); + CommonUtils.checkArgument(listener != null, "Must be not null."); + + this.client = client; + this.listener = listener; + } + + @Override + public void close() throws Exception { + Throwable error = null; + + for (Set ids : listener.getPartitions().values()) { + try { + listener.getContext().stopTrackingAndReleasePartitions(ids); + } catch (Throwable throwable) { + error = error == null ? throwable : error; + LOG.error( + "Failed to stop tracking partitions {}.", + Arrays.toString(ids.toArray()), + throwable); + } + } + listener.getPartitions().clear(); + + for (ScheduledFuture scheduledFuture : listener.getProblematicWorkers().values()) { + try { + cancelScheduledFuture(scheduledFuture); + } catch (Throwable throwable) { + error = error == null ? throwable : error; + LOG.error("Failed to cancel scheduled future.", throwable); + } + } + listener.getProblematicWorkers().clear(); + + try { + client.close(); + } catch (Throwable throwable) { + error = error == null ? throwable : error; + LOG.error("Failed to close shuffle client.", throwable); + } + + if (error != null) { + ExceptionUtils.rethrowException(error); + } + } + + public ShuffleManagerClient getClient() { + return client; + } + + public ShuffleWorkerStatusListenerImpl getListener() { + return listener; + } + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleServiceFactory.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleServiceFactory.java new file mode 100644 index 00000000..a7c1a971 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleServiceFactory.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin; + +import com.alibaba.flink.shuffle.plugin.transfer.RemoteShuffleInputGateFactory; +import com.alibaba.flink.shuffle.plugin.transfer.RemoteShuffleResultPartitionFactory; +import com.alibaba.flink.shuffle.plugin.utils.ConfigurationUtils; +import com.alibaba.flink.shuffle.transfer.NettyConfig; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.MemorySize; +import org.apache.flink.configuration.NettyShuffleEnvironmentOptions; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; +import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; +import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate; +import org.apache.flink.runtime.shuffle.ShuffleEnvironment; +import org.apache.flink.runtime.shuffle.ShuffleEnvironmentContext; +import org.apache.flink.runtime.shuffle.ShuffleMaster; +import org.apache.flink.runtime.shuffle.ShuffleMasterContext; +import org.apache.flink.runtime.shuffle.ShuffleServiceFactory; +import org.apache.flink.runtime.util.ConfigurationParserUtils; + +import java.time.Duration; + +import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.registerShuffleMetrics; + +/** Flink remote shuffle service implementation. */ +public class RemoteShuffleServiceFactory + implements ShuffleServiceFactory< + RemoteShuffleDescriptor, ResultPartitionWriter, IndexedInputGate> { + + @Override + public ShuffleMaster createShuffleMaster( + ShuffleMasterContext shuffleMasterContext) { + return new RemoteShuffleMaster(shuffleMasterContext); + } + + @Override + public ShuffleEnvironment createShuffleEnvironment( + ShuffleEnvironmentContext context) { + Configuration configuration = context.getConfiguration(); + int bufferSize = ConfigurationParserUtils.getPageSize(configuration); + final int numBuffers = + calculateNumberOfNetworkBuffers(context.getNetworkMemorySize(), bufferSize); + + ResultPartitionManager resultPartitionManager = new ResultPartitionManager(); + MetricGroup metricGroup = context.getParentMetricGroup(); + + int numPreferredClientThreads = 2 * ConfigurationParserUtils.getSlot(configuration); + NettyConfig nettyConfig = + new NettyConfig( + ConfigurationUtils.fromFlinkConfiguration(configuration), + numPreferredClientThreads); + + Duration requestSegmentsTimeout = + Duration.ofMillis( + configuration.getLong( + NettyShuffleEnvironmentOptions + .NETWORK_EXCLUSIVE_BUFFERS_REQUEST_TIMEOUT_MILLISECONDS)); + NetworkBufferPool networkBufferPool = + new NetworkBufferPool(numBuffers, bufferSize, requestSegmentsTimeout); + + registerShuffleMetrics(metricGroup, networkBufferPool); + + String compressionCodec = + configuration.getString(NettyShuffleEnvironmentOptions.SHUFFLE_COMPRESSION_CODEC); + RemoteShuffleResultPartitionFactory resultPartitionFactory = + new RemoteShuffleResultPartitionFactory( + ConfigurationUtils.fromFlinkConfiguration(configuration), + resultPartitionManager, + networkBufferPool, + bufferSize, + compressionCodec); + + RemoteShuffleInputGateFactory inputGateFactory = + new RemoteShuffleInputGateFactory( + ConfigurationUtils.fromFlinkConfiguration(configuration), + networkBufferPool, + bufferSize, + compressionCodec); + + return new RemoteShuffleEnvironment( + networkBufferPool, + resultPartitionManager, + resultPartitionFactory, + inputGateFactory, + nettyConfig); + } + + private static int calculateNumberOfNetworkBuffers(MemorySize memorySize, int bufferSize) { + long numBuffersLong = memorySize.getBytes() / bufferSize; + if (numBuffersLong > Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "The given number of memory bytes (" + + memorySize.getBytes() + + ") corresponds to more than MAX_INT pages."); + } + return (int) numBuffersLong; + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/config/PluginOptions.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/config/PluginOptions.java new file mode 100644 index 00000000..573e4852 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/config/PluginOptions.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.config; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; +import com.alibaba.flink.shuffle.common.config.MemorySize; + +/** Config options for shuffle jobs using the remote shuffle service. */ +public class PluginOptions { + + public static final MemorySize MIN_MEMORY_PER_PARTITION = MemorySize.parse("8m"); + + public static final MemorySize MIN_MEMORY_PER_GATE = MemorySize.parse("8m"); + + /** + * The maximum number of remote shuffle channels to open and read concurrently per input gate. + */ + public static final ConfigOption NUM_CONCURRENT_READINGS = + new ConfigOption("remote-shuffle.job.concurrent-readings-per-gate") + .defaultValue(Integer.MAX_VALUE) + .description( + "The maximum number of remote shuffle channels to open and read " + + "concurrently per input gate."); + + /** + * The size of network buffers required per result partition. The minimum valid value is 8M. + * Usually, several hundreds of megabytes memory is enough for large scale batch jobs. + */ + public static final ConfigOption MEMORY_PER_RESULT_PARTITION = + new ConfigOption("remote-shuffle.job.memory-per-partition") + .defaultValue(MemorySize.parse("64m")) + .description( + "The size of network buffers required per result partition. The " + + "minimum valid value is 8M. Usually, several hundreds of " + + "megabytes memory is enough for large scale batch jobs."); + + /** + * The size of network buffers required per input gate. The minimum valid value is 8m. Usually, + * several hundreds of megabytes memory is enough for large scale batch jobs. + */ + public static final ConfigOption MEMORY_PER_INPUT_GATE = + new ConfigOption("remote-shuffle.job.memory-per-gate") + .defaultValue(MemorySize.parse("32m")) + .description( + "The size of network buffers required per input gate. The minimum " + + "valid value is 8m. Usually, several hundreds of megabytes " + + "memory is enough for large scale batch jobs."); + + /** + * Defines the factory used to create new data partitions. According to the specified data + * partition factory from the client side, the shuffle manager will return corresponding + * resources and the shuffle worker will create the corresponding partitions. + */ + public static final ConfigOption DATA_PARTITION_FACTORY_NAME = + new ConfigOption("remote-shuffle.job.data-partition-factory-name") + .defaultValue( + "com.alibaba.flink.shuffle.storage.partition.LocalFileMapPartitionFactory") + .description( + "Defines the factory used to create new data partitions. According to " + + "the specified data partition factory from the client side, " + + "the shuffle manager will return corresponding resources and" + + " the shuffle worker will create the corresponding partitions."); + + /** + * Whether to enable shuffle data compression. Usually, enabling data compression can save the + * storage space and achieve better performance. + */ + public static final ConfigOption ENABLE_DATA_COMPRESSION = + new ConfigOption("remote-shuffle.job.enable-data-compression") + .defaultValue(true) + .description( + "Whether to enable shuffle data compression. Usually, enabling data " + + "compression can save the storage space and achieve better " + + "performance."); + + /** + * Whether to shuffle the reading channels for better load balance of the downstream consumer + * tasks. + */ + public static final ConfigOption SHUFFLE_READING_CHANNELS = + new ConfigOption("remote-shuffle.job.shuffle-reading-channels") + .defaultValue(true) + .description( + "Whether to shuffle the reading channels for better load balance of the" + + " downstream consumer tasks."); +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/BufferHeader.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/BufferHeader.java new file mode 100644 index 00000000..3b16d1a6 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/BufferHeader.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import org.apache.flink.runtime.io.network.buffer.Buffer; + +/** Header information for a {@link org.apache.flink.runtime.io.network.buffer.Buffer}. */ +public class BufferHeader { + + private final Buffer.DataType dataType; + + private final boolean isCompressed; + + private final int size; + + public BufferHeader(Buffer.DataType dataType, boolean isCompressed, int size) { + this.dataType = dataType; + this.isCompressed = isCompressed; + this.size = size; + } + + public Buffer.DataType getDataType() { + return dataType; + } + + public boolean isCompressed() { + return isCompressed; + } + + public int getSize() { + return size; + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/BufferPacker.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/BufferPacker.java new file mode 100644 index 00000000..f2dc7964 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/BufferPacker.java @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import com.alibaba.flink.shuffle.common.functions.BiConsumerWithException; +import com.alibaba.flink.shuffle.plugin.utils.BufferUtils; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferRecycler; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufAllocator; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Queue; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** Harness used to pack multiple partial buffers together as a full one. */ +public class BufferPacker { + + private final BiConsumerWithException ripeBufferHandler; + + private Buffer cachedBuffer; + + private int currentSubIdx = -1; + + public BufferPacker( + BiConsumerWithException ripeBufferHandler) { + this.ripeBufferHandler = ripeBufferHandler; + } + + public void process(Buffer buffer, int subIdx) throws InterruptedException { + if (buffer == null) { + return; + } + + if (buffer.readableBytes() == 0) { + buffer.recycleBuffer(); + return; + } + + if (cachedBuffer == null) { + cachedBuffer = buffer; + currentSubIdx = subIdx; + } else if (currentSubIdx != subIdx) { + Buffer dumpedBuffer = cachedBuffer; + cachedBuffer = buffer; + int targetSubIdx = currentSubIdx; + currentSubIdx = subIdx; + handleRipeBuffer(dumpedBuffer, targetSubIdx); + } else { + if (cachedBuffer.readableBytes() + buffer.readableBytes() + <= cachedBuffer.getMaxCapacity()) { + cachedBuffer.asByteBuf().writeBytes(buffer.asByteBuf()); + buffer.recycleBuffer(); + } else { + Buffer dumpedBuffer = cachedBuffer; + cachedBuffer = buffer; + handleRipeBuffer(dumpedBuffer, currentSubIdx); + } + } + } + + public void drain() throws InterruptedException { + if (cachedBuffer != null) { + handleRipeBuffer(cachedBuffer, currentSubIdx); + } + cachedBuffer = null; + currentSubIdx = -1; + } + + private void handleRipeBuffer(Buffer buffer, int subIdx) throws InterruptedException { + buffer.setCompressed(false); + ripeBufferHandler.accept(buffer.asByteBuf(), subIdx); + } + + public static Queue unpack(ByteBuf byteBuf) { + Queue buffers = new ArrayDeque<>(); + try { + checkState(byteBuf instanceof Buffer, "Illegal buffer type."); + + Buffer buffer = (Buffer) byteBuf; + int position = 0; + int totalBytes = buffer.readableBytes(); + while (position < totalBytes) { + BufferHeader bufferHeader = BufferUtils.getBufferHeader(buffer, position); + position += BufferUtils.HEADER_LENGTH; + + Buffer slice = buffer.readOnlySlice(position, bufferHeader.getSize()); + position += bufferHeader.getSize(); + + buffers.add( + new UnpackSlicedBuffer( + slice, + bufferHeader.getDataType(), + bufferHeader.isCompressed(), + bufferHeader.getSize())); + slice.retainBuffer(); + } + return buffers; + } catch (Throwable throwable) { + buffers.forEach(Buffer::recycleBuffer); + throw throwable; + } finally { + byteBuf.release(); + } + } + + public void close() { + if (cachedBuffer != null) { + cachedBuffer.recycleBuffer(); + cachedBuffer = null; + } + currentSubIdx = -1; + } + + private static class UnpackSlicedBuffer implements Buffer { + + private final Buffer buffer; + + private DataType dataType; + + private boolean isCompressed; + + private final int size; + + UnpackSlicedBuffer(Buffer buffer, DataType dataType, boolean isCompressed, int size) { + this.buffer = buffer; + this.dataType = dataType; + this.isCompressed = isCompressed; + this.size = size; + } + + @Override + public boolean isBuffer() { + return dataType.isBuffer(); + } + + @Override + public MemorySegment getMemorySegment() { + return buffer.getMemorySegment(); + } + + @Override + public int getMemorySegmentOffset() { + return buffer.getMemorySegmentOffset(); + } + + @Override + public BufferRecycler getRecycler() { + return buffer.getRecycler(); + } + + @Override + public void recycleBuffer() { + buffer.recycleBuffer(); + } + + @Override + public boolean isRecycled() { + return buffer.isRecycled(); + } + + @Override + public Buffer retainBuffer() { + return buffer.retainBuffer(); + } + + @Override + public Buffer readOnlySlice() { + return buffer.readOnlySlice(); + } + + @Override + public Buffer readOnlySlice(int i, int i1) { + return buffer.readOnlySlice(i, i1); + } + + @Override + public int getMaxCapacity() { + return buffer.getMaxCapacity(); + } + + @Override + public int getReaderIndex() { + return buffer.getReaderIndex(); + } + + @Override + public void setReaderIndex(int i) throws IndexOutOfBoundsException { + buffer.setReaderIndex(i); + } + + @Override + public int getSize() { + return size; + } + + @Override + public void setSize(int i) { + buffer.setSize(i); + } + + @Override + public int readableBytes() { + return buffer.readableBytes(); + } + + @Override + public ByteBuffer getNioBufferReadable() { + return buffer.getNioBufferReadable(); + } + + @Override + public ByteBuffer getNioBuffer(int i, int i1) throws IndexOutOfBoundsException { + return buffer.getNioBuffer(i, i1); + } + + @Override + public void setAllocator(ByteBufAllocator byteBufAllocator) { + buffer.setAllocator(byteBufAllocator); + } + + @Override + public ByteBuf asByteBuf() { + return buffer.asByteBuf(); + } + + @Override + public boolean isCompressed() { + return isCompressed; + } + + @Override + public void setCompressed(boolean b) { + isCompressed = b; + } + + @Override + public DataType getDataType() { + return dataType; + } + + @Override + public void setDataType(DataType dataType) { + this.dataType = dataType; + } + + @Override + public int refCnt() { + return buffer.refCnt(); + } + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/PartitionNotFoundException.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/PartitionNotFoundException.java new file mode 100644 index 00000000..091e4218 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/PartitionNotFoundException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import org.apache.flink.runtime.io.network.partition.PartitionException; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; + +/** + * Exception for failed shuffle read handshake due to non-existing partitions on the remote. + * Different from {@link org.apache.flink.runtime.io.network.partition.PartitionNotFoundException}, + * this one can hold extra information in addition to {@link ResultPartitionID} and can be displayed + * on Flink UI and guides debugging. + */ +public class PartitionNotFoundException extends PartitionException { + + private static final long serialVersionUID = -7355585001649725463L; + + public PartitionNotFoundException(ResultPartitionID partitionId, String msg) { + super("Partition " + partitionId + " not found -- " + msg, partitionId); + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/PartitionSortedBuffer.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/PartitionSortedBuffer.java new file mode 100644 index 00000000..663ad699 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/PartitionSortedBuffer.java @@ -0,0 +1,453 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.BufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.util.FlinkRuntimeException; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.NotThreadSafe; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; +import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType; + +/** + * A {@link SortBuffer} implementation which sorts all appended records only by subpartition index. + * Records of the same subpartition keep the appended order. + * + *

It maintains a list of {@link MemorySegment}s as a joint buffer. Data will be appended to the + * joint buffer sequentially. When writing a record, an index entry will be appended first. An index + * entry consists of 4 fields: 4 bytes for record length, 4 bytes for {@link DataType} and 8 bytes + * for address pointing to the next index entry of the same channel which will be used to index the + * next record to read when coping data from this {@link SortBuffer}. For simplicity, no index entry + * can span multiple segments. The corresponding record data is seated right after its index entry + * and different from the index entry, records have variable length thus may span multiple segments. + */ +@NotThreadSafe +public class PartitionSortedBuffer implements SortBuffer { + + /** + * Size of an index entry: 4 bytes for record length, 4 bytes for data type and 8 bytes for + * pointer to next entry. + */ + private static final int INDEX_ENTRY_SIZE = 4 + 4 + 8; + + private final Object lock; + /** A buffer pool to request memory segments from. */ + private final BufferPool bufferPool; + + /** A segment list as a joint buffer which stores all records and index entries. */ + @GuardedBy("lock") + private final ArrayList buffers = new ArrayList<>(); + + /** Addresses of the first record's index entry for each subpartition. */ + private final long[] firstIndexEntryAddresses; + + /** Addresses of the last record's index entry for each subpartition. */ + private final long[] lastIndexEntryAddresses; + /** Size of buffers requested from buffer pool. All buffers must be of the same size. */ + private final int bufferSize; + /** Data of different subpartitions in this sort buffer will be read in this order. */ + private final int[] subpartitionReadOrder; + + // --------------------------------------------------------------------------------------------- + // Statistics and states + // --------------------------------------------------------------------------------------------- + /** Total number of bytes already appended to this sort buffer. */ + private long numTotalBytes; + /** Total number of records already appended to this sort buffer. */ + private long numTotalRecords; + /** Total number of bytes already read from this sort buffer. */ + private long numTotalBytesRead; + /** Whether this sort buffer is finished. One can only read a finished sort buffer. */ + private boolean isFinished; + + // --------------------------------------------------------------------------------------------- + // For writing + // --------------------------------------------------------------------------------------------- + /** Whether this sort buffer is released. A released sort buffer can not be used. */ + @GuardedBy("lock") + private boolean isReleased; + /** Array index in the segment list of the current available buffer for writing. */ + private int writeSegmentIndex; + + // --------------------------------------------------------------------------------------------- + // For reading + // --------------------------------------------------------------------------------------------- + /** Next position in the current available buffer for writing. */ + private int writeSegmentOffset; + /** Index entry address of the current record or event to be read. */ + private long readIndexEntryAddress; + + /** Record bytes remaining after last copy, which must be read first in next copy. */ + private int recordRemainingBytes; + + /** Used to index the current available channel to read data from. */ + private int readOrderIndex = -1; + + public PartitionSortedBuffer( + BufferPool bufferPool, + int numSubpartitions, + int bufferSize, + @Nullable int[] customReadOrder) { + checkArgument(bufferSize > INDEX_ENTRY_SIZE, "Buffer size is too small."); + + this.lock = new Object(); + this.bufferPool = checkNotNull(bufferPool); + this.bufferSize = bufferSize; + this.firstIndexEntryAddresses = new long[numSubpartitions]; + this.lastIndexEntryAddresses = new long[numSubpartitions]; + + // initialized with -1 means the corresponding channel has no data. + Arrays.fill(firstIndexEntryAddresses, -1L); + Arrays.fill(lastIndexEntryAddresses, -1L); + + this.subpartitionReadOrder = new int[numSubpartitions]; + if (customReadOrder != null) { + checkArgument(customReadOrder.length == numSubpartitions, "Illegal data read order."); + System.arraycopy(customReadOrder, 0, this.subpartitionReadOrder, 0, numSubpartitions); + } else { + for (int channel = 0; channel < numSubpartitions; ++channel) { + this.subpartitionReadOrder[channel] = channel; + } + } + } + + @Override + public boolean append(ByteBuffer source, int targetChannel, DataType dataType) + throws IOException { + checkArgument(source.hasRemaining(), "Cannot append empty data."); + checkState(!isFinished, "Sort buffer is already finished."); + checkState(!isReleased, "Sort buffer is already released."); + + int totalBytes = source.remaining(); + + // return false directly if it can not allocate enough buffers for the given record + if (!allocateBuffersForRecord(totalBytes)) { + return false; + } + + // write the index entry and record or event data + writeIndex(targetChannel, totalBytes, dataType); + writeRecord(source); + + ++numTotalRecords; + numTotalBytes += totalBytes; + + return true; + } + + private void writeIndex(int channelIndex, int numRecordBytes, DataType dataType) { + MemorySegment segment = buffers.get(writeSegmentIndex); + + // record length takes the high 32 bits and data type takes the low 32 bits + segment.putLong(writeSegmentOffset, ((long) numRecordBytes << 32) | dataType.ordinal()); + + // segment index takes the high 32 bits and segment offset takes the low 32 bits + long indexEntryAddress = ((long) writeSegmentIndex << 32) | writeSegmentOffset; + + long lastIndexEntryAddress = lastIndexEntryAddresses[channelIndex]; + lastIndexEntryAddresses[channelIndex] = indexEntryAddress; + + if (lastIndexEntryAddress >= 0) { + // link the previous index entry of the given channel to the new index entry + segment = buffers.get(getSegmentIndexFromPointer(lastIndexEntryAddress)); + segment.putLong( + getSegmentOffsetFromPointer(lastIndexEntryAddress) + 8, indexEntryAddress); + } else { + firstIndexEntryAddresses[channelIndex] = indexEntryAddress; + } + + // move the writer position forward to write the corresponding record + updateWriteSegmentIndexAndOffset(INDEX_ENTRY_SIZE); + } + + private void writeRecord(ByteBuffer source) { + while (source.hasRemaining()) { + MemorySegment segment = buffers.get(writeSegmentIndex); + int toCopy = Math.min(bufferSize - writeSegmentOffset, source.remaining()); + segment.put(writeSegmentOffset, source, toCopy); + + // move the writer position forward to write the remaining bytes or next record + updateWriteSegmentIndexAndOffset(toCopy); + } + } + + private boolean allocateBuffersForRecord(int numRecordBytes) throws IOException { + int numBytesRequired = INDEX_ENTRY_SIZE + numRecordBytes; + int availableBytes = + writeSegmentIndex == buffers.size() ? 0 : bufferSize - writeSegmentOffset; + + // return directly if current available bytes is adequate + if (availableBytes >= numBytesRequired) { + return true; + } + + // skip the remaining free space if the available bytes is not enough for an index entry + if (availableBytes < INDEX_ENTRY_SIZE) { + updateWriteSegmentIndexAndOffset(availableBytes); + availableBytes = 0; + } + + // allocate exactly enough buffers for the appended record + do { + MemorySegment segment = requestBufferFromPool(); + if (segment == null) { + // return false if we can not allocate enough buffers for the appended record + return false; + } + + availableBytes += bufferSize; + addBuffer(segment); + } while (availableBytes < numBytesRequired); + + return true; + } + + private void addBuffer(MemorySegment segment) { + synchronized (lock) { + if (segment.size() != bufferSize) { + bufferPool.recycle(segment); + throw new IllegalStateException("Illegal memory segment size."); + } + + if (isReleased) { + bufferPool.recycle(segment); + throw new IllegalStateException("Sort buffer is already released."); + } + + buffers.add(segment); + } + } + + private MemorySegment requestBufferFromPool() throws IOException { + try { + // blocking request buffers if there is still guaranteed memory + if (buffers.size() < bufferPool.getNumberOfRequiredMemorySegments()) { + return bufferPool.requestMemorySegmentBlocking(); + } + } catch (InterruptedException e) { + throw new IOException("Interrupted while requesting buffer."); + } + + return bufferPool.requestMemorySegment(); + } + + private void updateWriteSegmentIndexAndOffset(int numBytes) { + writeSegmentOffset += numBytes; + + // using the next available free buffer if the current is full + if (writeSegmentOffset == bufferSize) { + ++writeSegmentIndex; + writeSegmentOffset = 0; + } + } + + @Override + public BufferWithChannel copyIntoSegment( + MemorySegment target, BufferRecycler recycler, int offset) { + synchronized (lock) { + checkState(hasRemaining(), "No data remaining."); + checkState(isFinished, "Should finish the sort buffer first before coping any data."); + checkState(!isReleased, "Sort buffer is already released."); + + int numBytesCopied = 0; + DataType bufferDataType = DataType.DATA_BUFFER; + int channelIndex = subpartitionReadOrder[readOrderIndex]; + + do { + int sourceSegmentIndex = getSegmentIndexFromPointer(readIndexEntryAddress); + int sourceSegmentOffset = getSegmentOffsetFromPointer(readIndexEntryAddress); + MemorySegment sourceSegment = buffers.get(sourceSegmentIndex); + + long lengthAndDataType = sourceSegment.getLong(sourceSegmentOffset); + int length = getSegmentIndexFromPointer(lengthAndDataType); + DataType dataType = + DataType.values()[getSegmentOffsetFromPointer(lengthAndDataType)]; + + // return the data read directly if the next to read is an event + if (dataType.isEvent() && numBytesCopied > 0) { + break; + } + bufferDataType = dataType; + + // get the next index entry address and move the read position forward + long nextReadIndexEntryAddress = sourceSegment.getLong(sourceSegmentOffset + 8); + sourceSegmentOffset += INDEX_ENTRY_SIZE; + + // throws if the event is too big to be accommodated by a buffer. + if (bufferDataType.isEvent() && target.size() < length) { + throw new FlinkRuntimeException( + "Event is too big to be accommodated by a buffer"); + } + + numBytesCopied += + copyRecordOrEvent( + target, + numBytesCopied + offset, + sourceSegmentIndex, + sourceSegmentOffset, + length); + + if (recordRemainingBytes == 0) { + // move to next channel if the current channel has been finished + if (readIndexEntryAddress == lastIndexEntryAddresses[channelIndex]) { + updateReadChannelAndIndexEntryAddress(); + break; + } + readIndexEntryAddress = nextReadIndexEntryAddress; + } + } while (numBytesCopied < target.size() - offset && bufferDataType.isBuffer()); + + numTotalBytesRead += numBytesCopied; + Buffer buffer = + new NetworkBuffer(target, recycler, bufferDataType, numBytesCopied + offset); + return new BufferWithChannel(buffer, channelIndex); + } + } + + private int copyRecordOrEvent( + MemorySegment targetSegment, + int targetSegmentOffset, + int sourceSegmentIndex, + int sourceSegmentOffset, + int recordLength) { + if (recordRemainingBytes > 0) { + // skip the data already read if there is remaining partial record after the previous + // copy + long position = (long) sourceSegmentOffset + (recordLength - recordRemainingBytes); + sourceSegmentIndex += (position / bufferSize); + sourceSegmentOffset = (int) (position % bufferSize); + } else { + recordRemainingBytes = recordLength; + } + + int targetSegmentSize = targetSegment.size(); + int numBytesToCopy = + Math.min(targetSegmentSize - targetSegmentOffset, recordRemainingBytes); + do { + // move to next data buffer if all data of the current buffer has been copied + if (sourceSegmentOffset == bufferSize) { + ++sourceSegmentIndex; + sourceSegmentOffset = 0; + } + + int sourceRemainingBytes = + Math.min(bufferSize - sourceSegmentOffset, recordRemainingBytes); + int numBytes = Math.min(targetSegmentSize - targetSegmentOffset, sourceRemainingBytes); + MemorySegment sourceSegment = buffers.get(sourceSegmentIndex); + sourceSegment.copyTo(sourceSegmentOffset, targetSegment, targetSegmentOffset, numBytes); + + recordRemainingBytes -= numBytes; + targetSegmentOffset += numBytes; + sourceSegmentOffset += numBytes; + } while (recordRemainingBytes > 0 && targetSegmentOffset < targetSegmentSize); + + return numBytesToCopy; + } + + private void updateReadChannelAndIndexEntryAddress() { + // skip the channels without any data + while (++readOrderIndex < firstIndexEntryAddresses.length) { + int channelIndex = subpartitionReadOrder[readOrderIndex]; + if ((readIndexEntryAddress = firstIndexEntryAddresses[channelIndex]) >= 0) { + break; + } + } + } + + private int getSegmentIndexFromPointer(long value) { + return (int) (value >>> 32); + } + + private int getSegmentOffsetFromPointer(long value) { + return (int) (value); + } + + @Override + public long numRecords() { + return numTotalRecords; + } + + @Override + public long numBytes() { + return numTotalBytes; + } + + @Override + public boolean hasRemaining() { + return numTotalBytesRead < numTotalBytes; + } + + @Override + public void finish() { + checkState( + !isFinished, + "com.alibaba.flink.shuffle.plugin.transfer.SortBuffer is already finished."); + + isFinished = true; + + // prepare for reading + updateReadChannelAndIndexEntryAddress(); + } + + @Override + public boolean isFinished() { + return isFinished; + } + + @Override + public void release() { + // the sort buffer can be released by other threads + synchronized (lock) { + if (isReleased) { + return; + } + + isReleased = true; + + for (MemorySegment segment : buffers) { + bufferPool.recycle(segment); + } + buffers.clear(); + + numTotalBytes = 0; + numTotalRecords = 0; + } + } + + @Override + public boolean isReleased() { + synchronized (lock) { + return isReleased; + } + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleInputGate.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleInputGate.java new file mode 100644 index 00000000..e879db37 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleInputGate.java @@ -0,0 +1,757 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerDescriptor; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.plugin.RemoteShuffleDescriptor; +import com.alibaba.flink.shuffle.plugin.utils.BufferUtils; +import com.alibaba.flink.shuffle.transfer.ConnectionManager; +import com.alibaba.flink.shuffle.transfer.ShuffleReadClient; +import com.alibaba.flink.shuffle.transfer.TransferBufferPool; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentProvider; +import org.apache.flink.metrics.SimpleCounter; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo; +import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; +import org.apache.flink.runtime.event.AbstractEvent; +import org.apache.flink.runtime.event.TaskEvent; +import org.apache.flink.runtime.io.network.ConnectionID; +import org.apache.flink.runtime.io.network.LocalConnectionManager; +import org.apache.flink.runtime.io.network.api.CheckpointBarrier; +import org.apache.flink.runtime.io.network.api.EndOfData; +import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferDecompressor; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; +import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate; +import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; +import org.apache.flink.util.CloseableIterator; +import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.function.SupplierWithException; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import org.apache.commons.lang3.tuple.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.GuardedBy; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Optional; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** A {@link IndexedInputGate} which ingest data from remote shuffle workers. */ +public class RemoteShuffleInputGate extends IndexedInputGate { + + private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleInputGate.class); + + /** Lock to protect {@link #receivedBuffers} and {@link #cause} and {@link #closed}. */ + private final Object lock = new Object(); + + /** Name of the corresponding computing task. */ + private final String taskName; + + /** Used to manage physical connections. */ + private final ConnectionManager connectionManager; + + /** Index of the gate of the corresponding computing task. */ + private final int gateIndex; + + /** Deployment descriptor for a single input gate instance. */ + private final InputGateDeploymentDescriptor gateDescriptor; + + /** Number of concurrent readings. */ + private final int numConcurrentReading; + + /** Buffer pool provider. */ + private final SupplierWithException bufferPoolFactory; + + /** Flink buffer pools to allocate network memory. */ + private BufferPool bufferPool; + + /** Buffer pool used by the transfer layer. */ + private final TransferBufferPool transferBufferPool = + new TransferBufferPool(Collections.emptySet()); + + /** Data decompressor. */ + private final BufferDecompressor bufferDecompressor; + + /** {@link InputChannelInfo}s to describe channels. */ + private final List channelsInfo; + + /** A {@link ShuffleReadClient} corresponds to a reading channel. */ + private final List shuffleReadClients = new ArrayList<>(); + + /** Map from channel index to shuffle client index. */ + private final int[] clientIndexMap; + + /** Map from shuffle client index to channel index. */ + private final int[] channelIndexMap; + + /** The number of subpartitions that has not consumed per channel. */ + private final int[] numSubPartitionsHasNotConsumed; + + /** The overall number of subpartitions that has not been consumed. */ + private long numUnconsumedSubpartitions; + + /** Received buffers from remote shuffle worker. It's consumed by upper computing task. */ + @GuardedBy("lock") + private final Queue> receivedBuffers = new LinkedList<>(); + + /** {@link Throwable} when reading failure. */ + @GuardedBy("lock") + private Throwable cause; + + /** Whether this remote input gate has been closed or not. */ + @GuardedBy("lock") + private boolean closed; + + /** Whether we have opened all initial channels or not. */ + private boolean initialChannelsOpened; + + /** Number of pending {@link EndOfData} events to be received. */ + private long pendingEndOfDataEvents; + + public RemoteShuffleInputGate( + String taskName, + boolean shuffleChannels, + int gateIndex, + int networkBufferSize, + InputGateDeploymentDescriptor gateDescriptor, + int numConcurrentReading, + ConnectionManager connectionManager, + SupplierWithException bufferPoolFactory, + BufferDecompressor bufferDecompressor) { + + this.taskName = taskName; + this.gateIndex = gateIndex; + this.gateDescriptor = gateDescriptor; + this.numConcurrentReading = numConcurrentReading; + this.connectionManager = connectionManager; + this.bufferPoolFactory = bufferPoolFactory; + this.bufferDecompressor = bufferDecompressor; + + int numChannels = gateDescriptor.getShuffleDescriptors().length; + this.clientIndexMap = new int[numChannels]; + this.channelIndexMap = new int[numChannels]; + this.numSubPartitionsHasNotConsumed = new int[numChannels]; + this.channelsInfo = createChannelInfos(); + this.numUnconsumedSubpartitions = + initShuffleReadClients(networkBufferSize, shuffleChannels); + this.pendingEndOfDataEvents = numUnconsumedSubpartitions; + } + + private long initShuffleReadClients(int bufferSize, boolean shuffleChannels) { + int startSubIdx = gateDescriptor.getConsumedSubpartitionIndex(); + int endSubIdx = gateDescriptor.getConsumedSubpartitionIndex(); + checkState(endSubIdx >= startSubIdx); + int numSubpartitionsPerChannel = endSubIdx - startSubIdx + 1; + long numUnconsumedSubpartitions = 0; + + List> descriptors = + IntStream.range(0, gateDescriptor.getShuffleDescriptors().length) + .mapToObj(i -> Pair.of(i, gateDescriptor.getShuffleDescriptors()[i])) + .collect(Collectors.toList()); + if (shuffleChannels) { + Collections.shuffle(descriptors); + } + + int clientIndex = 0; + for (Pair descriptor : descriptors) { + RemoteShuffleDescriptor remoteDescriptor = + (RemoteShuffleDescriptor) descriptor.getRight(); + ShuffleWorkerDescriptor swd = + remoteDescriptor.getShuffleResource().getMapPartitionLocation(); + InetSocketAddress address = + new InetSocketAddress(swd.getWorkerAddress(), swd.getDataPort()); + LOG.debug( + "Create DataPartitionReader [dataSetID: {}, resultPartitionID: {}, channelIdx: " + + "{}, mapID: {}, startSubIdx: {}, endSubIdx {}, address: {}]", + remoteDescriptor.getDataSetId(), + remoteDescriptor.getResultPartitionID(), + descriptor.getLeft(), + remoteDescriptor.getDataPartitionID(), + startSubIdx, + endSubIdx, + address); + ShuffleReadClient shuffleReadClient = + createShuffleReadClient( + connectionManager, + address, + remoteDescriptor.getDataSetId(), + (MapPartitionID) remoteDescriptor.getDataPartitionID(), + startSubIdx, + endSubIdx, + bufferSize, + transferBufferPool, + getDataListener(descriptor.getLeft()), + getFailureListener(remoteDescriptor.getResultPartitionID())); + + shuffleReadClients.add(shuffleReadClient); + numSubPartitionsHasNotConsumed[descriptor.getLeft()] = numSubpartitionsPerChannel; + numUnconsumedSubpartitions += numSubpartitionsPerChannel; + clientIndexMap[descriptor.getLeft()] = clientIndex; + channelIndexMap[clientIndex] = descriptor.getLeft(); + ++clientIndex; + } + return numUnconsumedSubpartitions; + } + + /** Setup gate and build network connections. */ + @Override + public void setup() throws IOException { + long startTime = System.nanoTime(); + + bufferPool = bufferPoolFactory.get(); + BufferUtils.reserveNumRequiredBuffers( + bufferPool, RemoteShuffleInputGateFactory.MIN_BUFFERS_PER_GATE); + + try { + for (int i = 0; i < gateDescriptor.getShuffleDescriptors().length; i++) { + shuffleReadClients.get(i).connect(); + } + } catch (Throwable throwable) { + LOG.error("Failed to setup remote input gate.", throwable); + ExceptionUtils.rethrowAsRuntimeException(throwable); + } + + tryRequestBuffers(); + // Complete availability future though handshake not fired yet, thus to allow fetcher to + // 'pollNext' and fire handshake to remote. This mechanism is to avoid bookkeeping remote + // reading resource before task start processing data from input gate. + availabilityHelper.getUnavailableToResetAvailable().complete(null); + LOG.info("Set up read gate by {} ms.", (System.nanoTime() - startTime) / 1000_000); + } + + /** Index of the gate of the corresponding computing task. */ + @Override + public int getGateIndex() { + return gateIndex; + } + + /** Get number of input channels. A channel is a data flow from one shuffle worker. */ + @Override + public int getNumberOfInputChannels() { + return channelsInfo.size(); + } + + /** Whether reading is finished -- all channels are finished and cached buffers are drained. */ + @Override + public boolean isFinished() { + synchronized (lock) { + return allReadersEOF() && receivedBuffers.isEmpty(); + } + } + + @Override + public Optional getNext() { + throw new UnsupportedOperationException("Not implemented (DataSet API is not supported)."); + } + + /** Poll a received {@link BufferOrEvent}. */ + @Override + public Optional pollNext() throws IOException { + if (!initialChannelsOpened) { + tryOpenSomeChannels(); + initialChannelsOpened = true; + // DO NOT return, method of 'getReceived' will manipulate 'availabilityHelper'. + } + + Pair pair = getReceived(); + Optional bufferOrEvent = Optional.empty(); + while (pair != null) { + Buffer buffer = pair.getLeft(); + InputChannelInfo channelInfo = pair.getRight(); + + if (buffer.isBuffer()) { + bufferOrEvent = transformBuffer(buffer, channelInfo); + } else { + bufferOrEvent = transformEvent(buffer, channelInfo); + } + + if (bufferOrEvent.isPresent()) { + break; + } + pair = getReceived(); + } + + tryRequestBuffers(); + return bufferOrEvent; + } + + /** Close all reading channels inside this {@link RemoteShuffleInputGate}. */ + @Override + public void close() throws Exception { + List buffersToRecycle; + Throwable closeException = null; + synchronized (lock) { + // Do not check closed flag, thus to allow calling this method from both task thread and + // cancel thread. + for (ShuffleReadClient shuffleReadClient : shuffleReadClients) { + try { + shuffleReadClient.close(); + } catch (Throwable throwable) { + closeException = closeException == null ? throwable : closeException; + } + } + buffersToRecycle = + receivedBuffers.stream().map(Pair::getLeft).collect(Collectors.toList()); + receivedBuffers.clear(); + closed = true; + } + + buffersToRecycle.forEach(Buffer::recycleBuffer); + transferBufferPool.destroy(); + if (bufferPool != null) { + bufferPool.lazyDestroy(); + } + + if (closeException != null) { + ExceptionUtils.rethrowException(closeException); + } + } + + /** Get {@link InputChannelInfo}s of this {@link RemoteShuffleInputGate}. */ + @Override + public List getChannelInfos() { + return channelsInfo; + } + + /** Get all {@link ShuffleReadClient}s inside. Each one corresponds to a reading channel. */ + public List getShuffleReadClients() { + return shuffleReadClients; + } + + private List createChannelInfos() { + return IntStream.range(0, gateDescriptor.getShuffleDescriptors().length) + .mapToObj(i -> new InputChannelInfo(gateIndex, i)) + .collect(Collectors.toList()); + } + + ShuffleReadClient createShuffleReadClient( + ConnectionManager connectionManager, + InetSocketAddress address, + DataSetID dataSetID, + MapPartitionID mapID, + int startSubIdx, + int endSubIdx, + int bufferSize, + TransferBufferPool bufferPool, + Consumer dataListener, + Consumer failureListener) { + return new ShuffleReadClient( + address, + dataSetID, + mapID, + startSubIdx, + endSubIdx, + bufferSize, + bufferPool, + connectionManager, + dataListener, + failureListener); + } + + /** Try to open more readers to {@link #numConcurrentReading}. */ + private void tryOpenSomeChannels() throws IOException { + List clientsToOpen = new ArrayList<>(); + synchronized (lock) { + if (closed) { + throw new IOException("Input gate already closed."); + } + + LOG.debug("Try open some partition readers."); + int numOnGoing = 0; + for (int i = 0; i < shuffleReadClients.size(); i++) { + ShuffleReadClient shuffleReadClient = shuffleReadClients.get(i); + LOG.debug( + "Trying reader: {}, isOpened={}, numSubPartitionsHasNotConsumed={}.", + shuffleReadClient, + shuffleReadClient.isOpened(), + numSubPartitionsHasNotConsumed[channelIndexMap[i]]); + if (numOnGoing >= numConcurrentReading) { + break; + } + + if (shuffleReadClient.isOpened() + && numSubPartitionsHasNotConsumed[channelIndexMap[i]] > 0) { + numOnGoing++; + continue; + } + + if (!shuffleReadClient.isOpened()) { + clientsToOpen.add(shuffleReadClient); + numOnGoing++; + } + } + } + + for (ShuffleReadClient shuffleReadClient : clientsToOpen) { + shuffleReadClient.open(); + } + } + + private void tryRequestBuffers() { + checkState(bufferPool != null, "Not initialized yet."); + + Buffer buffer; + List buffers = new ArrayList<>(); + while ((buffer = bufferPool.requestBuffer()) != null) { + buffers.add(buffer.asByteBuf()); + } + + if (!buffers.isEmpty()) { + transferBufferPool.addBuffers(buffers); + } + } + + private void onBuffer(Buffer buffer, int channelIdx) { + synchronized (lock) { + if (closed || cause != null) { + buffer.recycleBuffer(); + throw new IllegalStateException("Input gate already closed or failed."); + } + + boolean needRecycle = true; + try { + boolean wasEmpty = receivedBuffers.isEmpty(); + InputChannelInfo channelInfo = channelsInfo.get(channelIdx); + checkState( + channelInfo.getInputChannelIdx() == channelIdx, "Illegal channel index."); + receivedBuffers.add(Pair.of(buffer, channelInfo)); + needRecycle = false; + if (wasEmpty) { + availabilityHelper.getUnavailableToResetAvailable().complete(null); + } + } catch (Throwable throwable) { + if (needRecycle) { + buffer.recycleBuffer(); + } + throw throwable; + } + } + } + + private Consumer getDataListener(int channelIdx) { + return byteBuf -> { + Queue unpackedBuffers = null; + try { + unpackedBuffers = BufferPacker.unpack(byteBuf); + while (!unpackedBuffers.isEmpty()) { + onBuffer(unpackedBuffers.poll(), channelIdx); + } + } catch (Throwable throwable) { + synchronized (lock) { + cause = cause == null ? throwable : cause; + availabilityHelper.getUnavailableToResetAvailable().complete(null); + } + + if (unpackedBuffers != null) { + unpackedBuffers.forEach(Buffer::recycleBuffer); + } + LOG.error("Failed to process the received buffer.", throwable); + } + }; + } + + private Consumer getFailureListener(ResultPartitionID rpID) { + return throwable -> { + synchronized (lock) { + if (cause != null) { + return; + } + Class clazz = + com.alibaba.flink.shuffle.core.exception.PartitionNotFoundException.class; + if (throwable.getMessage() != null + && throwable.getMessage().contains(clazz.getName())) { + cause = new PartitionNotFoundException(rpID, throwable.getMessage()); + } else { + cause = throwable; + } + availabilityHelper.getUnavailableToResetAvailable().complete(null); + } + }; + } + + private Pair getReceived() throws IOException { + synchronized (lock) { + healthCheck(); + if (!receivedBuffers.isEmpty()) { + return receivedBuffers.poll(); + } else { + if (!allReadersEOF()) { + availabilityHelper.resetUnavailable(); + } + return null; + } + } + } + + private void healthCheck() throws IOException { + if (closed) { + throw new IOException("Input gate already closed."); + } + if (cause != null) { + if (cause instanceof IOException) { + throw (IOException) cause; + } else { + throw new IOException(cause); + } + } + } + + private boolean allReadersEOF() { + return numUnconsumedSubpartitions <= 0; + } + + private Optional transformBuffer(Buffer buf, InputChannelInfo info) + throws IOException { + return Optional.of( + new BufferOrEvent(decompressBufferIfNeeded(buf), info, !isFinished(), false)); + } + + private Optional transformEvent(Buffer buffer, InputChannelInfo channelInfo) + throws IOException { + final AbstractEvent event; + try { + event = EventSerializer.fromBuffer(buffer, getClass().getClassLoader()); + } catch (Throwable t) { + throw new IOException("Deserialize failure.", t); + } finally { + buffer.recycleBuffer(); + } + + if (event.getClass() == EndOfPartitionEvent.class) { + checkState( + numSubPartitionsHasNotConsumed[channelInfo.getInputChannelIdx()] > 0, + "BUG -- EndOfPartitionEvent received repeatedly."); + numSubPartitionsHasNotConsumed[channelInfo.getInputChannelIdx()]--; + numUnconsumedSubpartitions--; + // not the real end. + if (numSubPartitionsHasNotConsumed[channelInfo.getInputChannelIdx()] != 0) { + return Optional.empty(); + } else { + // the real end. + shuffleReadClients.get(clientIndexMap[channelInfo.getInputChannelIdx()]).close(); + tryOpenSomeChannels(); + if (allReadersEOF()) { + availabilityHelper.getUnavailableToResetAvailable().complete(null); + } + } + } else if (event.getClass() == EndOfData.class) { + CommonUtils.checkState(!hasReceivedEndOfData()); + --pendingEndOfDataEvents; + } + + return Optional.of( + new BufferOrEvent( + event, + buffer.getDataType().hasPriority(), + channelInfo, + !isFinished(), + buffer.getSize(), + false)); + } + + private Buffer decompressBufferIfNeeded(Buffer buffer) throws IOException { + if (buffer.isCompressed()) { + try { + checkState(bufferDecompressor != null, "Buffer decompressor not set."); + return bufferDecompressor.decompressToIntermediateBuffer(buffer); + } catch (Throwable t) { + throw new IOException("Decompress failure", t); + } finally { + buffer.recycleBuffer(); + } + } + return buffer; + } + + @Override + public void requestPartitions() { + // do-nothing + } + + @Override + public void checkpointStarted(CheckpointBarrier barrier) { + // do-nothing. + } + + @Override + public void checkpointStopped(long cancelledCheckpointId) { + // do-nothing. + } + + @Override + public List getUnfinishedChannels() { + return Collections.emptyList(); + } + + @Override + public int getBuffersInUseCount() { + return 0; + } + + @Override + public boolean hasReceivedEndOfData() { + return pendingEndOfDataEvents <= 0; + } + + @Override + public void announceBufferSize(int bufferSize) {} + + @Override + public void finishReadRecoveredState() { + // do-nothing. + } + + @Override + public InputChannel getChannel(int channelIndex) { + return new FakedRemoteInputChannel(channelIndex); + } + + @Override + public void sendTaskEvent(TaskEvent event) { + throw new FlinkRuntimeException("Method should not be called."); + } + + @Override + public void resumeConsumption(InputChannelInfo channelInfo) { + throw new FlinkRuntimeException("Method should not be called."); + } + + @Override + public void acknowledgeAllRecordsProcessed(InputChannelInfo inputChannelInfo) {} + + @Override + public CompletableFuture getStateConsumedFuture() { + return CompletableFuture.completedFuture(null); + } + + @Override + public String toString() { + return String.format( + "ReadGate [owning task: %s, gate index: %d, descriptor: %s]", + taskName, gateIndex, gateDescriptor.toString()); + } + + /** Accommodation for the incompleteness of Flink pluggable shuffle service. */ + private class FakedRemoteInputChannel extends RemoteInputChannel { + FakedRemoteInputChannel(int channelIndex) { + super( + new SingleInputGate( + "", + gateIndex, + new IntermediateDataSetID(), + ResultPartitionType.BLOCKING, + 0, + 1, + (a, b, c) -> {}, + () -> null, + null, + new FakedMemorySegmentProvider(), + 0), + channelIndex, + new ResultPartitionID(), + new ConnectionID(new InetSocketAddress("", 0), 0), + new LocalConnectionManager(), + 0, + 0, + 0, + new SimpleCounter(), + new SimpleCounter(), + new FakedChannelStateWriter()); + } + } + + /** Accommodation for the incompleteness of Flink pluggable shuffle service. */ + private static class FakedMemorySegmentProvider implements MemorySegmentProvider { + @Override + public Collection requestMemorySegments(int i) { + return null; + } + + @Override + public void recycleMemorySegments(Collection collection) {} + } + + /** Accommodation for the incompleteness of Flink pluggable shuffle service. */ + private static class FakedChannelStateWriter implements ChannelStateWriter { + + @Override + public void start(long cpId, CheckpointOptions checkpointOptions) {} + + @Override + public void addInputData( + long cpId, + InputChannelInfo info, + int startSeqNum, + CloseableIterator data) {} + + @Override + public void addOutputData( + long cpId, ResultSubpartitionInfo info, int startSeqNum, Buffer... data) {} + + @Override + public void finishInput(long checkpointId) {} + + @Override + public void finishOutput(long checkpointId) {} + + @Override + public void abort(long checkpointId, Throwable cause, boolean cleanup) {} + + @Override + public ChannelStateWriteResult getAndRemoveWriteResult(long checkpointId) { + return null; + } + + @Override + public void close() {} + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleInputGateFactory.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleInputGateFactory.java new file mode 100644 index 00000000..4b3c9758 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleInputGateFactory.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.config.MemorySize; +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.plugin.config.PluginOptions; +import com.alibaba.flink.shuffle.transfer.ConnectionManager; + +import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; +import org.apache.flink.runtime.io.network.buffer.BufferDecompressor; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.BufferPoolFactory; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; +import org.apache.flink.util.function.SupplierWithException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** Factory class to create {@link RemoteShuffleInputGate}. */ +public class RemoteShuffleInputGateFactory { + + public static final int MIN_BUFFERS_PER_GATE = 16; + + private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleInputGateFactory.class); + + /** Number of max concurrent reading channels. */ + private final int numConcurrentReading; + + /** Codec used for compression / decompression. */ + private final String compressionCodec; + + /** Network buffer size. */ + private final int networkBufferSize; + + /** + * Network buffer pool used for shuffle read buffers. {@link BufferPool}s will be created from + * it and each of them will be used by a channel exclusively. + */ + private final NetworkBufferPool networkBufferPool; + + /** Sum of buffers. */ + private final int numBuffersPerGate; + + /** Whether to shuffle input channels before reading. */ + private final boolean shuffleChannels; + + public RemoteShuffleInputGateFactory( + Configuration configuration, + NetworkBufferPool networkBufferPool, + int networkBufferSize, + String compressionCodec) { + MemorySize configuredMemorySize = + configuration.getMemorySize(PluginOptions.MEMORY_PER_INPUT_GATE); + if (configuredMemorySize.getBytes() < PluginOptions.MIN_MEMORY_PER_GATE.getBytes()) { + throw new ConfigurationException( + String.format( + "Insufficient network memory per input gate, please increase %s to at " + + "least %s.", + PluginOptions.MEMORY_PER_INPUT_GATE.key(), + PluginOptions.MIN_MEMORY_PER_GATE.toHumanReadableString())); + } + + this.numBuffersPerGate = + CommonUtils.checkedDownCast(configuredMemorySize.getBytes() / networkBufferSize); + if (numBuffersPerGate < MIN_BUFFERS_PER_GATE) { + throw new ConfigurationException( + String.format( + "Insufficient network memory per input gate, please increase %s to at " + + "least %d bytes.", + PluginOptions.MEMORY_PER_INPUT_GATE.key(), + networkBufferSize * MIN_BUFFERS_PER_GATE)); + } + + this.shuffleChannels = configuration.getBoolean(PluginOptions.SHUFFLE_READING_CHANNELS); + this.compressionCodec = compressionCodec; + this.networkBufferSize = networkBufferSize; + this.numConcurrentReading = configuration.getInteger(PluginOptions.NUM_CONCURRENT_READINGS); + this.networkBufferPool = networkBufferPool; + } + + /** Create {@link RemoteShuffleInputGate} from {@link InputGateDeploymentDescriptor}. */ + public RemoteShuffleInputGate create( + String owningTaskName, + int gateIndex, + InputGateDeploymentDescriptor igdd, + ConnectionManager connectionManager) { + LOG.info( + "Create input gate -- number of buffers per input gate={}, " + + "number of concurrent readings={}.", + numBuffersPerGate, + numConcurrentReading); + + SupplierWithException bufferPoolFactory = + createBufferPoolFactory(networkBufferPool, numBuffersPerGate); + BufferDecompressor bufferDecompressor = + new BufferDecompressor(networkBufferSize, compressionCodec); + + return createInputGate( + owningTaskName, + shuffleChannels, + gateIndex, + igdd, + numConcurrentReading, + connectionManager, + bufferPoolFactory, + bufferDecompressor); + } + + // For testing. + RemoteShuffleInputGate createInputGate( + String owningTaskName, + boolean shuffleChannels, + int gateIndex, + InputGateDeploymentDescriptor igdd, + int numConcurrentReading, + ConnectionManager connectionManager, + SupplierWithException bufferPoolFactory, + BufferDecompressor bufferDecompressor) { + return new RemoteShuffleInputGate( + owningTaskName, + shuffleChannels, + gateIndex, + networkBufferSize, + igdd, + numConcurrentReading, + connectionManager, + bufferPoolFactory, + bufferDecompressor); + } + + private SupplierWithException createBufferPoolFactory( + BufferPoolFactory bufferPoolFactory, int numBuffers) { + return () -> bufferPoolFactory.createBufferPool(numBuffers, numBuffers); + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleOutputGate.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleOutputGate.java new file mode 100644 index 00000000..33425c05 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleOutputGate.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.coordinator.manager.DefaultShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerDescriptor; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.plugin.RemoteShuffleDescriptor; +import com.alibaba.flink.shuffle.plugin.utils.BufferUtils; +import com.alibaba.flink.shuffle.transfer.ConnectionManager; +import com.alibaba.flink.shuffle.transfer.ShuffleWriteClient; + +import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; +import org.apache.flink.util.function.SupplierWithException; + +import java.io.IOException; +import java.net.InetSocketAddress; + +/** + * A transportation gate used to spill buffers from {@link ResultPartitionWriter} to remote shuffle + * worker. + */ +public class RemoteShuffleOutputGate { + + /** A {@link ShuffleDescriptor} which describes shuffle meta and shuffle worker address. */ + private final RemoteShuffleDescriptor shuffleDesc; + + /** Number of subpartitions of the corresponding {@link ResultPartitionWriter}. */ + protected final int numSubs; + + /** Used to transport data to a shuffle worker. */ + private final ShuffleWriteClient shuffleWriteClient; + + /** Used to consolidate buffers. */ + private final BufferPacker bufferPacker; + + /** {@link BufferPool} provider. */ + protected final SupplierWithException bufferPoolFactory; + + /** Provides buffers to hold data to send online by Netty layer. */ + protected BufferPool bufferPool; + + /** + * @param shuffleDesc Describes shuffle meta and shuffle worker address. + * @param numSubs Number of subpartitions of the corresponding {@link ResultPartitionWriter}. + * @param bufferPoolFactory {@link BufferPool} provider. + * @param connectionManager Manages physical connections. + */ + public RemoteShuffleOutputGate( + RemoteShuffleDescriptor shuffleDesc, + int numSubs, + int bufferSize, + String dataPartitionFactoryName, + SupplierWithException bufferPoolFactory, + ConnectionManager connectionManager) { + + this.shuffleDesc = shuffleDesc; + this.numSubs = numSubs; + this.bufferPoolFactory = bufferPoolFactory; + this.shuffleWriteClient = + createWriteClient(bufferSize, dataPartitionFactoryName, connectionManager); + this.bufferPacker = new BufferPacker(shuffleWriteClient::write); + } + + /** Initialize transportation gate. */ + public void setup() throws IOException, InterruptedException { + bufferPool = CommonUtils.checkNotNull(bufferPoolFactory.get()); + CommonUtils.checkArgument( + bufferPool.getNumberOfRequiredMemorySegments() >= 2, + "Too few buffers for transfer, the minimum valid required size is 2."); + + // guarantee that we have at least one buffer + BufferUtils.reserveNumRequiredBuffers(bufferPool, 1); + + shuffleWriteClient.open(); + } + + /** Get transportation buffer pool. */ + public BufferPool getBufferPool() { + return bufferPool; + } + + /** Writes a {@link Buffer} to a subpartition. */ + public void write(Buffer buffer, int subIdx) throws InterruptedException { + bufferPacker.process(buffer, subIdx); + } + + /** + * Indicates the start of a region. A region of buffers guarantees the records inside are + * completed. + * + * @param isBroadcast Whether it's a broadcast region. + */ + public void regionStart(boolean isBroadcast) { + shuffleWriteClient.regionStart(isBroadcast); + } + + /** + * Indicates the finish of a region. A region is always bounded by a pair of region-start and + * region-finish. + */ + public void regionFinish() throws InterruptedException { + bufferPacker.drain(); + shuffleWriteClient.regionFinish(); + } + + /** Indicates the writing/spilling is finished. */ + public void finish() throws InterruptedException { + shuffleWriteClient.finish(); + } + + /** Close the transportation gate. */ + public void close() throws IOException { + if (bufferPool != null) { + bufferPool.lazyDestroy(); + } + bufferPacker.close(); + shuffleWriteClient.close(); + } + + /** Returns shuffle descriptor. */ + public RemoteShuffleDescriptor getShuffleDesc() { + return shuffleDesc; + } + + private ShuffleWriteClient createWriteClient( + int bufferSize, String dataPartitionFactoryName, ConnectionManager connectionManager) { + JobID jobID = shuffleDesc.getJobId(); + DataSetID dataSetID = shuffleDesc.getDataSetId(); + MapPartitionID mapID = (MapPartitionID) shuffleDesc.getDataPartitionID(); + ShuffleWorkerDescriptor swd = + ((DefaultShuffleResource) shuffleDesc.getShuffleResource()) + .getMapPartitionLocation(); + InetSocketAddress address = + new InetSocketAddress(swd.getWorkerAddress(), swd.getDataPort()); + return new ShuffleWriteClient( + address, + jobID, + dataSetID, + mapID, + numSubs, + bufferSize, + dataPartitionFactoryName, + connectionManager); + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleResultPartition.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleResultPartition.java new file mode 100644 index 00000000..7413ed43 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleResultPartition.java @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.plugin.utils.BufferUtils; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.runtime.event.AbstractEvent; +import org.apache.flink.runtime.io.network.api.EndOfData; +import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.Buffer.DataType; +import org.apache.flink.runtime.io.network.buffer.BufferCompressor; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener; +import org.apache.flink.runtime.io.network.partition.ResultPartition; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView; +import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SupplierWithException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.concurrent.CompletableFuture; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * A {@link ResultPartition} which appends records and events to {@link SortBuffer} and after the + * {@link SortBuffer} is full, all data in the {@link SortBuffer} will be copied and spilled to the + * remote shuffle service in subpartition index order sequentially. Large records that can not be + * appended to an empty {@link org.apache.flink.runtime.io.network.partition.SortBuffer} will be + * spilled directly. + */ +public class RemoteShuffleResultPartition extends ResultPartition { + + private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleResultPartition.class); + + /** Size of network buffer and write buffer. */ + private final int networkBufferSize; + + /** {@link SortBuffer} for records sent by {@link #broadcastRecord(ByteBuffer)}. */ + private SortBuffer broadcastSortBuffer; + + /** {@link SortBuffer} for records sent by {@link #emitRecord(ByteBuffer, int)}. */ + private SortBuffer unicastSortBuffer; + + /** Utility to spill data to shuffle workers. */ + private final RemoteShuffleOutputGate outputGate; + + /** Whether {@link #notifyEndOfData()} has been called or not. */ + private boolean endOfDataNotified; + + public RemoteShuffleResultPartition( + String owningTaskName, + int partitionIndex, + ResultPartitionID partitionId, + ResultPartitionType partitionType, + int numSubpartitions, + int numTargetKeyGroups, + int networkBufferSize, + ResultPartitionManager partitionManager, + @Nullable BufferCompressor bufferCompressor, + SupplierWithException bufferPoolFactory, + RemoteShuffleOutputGate outputGate) { + + super( + owningTaskName, + partitionIndex, + partitionId, + partitionType, + numSubpartitions, + numTargetKeyGroups, + partitionManager, + bufferCompressor, + bufferPoolFactory); + + this.networkBufferSize = networkBufferSize; + this.outputGate = outputGate; + } + + @Override + public void setup() throws IOException { + LOG.info("Setup {}", this); + super.setup(); + BufferUtils.reserveNumRequiredBuffers(bufferPool, 1); + try { + outputGate.setup(); + } catch (Throwable throwable) { + LOG.error("Failed to setup remote output gate.", throwable); + ExceptionUtils.rethrowAsRuntimeException(throwable); + } + } + + @Override + public void emitRecord(ByteBuffer record, int targetSubpartition) throws IOException { + emit(record, targetSubpartition, DataType.DATA_BUFFER, false); + } + + @Override + public void broadcastRecord(ByteBuffer record) throws IOException { + broadcast(record, DataType.DATA_BUFFER); + } + + @Override + public void broadcastEvent(AbstractEvent event, boolean isPriorityEvent) throws IOException { + Buffer buffer = EventSerializer.toBuffer(event, isPriorityEvent); + try { + ByteBuffer serializedEvent = buffer.getNioBufferReadable(); + broadcast(serializedEvent, buffer.getDataType()); + } finally { + buffer.recycleBuffer(); + } + } + + private void broadcast(ByteBuffer record, DataType dataType) throws IOException { + emit(record, 0, dataType, true); + } + + private void emit( + ByteBuffer record, int targetSubpartition, DataType dataType, boolean isBroadcast) + throws IOException { + + checkInProduceState(); + if (isBroadcast) { + Preconditions.checkState( + targetSubpartition == 0, + "Target subpartition index can only be 0 when broadcast."); + } + + SortBuffer sortBuffer = isBroadcast ? getBroadcastSortBuffer() : getUnicastSortBuffer(); + if (sortBuffer.append(record, targetSubpartition, dataType)) { + return; + } + + try { + if (!sortBuffer.hasRemaining()) { + // the record can not be appended to the free sort buffer because it is too large + sortBuffer.finish(); + sortBuffer.release(); + writeLargeRecord(record, targetSubpartition, dataType, isBroadcast); + return; + } + flushSortBuffer(sortBuffer, isBroadcast); + } catch (InterruptedException e) { + LOG.error("Failed to flush the sort buffer.", e); + ExceptionUtils.rethrowAsRuntimeException(e); + } + emit(record, targetSubpartition, dataType, isBroadcast); + } + + private void releaseSortBuffer(SortBuffer sortBuffer) { + if (sortBuffer != null) { + sortBuffer.release(); + } + } + + private SortBuffer getUnicastSortBuffer() throws IOException { + flushBroadcastSortBuffer(); + + if (unicastSortBuffer != null && !unicastSortBuffer.isFinished()) { + return unicastSortBuffer; + } + + unicastSortBuffer = + new PartitionSortedBuffer(bufferPool, numSubpartitions, networkBufferSize, null); + return unicastSortBuffer; + } + + private SortBuffer getBroadcastSortBuffer() throws IOException { + flushUnicastSortBuffer(); + + if (broadcastSortBuffer != null && !broadcastSortBuffer.isFinished()) { + return broadcastSortBuffer; + } + + broadcastSortBuffer = + new PartitionSortedBuffer(bufferPool, numSubpartitions, networkBufferSize, null); + return broadcastSortBuffer; + } + + private void flushBroadcastSortBuffer() throws IOException { + flushSortBuffer(broadcastSortBuffer, true); + } + + private void flushUnicastSortBuffer() throws IOException { + flushSortBuffer(unicastSortBuffer, false); + } + + private void flushSortBuffer(SortBuffer sortBuffer, boolean isBroadcast) throws IOException { + if (sortBuffer == null || sortBuffer.isReleased()) { + return; + } + sortBuffer.finish(); + if (sortBuffer.hasRemaining()) { + try { + outputGate.regionStart(isBroadcast); + while (sortBuffer.hasRemaining()) { + MemorySegment segment = + outputGate.getBufferPool().requestMemorySegmentBlocking(); + SortBuffer.BufferWithChannel bufferWithChannel; + try { + bufferWithChannel = + sortBuffer.copyIntoSegment( + segment, + outputGate.getBufferPool(), + BufferUtils.HEADER_LENGTH); + } catch (Throwable t) { + outputGate.getBufferPool().recycle(segment); + throw new FlinkRuntimeException("Shuffle write failure.", t); + } + + Buffer buffer = bufferWithChannel.getBuffer(); + int subpartitionIndex = bufferWithChannel.getChannelIndex(); + updateStatistics(bufferWithChannel.getBuffer()); + writeCompressedBufferIfPossible(buffer, subpartitionIndex); + } + outputGate.regionFinish(); + } catch (InterruptedException e) { + throw new IOException( + "Failed to flush the sort buffer, broadcast=" + isBroadcast, e); + } + } + releaseSortBuffer(sortBuffer); + } + + private void writeCompressedBufferIfPossible(Buffer buffer, int targetSubpartition) + throws InterruptedException { + Buffer compressedBuffer = null; + try { + if (canBeCompressed(buffer)) { + Buffer dataBuffer = + buffer.readOnlySlice( + BufferUtils.HEADER_LENGTH, + buffer.getSize() - BufferUtils.HEADER_LENGTH); + compressedBuffer = + checkNotNull(bufferCompressor).compressToIntermediateBuffer(dataBuffer); + } + BufferUtils.setCompressedDataWithHeader(buffer, compressedBuffer); + } catch (Throwable throwable) { + buffer.recycleBuffer(); + throw new ShuffleException("Shuffle write failure.", throwable); + } finally { + if (compressedBuffer != null && compressedBuffer.isCompressed()) { + compressedBuffer.setReaderIndex(0); + compressedBuffer.recycleBuffer(); + } + } + outputGate.write(buffer, targetSubpartition); + } + + private void updateStatistics(Buffer buffer) { + numBuffersOut.inc(); + numBytesOut.inc(buffer.readableBytes() - BufferUtils.HEADER_LENGTH); + } + + /** Spills the large record into {@link RemoteShuffleOutputGate}. */ + private void writeLargeRecord( + ByteBuffer record, int targetSubpartition, DataType dataType, boolean isBroadcast) + throws InterruptedException { + + outputGate.regionStart(isBroadcast); + while (record.hasRemaining()) { + MemorySegment writeBuffer = outputGate.getBufferPool().requestMemorySegmentBlocking(); + int toCopy = + Math.min(record.remaining(), writeBuffer.size() - BufferUtils.HEADER_LENGTH); + writeBuffer.put(BufferUtils.HEADER_LENGTH, record, toCopy); + NetworkBuffer buffer = + new NetworkBuffer( + writeBuffer, + outputGate.getBufferPool(), + dataType, + toCopy + BufferUtils.HEADER_LENGTH); + + updateStatistics(buffer); + writeCompressedBufferIfPossible(buffer, targetSubpartition); + } + outputGate.regionFinish(); + } + + @Override + public void finish() throws IOException { + checkState(!isReleased(), "Result partition is already released."); + broadcastEvent(EndOfPartitionEvent.INSTANCE, false); + checkState( + unicastSortBuffer == null || unicastSortBuffer.isReleased(), + "The unicast sort buffer should be either null or released."); + flushBroadcastSortBuffer(); + try { + outputGate.finish(); + } catch (InterruptedException e) { + throw new IOException("Output gate fails to finish.", e); + } + super.finish(); + } + + @Override + public synchronized void close() { + releaseSortBuffer(unicastSortBuffer); + releaseSortBuffer(broadcastSortBuffer); + super.close(); + try { + outputGate.close(); + } catch (Exception e) { + ExceptionUtils.rethrowAsRuntimeException(e); + } + } + + @Override + protected void releaseInternal() { + // no-op + } + + @Override + public void flushAll() { + try { + flushUnicastSortBuffer(); + flushBroadcastSortBuffer(); + } catch (Throwable t) { + LOG.error("Failed to flush the current sort buffer.", t); + ExceptionUtils.rethrowAsRuntimeException(t); + } + } + + @Override + public void flush(int subpartitionIndex) { + flushAll(); + } + + @Override + public CompletableFuture getAvailableFuture() { + return AVAILABLE; + } + + @Override + public int getNumberOfQueuedBuffers() { + return 0; + } + + @Override + public int getNumberOfQueuedBuffers(int targetSubpartition) { + return 0; + } + + @Override + public ResultSubpartitionView createSubpartitionView( + int index, BufferAvailabilityListener availabilityListener) { + throw new UnsupportedOperationException("Not supported."); + } + + @Override + public void notifyEndOfData() throws IOException { + if (!endOfDataNotified) { + broadcastEvent(EndOfData.INSTANCE, false); + endOfDataNotified = true; + } + } + + @Override + public CompletableFuture getAllDataProcessedFuture() { + return CompletableFuture.completedFuture(null); + } + + @Override + public String toString() { + return "ResultPartition " + + partitionId.toString() + + " [" + + partitionType + + ", " + + numSubpartitions + + " subpartitions, shuffle-descriptor: " + + outputGate.getShuffleDesc() + + "]"; + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleResultPartitionFactory.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleResultPartitionFactory.java new file mode 100644 index 00000000..3909a9d5 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleResultPartitionFactory.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.config.MemorySize; +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.plugin.RemoteShuffleDescriptor; +import com.alibaba.flink.shuffle.plugin.config.PluginOptions; +import com.alibaba.flink.shuffle.transfer.ConnectionManager; + +import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; +import org.apache.flink.runtime.io.network.buffer.BufferCompressor; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.BufferPoolFactory; +import org.apache.flink.runtime.io.network.partition.ResultPartition; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; +import org.apache.flink.util.function.SupplierWithException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** Factory class to create {@link RemoteShuffleResultPartition}. */ +public class RemoteShuffleResultPartitionFactory { + + private static final Logger LOG = + LoggerFactory.getLogger(RemoteShuffleResultPartitionFactory.class); + + public static final int MIN_BUFFERS_PER_PARTITION = 16; + + /** Not used and just for compatibility with Flink pluggable shuffle service. */ + private final ResultPartitionManager partitionManager; + + /** Network buffer pool used for shuffle write buffers. */ + private final BufferPoolFactory bufferPoolFactory; + + /** Network buffer size. */ + private final int networkBufferSize; + + /** Remote data partition type. */ + private final String dataPartitionFactoryName; + + /** Whether compression enabled. */ + private final boolean compressionEnabled; + + /** Codec used for compression / decompression. */ + private final String compressionCodec; + + /** + * Configured number of buffers for shuffle write, it contains two parts: sorting buffers and + * transportation buffers. + */ + private final int numBuffersPerPartition; + + public RemoteShuffleResultPartitionFactory( + Configuration configuration, + ResultPartitionManager partitionManager, + BufferPoolFactory bufferPoolFactory, + int networkBufferSize, + String compressionCodec) { + MemorySize configuredMemorySize = + configuration.getMemorySize(PluginOptions.MEMORY_PER_RESULT_PARTITION); + if (configuredMemorySize.getBytes() < PluginOptions.MIN_MEMORY_PER_PARTITION.getBytes()) { + throw new ConfigurationException( + String.format( + "Insufficient network memory per result partition, please increase %s " + + "to at least %s.", + PluginOptions.MEMORY_PER_RESULT_PARTITION.key(), + PluginOptions.MIN_MEMORY_PER_PARTITION.toHumanReadableString())); + } + + this.numBuffersPerPartition = + CommonUtils.checkedDownCast(configuredMemorySize.getBytes() / networkBufferSize); + if (numBuffersPerPartition < MIN_BUFFERS_PER_PARTITION) { + throw new ConfigurationException( + String.format( + "Insufficient network memory per partition, please increase %s to at " + + "least %d bytes.", + PluginOptions.MEMORY_PER_RESULT_PARTITION.key(), + networkBufferSize * MIN_BUFFERS_PER_PARTITION)); + } + + this.compressionEnabled = configuration.getBoolean(PluginOptions.ENABLE_DATA_COMPRESSION); + this.dataPartitionFactoryName = + configuration.getString(PluginOptions.DATA_PARTITION_FACTORY_NAME); + + this.partitionManager = partitionManager; + this.bufferPoolFactory = bufferPoolFactory; + this.networkBufferSize = networkBufferSize; + this.compressionCodec = compressionCodec; + } + + public ResultPartition create( + String taskNameWithSubtaskAndId, + int partitionIndex, + ResultPartitionDeploymentDescriptor desc, + ConnectionManager connectionManager) { + LOG.info( + "Create result partition -- number of buffers per result partition={}, " + + "number of subpartitions={}.", + numBuffersPerPartition, + desc.getNumberOfSubpartitions()); + + return create( + taskNameWithSubtaskAndId, + partitionIndex, + desc.getShuffleDescriptor().getResultPartitionID(), + desc.getPartitionType(), + desc.getNumberOfSubpartitions(), + desc.getMaxParallelism(), + createBufferPoolFactory(), + desc.getShuffleDescriptor(), + connectionManager); + } + + private ResultPartition create( + String taskNameWithSubtaskAndId, + int partitionIndex, + ResultPartitionID id, + ResultPartitionType type, + int numSubpartitions, + int maxParallelism, + List> bufferPoolFactories, + ShuffleDescriptor shuffleDescriptor, + ConnectionManager connectionManager) { + + final BufferCompressor bufferCompressor; + if (compressionEnabled) { + bufferCompressor = new BufferCompressor(networkBufferSize, compressionCodec); + } else { + bufferCompressor = null; + } + RemoteShuffleDescriptor rsd = (RemoteShuffleDescriptor) shuffleDescriptor; + ResultPartition partition = + new RemoteShuffleResultPartition( + taskNameWithSubtaskAndId, + partitionIndex, + id, + type, + numSubpartitions, + maxParallelism, + networkBufferSize, + partitionManager, + bufferCompressor, + bufferPoolFactories.get(0), + new RemoteShuffleOutputGate( + rsd, + numSubpartitions, + networkBufferSize, + dataPartitionFactoryName, + bufferPoolFactories.get(1), + connectionManager)); + LOG.debug("{}: Initialized {}", taskNameWithSubtaskAndId, this); + return partition; + } + + /** + * Used to create 2 buffer pools -- sorting buffer pool (7/8), transportation buffer pool (1/8). + */ + private List> createBufferPoolFactory() { + int numForResultPartition = numBuffersPerPartition * 7 / 8; + int numForOutputGate = numBuffersPerPartition - numForResultPartition; + + List> factories = new ArrayList<>(); + factories.add( + () -> + bufferPoolFactory.createBufferPool( + numForResultPartition, numForResultPartition)); + factories.add(() -> bufferPoolFactory.createBufferPool(numForOutputGate, numForOutputGate)); + return factories; + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/SortBuffer.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/SortBuffer.java new file mode 100644 index 00000000..285e7045 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/transfer/SortBuffer.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferRecycler; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** + * Data of different channels can be appended to a {@link SortBuffer}., after apending finished, + * data can be copied from it in channel index order. + */ +public interface SortBuffer { + + /** + * Appends data of the specified channel to this {@link SortBuffer} and returns true if all + * bytes of the source buffer is copied to this {@link SortBuffer} successfully, otherwise if + * returns false, nothing will be copied. + */ + boolean append(ByteBuffer source, int targetChannel, Buffer.DataType dataType) + throws IOException; + + /** + * Copies data from this {@link SortBuffer} to the target {@link MemorySegment} in channel index + * order and returns {@link BufferWithChannel} which contains the copied data and the + * corresponding channel index. + */ + BufferWithChannel copyIntoSegment(MemorySegment target, BufferRecycler recycler, int offset); + + /** Returns the number of records written to this {@link SortBuffer}. */ + long numRecords(); + + /** Returns the number of bytes written to this {@link SortBuffer}. */ + long numBytes(); + + /** Returns true if there is still data can be consumed in this {@link SortBuffer}. */ + boolean hasRemaining(); + + /** Finishes this {@link SortBuffer} which means no record can be appended any more. */ + void finish(); + + /** Whether this {@link SortBuffer} is finished or not. */ + boolean isFinished(); + + /** Releases this {@link SortBuffer} which releases all resources. */ + void release(); + + /** Whether this {@link SortBuffer} is released or not. */ + boolean isReleased(); + + /** Buffer and the corresponding channel index returned to reader. */ + class BufferWithChannel { + + private final Buffer buffer; + + private final int channelIndex; + + BufferWithChannel(Buffer buffer, int channelIndex) { + this.buffer = checkNotNull(buffer); + this.channelIndex = channelIndex; + } + + /** Get {@link Buffer}. */ + public Buffer getBuffer() { + return buffer; + } + + /** Get channel index. */ + public int getChannelIndex() { + return channelIndex; + } + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/utils/BufferUtils.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/utils/BufferUtils.java new file mode 100644 index 00000000..7f393d02 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/utils/BufferUtils.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.utils; + +import com.alibaba.flink.shuffle.plugin.transfer.BufferHeader; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferPool; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; + +/** Utility methods to process flink buffers. */ +public class BufferUtils { + + // dataType(1) + isCompressed(1) + bufferSize(4) + public static final int HEADER_LENGTH = 1 + 1 + 4; + + /** + * Copies the data of the compressed buffer and the corresponding buffer header to the origin + * buffer. The origin buffer must reserve the {@link #HEADER_LENGTH} space for the header data. + */ + public static void setCompressedDataWithHeader(Buffer buffer, Buffer compressedBuffer) { + checkArgument(buffer != null, "Must be not null."); + checkArgument(buffer.getReaderIndex() == 0, "Illegal reader index."); + + boolean isCompressed = compressedBuffer != null && compressedBuffer.isCompressed(); + int dataLength = + isCompressed + ? compressedBuffer.readableBytes() + : buffer.readableBytes() - HEADER_LENGTH; + ByteBuf byteBuf = buffer.asByteBuf(); + setBufferHeader(byteBuf, buffer.getDataType(), isCompressed, dataLength); + + if (isCompressed) { + byteBuf.writeBytes(compressedBuffer.asByteBuf()); + } + buffer.setSize(dataLength + HEADER_LENGTH); + } + + public static void setBufferHeader( + ByteBuf byteBuf, Buffer.DataType dataType, boolean isCompressed, int dataLength) { + byteBuf.writerIndex(0); + byteBuf.writeByte(dataType.ordinal()); + byteBuf.writeBoolean(isCompressed); + byteBuf.writeInt(dataLength); + } + + public static BufferHeader getBufferHeader(Buffer buffer, int position) { + ByteBuf byteBuf = buffer.asByteBuf(); + byteBuf.readerIndex(position); + return new BufferHeader( + Buffer.DataType.values()[byteBuf.readByte()], + byteBuf.readBoolean(), + byteBuf.readInt()); + } + + public static void reserveNumRequiredBuffers(BufferPool bufferPool, int numRequiredBuffers) + throws IOException { + long startTime = System.nanoTime(); + List buffers = new ArrayList<>(numRequiredBuffers); + try { + // guarantee that we have at least the minimal number of buffers + while (buffers.size() < numRequiredBuffers) { + MemorySegment segment = bufferPool.requestMemorySegment(); + if (segment != null) { + buffers.add(segment); + continue; + } + + Thread.sleep(10); + if ((System.nanoTime() - startTime) > 3L * 60 * 1000_000_000) { + throw new IOException( + "Could not allocate the required number of buffers in 3 minutes."); + } + } + } catch (Throwable throwable) { + throw new IOException(throwable); + } finally { + buffers.forEach(bufferPool::recycle); + } + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/utils/ConfigurationUtils.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/utils/ConfigurationUtils.java new file mode 100644 index 00000000..5ec53f16 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/utils/ConfigurationUtils.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.utils; + +import com.alibaba.flink.shuffle.common.config.Configuration; + +import java.util.Properties; + +/** Utils for configurations. */ +public class ConfigurationUtils { + + /** Convert {@link org.apache.flink.configuration.Configuration} to {@link Configuration}. */ + public static Configuration fromFlinkConfiguration( + org.apache.flink.configuration.Configuration configuration) { + Properties properties = new Properties(); + properties.putAll(configuration.toMap()); + return new Configuration(properties); + } + + /** Convert {@link Configuration} to {@link org.apache.flink.configuration.Configuration}. */ + public static org.apache.flink.configuration.Configuration toFlinkConfiguration( + Configuration configuration) { + return org.apache.flink.configuration.Configuration.fromMap(configuration.toMap()); + } +} diff --git a/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/utils/IdMappingUtils.java b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/utils/IdMappingUtils.java new file mode 100644 index 00000000..32d8f4e0 --- /dev/null +++ b/shuffle-plugin/src/main/java/com/alibaba/flink/shuffle/plugin/utils/IdMappingUtils.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.utils; + +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; + +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.util.AbstractID; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled; + +/** Processes the mapping from flink ids to the remote shuffle ids. */ +public class IdMappingUtils { + + public static JobID fromFlinkJobId(AbstractID flinkJobId) { + return new JobID(flinkJobId.getBytes()); + } + + public static DataSetID fromFlinkDataSetId(IntermediateDataSetID flinkDataSetId) { + return new DataSetID(flinkDataSetId.getBytes()); + } + + public static MapPartitionID fromFlinkResultPartitionID(ResultPartitionID resultPartitionID) { + ByteBuf byteBuf = Unpooled.buffer(); + resultPartitionID.getPartitionId().writeTo(byteBuf); + resultPartitionID.getProducerId().writeTo(byteBuf); + + byte[] bytes = new byte[byteBuf.readableBytes()]; + byteBuf.readBytes(bytes); + byteBuf.release(); + + return new MapPartitionID(bytes); + } + + public static ResultPartitionID fromMapPartitionID(MapPartitionID mapPartitionID) { + ByteBuf byteBuf = Unpooled.buffer(); + byteBuf.writeBytes(mapPartitionID.getId()); + + IntermediateResultPartitionID partitionID = + IntermediateResultPartitionID.fromByteBuf(byteBuf); + ExecutionAttemptID attemptID = ExecutionAttemptID.fromByteBuf(byteBuf); + byteBuf.release(); + + return new ResultPartitionID(partitionID, attemptID); + } +} diff --git a/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleMasterTest.java b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleMasterTest.java new file mode 100644 index 00000000..c66e471d --- /dev/null +++ b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleMasterTest.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin; + +import com.alibaba.flink.shuffle.coordinator.manager.DefaultShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerDescriptor; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.InstanceID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.plugin.config.PluginOptions; +import com.alibaba.flink.shuffle.plugin.utils.ConfigurationUtils; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; +import com.alibaba.flink.shuffle.rpc.message.Acknowledge; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.MemorySize; +import org.apache.flink.runtime.clusterframework.types.ResourceID; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.shuffle.JobShuffleContext; +import org.apache.flink.runtime.shuffle.PartitionDescriptor; +import org.apache.flink.runtime.shuffle.ProducerDescriptor; +import org.apache.flink.runtime.shuffle.ShuffleMasterContext; +import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor; + +import org.junit.Assert; +import org.junit.Test; + +import java.net.InetAddress; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; + +/** Tests the behavior of {@link RemoteShuffleMaster}. */ +public class RemoteShuffleMasterTest extends RemoteShuffleShuffleTestBase { + + @Test + public void testResourceAllocateAndRelease() throws Exception { + // Resource request + CompletableFuture> resourceRequestFuture = + new CompletableFuture<>(); + ShuffleResource shuffleResource = + new DefaultShuffleResource( + new ShuffleWorkerDescriptor[] { + new ShuffleWorkerDescriptor(new InstanceID("worker1"), "worker1", 20480) + }, + DataPartition.DataPartitionType.MAP_PARTITION); + smGateway.setAllocateShuffleResourceConsumer( + (jobID, dataSetID, mapPartitionID, numberOfSubpartitions) -> { + resourceRequestFuture.complete( + new Tuple4<>(jobID, dataSetID, mapPartitionID, numberOfSubpartitions)); + return CompletableFuture.completedFuture(shuffleResource); + }); + + // Resource release + CompletableFuture> resourceReleaseFuture = + new CompletableFuture<>(); + smGateway.setReleaseShuffleResourceConsumer( + (jobID, dataSetID, mapPartitionID) -> { + resourceReleaseFuture.complete(new Tuple3<>(jobID, dataSetID, mapPartitionID)); + return CompletableFuture.completedFuture(Acknowledge.get()); + }); + + org.apache.flink.api.common.JobID jobID = new org.apache.flink.api.common.JobID(); + IntermediateDataSetID intermediateDataSetId = new IntermediateDataSetID(); + IntermediateResultPartitionID intermediateResultPartitionId = + new IntermediateResultPartitionID(intermediateDataSetId, 0); + ExecutionAttemptID executionAttemptId = new ExecutionAttemptID(); + ResultPartitionID resultPartitionId = + new ResultPartitionID(intermediateResultPartitionId, executionAttemptId); + PartitionDescriptor partitionDescriptor = + new PartitionDescriptor( + intermediateDataSetId, + 10, + intermediateResultPartitionId, + ResultPartitionType.BLOCKING, + 5, + 1); + ProducerDescriptor producerDescriptor = + new ProducerDescriptor( + new ResourceID("tm1"), + executionAttemptId, + InetAddress.getLocalHost(), + 50000); + try (RemoteShuffleMaster shuffleMaster = createAndInitializeShuffleMaster(jobID)) { + CompletableFuture shuffleDescriptorFuture = + shuffleMaster.registerPartitionWithProducer( + jobID, partitionDescriptor, producerDescriptor); + shuffleDescriptorFuture.join(); + RemoteShuffleDescriptor shuffleDescriptor = shuffleDescriptorFuture.get(); + assertEquals(resultPartitionId, shuffleDescriptor.getResultPartitionID()); + Assert.assertEquals(shuffleResource, shuffleDescriptor.getShuffleResource()); + + shuffleMaster.releasePartitionExternally(shuffleDescriptor); + assertEquals( + new Tuple3<>( + shuffleDescriptor.getJobId(), + shuffleDescriptor.getDataSetId(), + shuffleDescriptor.getDataPartitionID()), + resourceReleaseFuture.get(TIMEOUT, TimeUnit.MILLISECONDS)); + } + } + + @Test + public void testShuffleMemoryAnnouncing() throws Exception { + try (RemoteShuffleMaster shuffleMaster = + createAndInitializeShuffleMaster(new org.apache.flink.api.common.JobID())) { + Map numberOfInputGateChannels = new HashMap<>(); + Map numbersOfResultSubpartitions = new HashMap<>(); + Map resultPartitionTypes = new HashMap<>(); + IntermediateDataSetID inputDataSetID0 = new IntermediateDataSetID(); + IntermediateDataSetID inputDataSetID1 = new IntermediateDataSetID(); + IntermediateDataSetID outputDataSetID0 = new IntermediateDataSetID(); + IntermediateDataSetID outputDataSetID1 = new IntermediateDataSetID(); + IntermediateDataSetID outputDataSetID2 = new IntermediateDataSetID(); + Random random = new Random(); + numberOfInputGateChannels.put(inputDataSetID0, random.nextInt(1000)); + numberOfInputGateChannels.put(inputDataSetID1, random.nextInt(1000)); + numbersOfResultSubpartitions.put(outputDataSetID0, random.nextInt(1000)); + numbersOfResultSubpartitions.put(outputDataSetID1, random.nextInt(1000)); + numbersOfResultSubpartitions.put(outputDataSetID2, random.nextInt(1000)); + resultPartitionTypes.put(outputDataSetID0, ResultPartitionType.BLOCKING); + resultPartitionTypes.put(outputDataSetID1, ResultPartitionType.BLOCKING); + resultPartitionTypes.put(outputDataSetID2, ResultPartitionType.BLOCKING); + MemorySize calculated = + shuffleMaster.computeShuffleMemorySizeForTask( + TaskInputsOutputsDescriptor.from( + numberOfInputGateChannels, + numbersOfResultSubpartitions, + resultPartitionTypes)); + + long numBytesPerGate = + configuration.getMemorySize(PluginOptions.MEMORY_PER_INPUT_GATE).getBytes(); + long expectedInput = 2 * numBytesPerGate; + + long numBytesPerResultPartition = + configuration + .getMemorySize(PluginOptions.MEMORY_PER_RESULT_PARTITION) + .getBytes(); + long expectedOutput = 3 * numBytesPerResultPartition; + MemorySize expected = new MemorySize(expectedInput + expectedOutput); + + assertEquals(expected, calculated); + } + } + + private RemoteShuffleMaster createAndInitializeShuffleMaster( + org.apache.flink.api.common.JobID jobID) throws Exception { + RemoteShuffleMaster shuffleMaster = + new RemoteShuffleMaster( + new ShuffleMasterContext() { + @Override + public Configuration getConfiguration() { + return ConfigurationUtils.toFlinkConfiguration(configuration); + } + + @Override + public void onFatalError(Throwable throwable) { + System.exit(-100); + } + }) { + @Override + protected RemoteShuffleRpcService createRpcService() throws Exception { + return rpcService; + } + }; + + shuffleMaster.start(); + shuffleMaster.registerJob( + new JobShuffleContext() { + @Override + public org.apache.flink.api.common.JobID getJobId() { + return jobID; + } + + @Override + public CompletableFuture stopTrackingAndReleasePartitions( + Collection collection) { + return CompletableFuture.completedFuture(null); + } + }); + + return shuffleMaster; + } +} diff --git a/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleShuffleTestBase.java b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleShuffleTestBase.java new file mode 100644 index 00000000..d796bed3 --- /dev/null +++ b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/RemoteShuffleShuffleTestBase.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServices; +import com.alibaba.flink.shuffle.coordinator.highavailability.HaServicesFactory; +import com.alibaba.flink.shuffle.coordinator.highavailability.LeaderInformation; +import com.alibaba.flink.shuffle.coordinator.highavailability.TestingHaServices; +import com.alibaba.flink.shuffle.coordinator.leaderretrieval.SettableLeaderRetrievalService; +import com.alibaba.flink.shuffle.coordinator.utils.TestingShuffleManagerGateway; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; +import com.alibaba.flink.shuffle.rpc.test.TestingRpcService; +import com.alibaba.flink.shuffle.rpc.utils.RpcUtils; + +import org.apache.flink.runtime.testutils.DirectScheduledExecutorService; +import org.apache.flink.runtime.util.TestingFatalErrorHandler; + +import org.junit.After; +import org.junit.Before; + +import java.util.concurrent.Executor; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** Base class for the remote shuffle test. */ +public class RemoteShuffleShuffleTestBase { + + public static final long TIMEOUT = 10000L; + + protected TestingRpcService rpcService; + + protected Configuration configuration; + + protected TestingFatalErrorHandler testingFatalErrorHandler; + + protected Executor mainThreadExecutor; + + protected TestingShuffleManagerGateway smGateway; + + @Before + public void setup() throws Exception { + rpcService = new TestingRpcService(); + + configuration = new Configuration(); + configuration.setString( + HighAvailabilityOptions.HA_MODE, TestingHaServiceFactory.class.getName()); + + testingFatalErrorHandler = new TestingFatalErrorHandler(); + mainThreadExecutor = new DirectScheduledExecutorService(); + + smGateway = new TestingShuffleManagerGateway(); + rpcService.registerGateway(smGateway.getAddress(), smGateway); + + TestingHaServiceFactory.shuffleManagerLeaderRetrieveService = + new SettableLeaderRetrievalService(); + TestingHaServiceFactory.shuffleManagerLeaderRetrieveService.notifyListener( + new LeaderInformation(smGateway.getFencingToken(), smGateway.getAddress())); + } + + @After + public void teardown() throws Exception { + if (rpcService != null) { + RpcUtils.terminateRpcService(rpcService, TIMEOUT); + rpcService = null; + } + } + + /** A testing {@link HaServicesFactory} implementation. */ + public static class TestingHaServiceFactory implements HaServicesFactory { + static SettableLeaderRetrievalService shuffleManagerLeaderRetrieveService; + + @Override + public HaServices createHAServices(Configuration configuration) throws Exception { + checkNotNull(shuffleManagerLeaderRetrieveService); + TestingHaServices haServices = new TestingHaServices(); + haServices.setShuffleManagerLeaderRetrieveService(shuffleManagerLeaderRetrieveService); + return haServices; + } + } +} diff --git a/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/itcase/BatchJobITCaseBase.java b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/itcase/BatchJobITCaseBase.java new file mode 100644 index 00000000..6fd70830 --- /dev/null +++ b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/itcase/BatchJobITCaseBase.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.itcase; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.config.WorkerOptions; +import com.alibaba.flink.shuffle.minicluster.ShuffleMiniCluster; +import com.alibaba.flink.shuffle.minicluster.ShuffleMiniClusterConfiguration; +import com.alibaba.flink.shuffle.plugin.RemoteShuffleServiceFactory; +import com.alibaba.flink.shuffle.plugin.config.PluginOptions; + +import org.apache.flink.api.common.RuntimeExecutionMode; +import org.apache.flink.configuration.ExecutionOptions; +import org.apache.flink.configuration.JobManagerOptions; +import org.apache.flink.configuration.MemorySize; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.configuration.TaskManagerOptions; +import org.apache.flink.runtime.highavailability.HighAvailabilityServices; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.TestingMiniCluster; +import org.apache.flink.runtime.minicluster.TestingMiniClusterConfiguration; +import org.apache.flink.runtime.shuffle.ShuffleServiceOptions; +import org.apache.flink.util.ExceptionUtils; + +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.TemporaryFolder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.InetAddress; +import java.util.function.Supplier; + +/** A base class for batch job cases which using the remote shuffle. */ +public abstract class BatchJobITCaseBase { + + protected final Logger log = LoggerFactory.getLogger(getClass()); + + @Rule public final TemporaryFolder temporaryFolder = new TemporaryFolder(); + + protected int numShuffleWorkers = 4; + + protected int numTaskManagers = 4; + + protected int numSlotsPerTaskManager = 4; + + protected final Configuration configuration = new Configuration(); + + protected final org.apache.flink.configuration.Configuration flinkConfiguration = + new org.apache.flink.configuration.Configuration(); + + protected MiniCluster flinkCluster; + + protected ShuffleMiniCluster shuffleCluster; + + protected Supplier highAvailabilityServicesSupplier = null; + + @Before + public void before() throws Exception { + // basic configuration + String address = InetAddress.getLocalHost().getHostAddress(); + configuration.setString( + StorageOptions.STORAGE_LOCAL_DATA_DIRS, + temporaryFolder.getRoot().getAbsolutePath()); + configuration.setString(ManagerOptions.RPC_ADDRESS, address); + configuration.setString(ManagerOptions.RPC_BIND_ADDRESS, address); + configuration.setString(WorkerOptions.BIND_HOST, address); + configuration.setString(WorkerOptions.HOST, address); + configuration.setInteger(ManagerOptions.RPC_PORT, ManagerOptions.RPC_PORT.defaultValue()); + configuration.setInteger( + ManagerOptions.RPC_BIND_PORT, ManagerOptions.RPC_PORT.defaultValue()); + + // flink basic configuration. + flinkConfiguration.set(ExecutionOptions.RUNTIME_MODE, RuntimeExecutionMode.BATCH); + flinkConfiguration.setString( + ShuffleServiceOptions.SHUFFLE_SERVICE_FACTORY_CLASS, + RemoteShuffleServiceFactory.class.getName()); + flinkConfiguration.setString(ManagerOptions.RPC_ADDRESS.key(), address); + flinkConfiguration.setLong(JobManagerOptions.SLOT_REQUEST_TIMEOUT, 5000L); + flinkConfiguration.setString(RestOptions.BIND_PORT, "0"); + flinkConfiguration.set(TaskManagerOptions.TOTAL_PROCESS_MEMORY, MemorySize.parse("1g")); + flinkConfiguration.set(TaskManagerOptions.NETWORK_MEMORY_FRACTION, 0.4F); + flinkConfiguration.setString(PluginOptions.MEMORY_PER_INPUT_GATE.key(), "8m"); + flinkConfiguration.setString(PluginOptions.MEMORY_PER_RESULT_PARTITION.key(), "8m"); + + // setup special config. + setup(); + + ShuffleMiniClusterConfiguration clusterConf = + new ShuffleMiniClusterConfiguration.Builder() + .setConfiguration(configuration) + .setNumShuffleWorkers(numShuffleWorkers) + .setCommonBindAddress(address) + .build(); + shuffleCluster = new ShuffleMiniCluster(clusterConf); + shuffleCluster.start(); + + TestingMiniClusterConfiguration miniClusterConfiguration = + TestingMiniClusterConfiguration.newBuilder() + .setConfiguration(flinkConfiguration) + .setNumTaskManagers(numTaskManagers) + .setNumSlotsPerTaskManager(numSlotsPerTaskManager) + .build(); + + flinkCluster = + new TestingMiniCluster(miniClusterConfiguration, highAvailabilityServicesSupplier); + flinkCluster.start(); + } + + @After + public void after() { + Throwable exception = null; + + try { + flinkCluster.close(); + } catch (Throwable throwable) { + exception = throwable; + } + + try { + shuffleCluster.close(); + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + } + + if (exception != null) { + ExceptionUtils.rethrow(exception); + } + } + + abstract void setup(); +} diff --git a/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/itcase/WordCountITCase.java b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/itcase/WordCountITCase.java new file mode 100644 index 00000000..396c5d88 --- /dev/null +++ b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/itcase/WordCountITCase.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.itcase; + +import org.apache.flink.api.common.ExecutionMode; +import org.apache.flink.api.common.InputDependencyConstraint; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.JobType; +import org.apache.flink.runtime.jobmaster.JobResult; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; +import org.apache.flink.streaming.api.graph.GlobalStreamExchangeMode; +import org.apache.flink.streaming.api.graph.StreamGraph; +import org.apache.flink.streaming.api.graph.StreamingJobGraphGenerator; +import org.apache.flink.util.Collector; + +import org.junit.Test; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static org.junit.Assert.assertEquals; + +/** A simple word-count integration test. */ +public class WordCountITCase extends BatchJobITCaseBase { + + private static final int NUM_WORDS = 20; + + private static final int WORD_COUNT = 200; + + @Override + public void setup() {} + + @Test + public void testWordCount() throws Exception { + StreamExecutionEnvironment env = + StreamExecutionEnvironment.getExecutionEnvironment(flinkConfiguration); + + int parallelism = numTaskManagers * numSlotsPerTaskManager; + env.getConfig().setExecutionMode(ExecutionMode.BATCH); + env.getConfig().setParallelism(parallelism); + env.getConfig().setDefaultInputDependencyConstraint(InputDependencyConstraint.ALL); + env.disableOperatorChaining(); + + DataStream> words = + env.fromSequence(0, NUM_WORDS) + .broadcast() + .map(new WordsMapper()) + .flatMap(new WordsFlatMapper(WORD_COUNT)); + words.keyBy(value -> value.f0) + .sum(1) + .map((MapFunction, Long>) wordCount -> wordCount.f1) + .addSink(new VerifySink((long) parallelism * WORD_COUNT)); + + StreamGraph streamGraph = env.getStreamGraph(); + streamGraph.setGlobalStreamExchangeMode(GlobalStreamExchangeMode.ALL_EDGES_BLOCKING); + streamGraph.setJobType(JobType.BATCH); + JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(streamGraph); + + JobID jobID = flinkCluster.submitJob(jobGraph).get().getJobID(); + JobResult jobResult = flinkCluster.requestJobResult(jobID).get(); + if (jobResult.getSerializedThrowable().isPresent()) { + throw new AssertionError(jobResult.getSerializedThrowable().get()); + } + } + + private static class WordsMapper implements MapFunction { + + private static final long serialVersionUID = -896627105414186948L; + + private static final String WORD_SUFFIX_1K = getWordSuffix1k(); + + private static String getWordSuffix1k() { + StringBuilder builder = new StringBuilder(); + builder.append("-"); + for (int i = 0; i < 1024; ++i) { + builder.append("0"); + } + return builder.toString(); + } + + @Override + public String map(Long value) { + return "WORD-" + value + WORD_SUFFIX_1K; + } + } + + private static class WordsFlatMapper implements FlatMapFunction> { + + private static final long serialVersionUID = 7873046672795114433L; + + private final int wordsCount; + + public WordsFlatMapper(int wordsCount) { + checkArgument(wordsCount > 0, "Must be positive."); + this.wordsCount = wordsCount; + } + + @Override + public void flatMap(String word, Collector> collector) { + for (int i = 0; i < wordsCount; ++i) { + collector.collect(new Tuple2<>(word, 1L)); + } + } + } + + private static class VerifySink implements SinkFunction { + + private static final long serialVersionUID = -1975623991098131708L; + + private final Long wordCount; + + public VerifySink(long wordCount) { + this.wordCount = wordCount; + } + + @Override + public void invoke(Long value, Context context) { + assertEquals(wordCount, value); + } + } +} diff --git a/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/BufferPackerTest.java b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/BufferPackerTest.java new file mode 100644 index 00000000..78340a5d --- /dev/null +++ b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/BufferPackerTest.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import com.alibaba.flink.shuffle.common.functions.BiConsumerWithException; +import com.alibaba.flink.shuffle.plugin.utils.BufferUtils; + +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType.DATA_BUFFER; +import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType.EVENT_BUFFER; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** Test for {@link BufferPacker}. */ +public class BufferPackerTest { + + private static final int BUFFER_SIZE = 20; + + private NetworkBufferPool networkBufferPool; + + private BufferPool bufferPool; + + @Before + public void setup() throws Exception { + networkBufferPool = new NetworkBufferPool(10, BUFFER_SIZE); + bufferPool = networkBufferPool.createBufferPool(10, 10); + } + + @After + public void tearDown() { + bufferPool.lazyDestroy(); + assertEquals(10, networkBufferPool.getNumberOfAvailableMemorySegments()); + networkBufferPool.destroy(); + } + + @Test + public void testPackEmptyBuffers() throws Exception { + List buffers = requestBuffers(3); + setCompressed(buffers, true, true, false); + setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER); + + Integer subIdx = 2; + + List output = new ArrayList<>(); + BiConsumerWithException ripeBufferHandler = + (ripe, sub) -> { + assertEquals(subIdx, sub); + output.add(ripe); + }; + + BufferPacker packer = new BufferPacker(ripeBufferHandler); + packer.process(buffers.get(0), subIdx); + packer.process(buffers.get(1), subIdx); + packer.process(buffers.get(2), subIdx); + assertTrue(output.isEmpty()); + + packer.drain(); + assertEquals(0, output.size()); + } + + @Test + public void testPartialBuffersForSameSubIdx() throws Exception { + List buffers = requestBuffers(3); + setCompressed(buffers, true, true, false); + setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER); + + List> output = new ArrayList<>(); + BiConsumerWithException ripeBufferHandler = + (ripe, sub) -> output.add(Pair.of(ripe, sub)); + BufferPacker packer = new BufferPacker(ripeBufferHandler); + fillBuffers(buffers, 0, 1, 2); + + packer.process(buffers.get(0), 2); + packer.process(buffers.get(1), 2); + assertEquals(0, output.size()); + + packer.process(buffers.get(2), 2); + assertEquals(1, output.size()); + + packer.drain(); + assertEquals(2, output.size()); + + List unpacked = new ArrayList<>(); + output.forEach( + pair -> { + assertEquals(Integer.valueOf(2), pair.getRight()); + unpacked.addAll(BufferPacker.unpack(pair.getLeft())); + }); + checkIfCompressed(unpacked, true, true, false); + checkDataType(unpacked, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER); + verifyBuffers(unpacked, 0, 1, 2); + unpacked.forEach(Buffer::recycleBuffer); + } + + @Test + public void testPartialBuffersForMultipleSubIdx() throws Exception { + List buffers = requestBuffers(3); + setCompressed(buffers, true, true, false); + setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER); + + List> output = new ArrayList<>(); + BiConsumerWithException ripeBufferHandler = + (ripe, sub) -> output.add(Pair.of(ripe, sub)); + BufferPacker packer = new BufferPacker(ripeBufferHandler); + fillBuffers(buffers, 0, 1, 2); + + packer.process(buffers.get(0), 0); + packer.process(buffers.get(1), 1); + assertEquals(1, output.size()); + + packer.process(buffers.get(2), 1); + assertEquals(1, output.size()); + + packer.drain(); + assertEquals(2, output.size()); + + List unpacked = new ArrayList<>(); + for (int i = 0; i < output.size(); i++) { + Pair pair = output.get(i); + assertEquals(Integer.valueOf(i), pair.getRight()); + unpacked.addAll(BufferPacker.unpack(pair.getLeft())); + } + + checkIfCompressed(unpacked, true, true, false); + checkDataType(unpacked, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER); + verifyBuffers(unpacked, 0, 1, 2); + unpacked.forEach(Buffer::recycleBuffer); + } + + @Test + public void testUnpackedBuffers() throws Exception { + List buffers = requestBuffers(3); + setCompressed(buffers, true, true, false); + setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER); + + List> output = new ArrayList<>(); + BiConsumerWithException ripeBufferHandler = + (ripe, sub) -> output.add(Pair.of(ripe, sub)); + BufferPacker packer = new BufferPacker(ripeBufferHandler); + fillBuffers(buffers, 0, 1, 2); + + packer.process(buffers.get(0), 0); + packer.process(buffers.get(1), 1); + assertEquals(1, output.size()); + + packer.process(buffers.get(2), 2); + assertEquals(2, output.size()); + + packer.drain(); + assertEquals(3, output.size()); + + List unpacked = new ArrayList<>(); + for (int i = 0; i < output.size(); i++) { + Pair pair = output.get(i); + assertEquals(Integer.valueOf(i), pair.getRight()); + unpacked.addAll(BufferPacker.unpack(pair.getLeft())); + } + + checkIfCompressed(unpacked, true, true, false); + checkDataType(unpacked, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER); + verifyBuffers(unpacked, 0, 1, 2); + unpacked.forEach(Buffer::recycleBuffer); + } + + private List requestBuffers(int n) { + List buffers = new ArrayList<>(); + for (int i = 0; i < n; i++) { + Buffer buffer = bufferPool.requestBuffer(); + buffers.add(buffer); + } + return buffers; + } + + private void setCompressed(List buffers, boolean... values) { + for (int i = 0; i < buffers.size(); i++) { + buffers.get(i).setCompressed(values[i]); + } + } + + private void setDataType(List buffers, Buffer.DataType... values) { + for (int i = 0; i < buffers.size(); i++) { + buffers.get(i).setDataType(values[i]); + } + } + + private void checkIfCompressed(List buffers, boolean... values) { + for (int i = 0; i < buffers.size(); i++) { + assertEquals(values[i], buffers.get(i).isCompressed()); + } + } + + private void checkDataType(List buffers, Buffer.DataType... values) { + for (int i = 0; i < buffers.size(); i++) { + assertEquals(values[i], buffers.get(i).getDataType()); + } + } + + private void fillBuffers(List buffers, int... ints) { + for (int i = 0; i < buffers.size(); i++) { + Buffer buffer = buffers.get(i); + ByteBuf target = buffer.asByteBuf(); + BufferUtils.setBufferHeader(target, buffer.getDataType(), buffer.isCompressed(), 4); + target.writerIndex(BufferUtils.HEADER_LENGTH); + target.writeInt(ints[i]); + } + } + + private void verifyBuffers(List buffers, int... expects) { + for (int i = 0; i < buffers.size(); i++) { + ByteBuf actual = buffers.get(i).asByteBuf(); + assertEquals(expects[i], actual.getInt(0)); + } + } +} diff --git a/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/PartitionSortedBufferTest.java b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/PartitionSortedBufferTest.java new file mode 100644 index 00000000..8e3bd0a0 --- /dev/null +++ b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/PartitionSortedBufferTest.java @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; + +import org.junit.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Queue; +import java.util.Random; + +import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** Tests for {@link PartitionSortedBuffer}. */ +public class PartitionSortedBufferTest { + + @Test + public void testWriteAndReadSortBuffer() throws Exception { + int numSubpartitions = 10; + int bufferSize = 1024; + int bufferPoolSize = 1000; + Random random = new Random(1111); + + // used to store data written to and read from sort buffer for correctness check + Queue[] dataWritten = new Queue[numSubpartitions]; + Queue[] buffersRead = new Queue[numSubpartitions]; + for (int i = 0; i < numSubpartitions; ++i) { + dataWritten[i] = new ArrayDeque<>(); + buffersRead[i] = new ArrayDeque<>(); + } + + int[] numBytesWritten = new int[numSubpartitions]; + int[] numBytesRead = new int[numSubpartitions]; + Arrays.fill(numBytesWritten, 0); + Arrays.fill(numBytesRead, 0); + + // fill the sort buffer with randomly generated data + int totalBytesWritten = 0; + SortBuffer sortBuffer = + createSortBuffer( + bufferPoolSize, + bufferSize, + numSubpartitions, + getRandomSubpartitionOrder(numSubpartitions)); + while (true) { + // record size may be larger than buffer size so a record may span multiple segments + int recordSize = random.nextInt(bufferSize * 4 - 1) + 1; + byte[] bytes = new byte[recordSize]; + + // fill record with random value + random.nextBytes(bytes); + ByteBuffer record = ByteBuffer.wrap(bytes); + + // select a random subpartition to write + int subpartition = random.nextInt(numSubpartitions); + + // select a random data type + boolean isBuffer = random.nextBoolean() || recordSize > bufferSize; + DataType dataType = isBuffer ? DataType.DATA_BUFFER : DataType.EVENT_BUFFER; + if (!sortBuffer.append(record, subpartition, dataType)) { + sortBuffer.finish(); + break; + } + record.rewind(); + dataWritten[subpartition].add(new DataAndType(record, dataType)); + numBytesWritten[subpartition] += recordSize; + totalBytesWritten += recordSize; + } + + // read all data from the sort buffer + while (sortBuffer.hasRemaining()) { + MemorySegment readBuffer = MemorySegmentFactory.allocateUnpooledSegment(bufferSize); + SortBuffer.BufferWithChannel bufferAndChannel = + sortBuffer.copyIntoSegment(readBuffer, ignore -> {}, 0); + int subpartition = bufferAndChannel.getChannelIndex(); + buffersRead[subpartition].add(bufferAndChannel.getBuffer()); + numBytesRead[subpartition] += bufferAndChannel.getBuffer().readableBytes(); + } + + assertEquals(totalBytesWritten, sortBuffer.numBytes()); + checkWriteReadResult( + numSubpartitions, numBytesWritten, numBytesRead, dataWritten, buffersRead); + } + + public static void checkWriteReadResult( + int numSubpartitions, + int[] numBytesWritten, + int[] numBytesRead, + Queue[] dataWritten, + Collection[] buffersRead) { + for (int subpartitionIndex = 0; subpartitionIndex < numSubpartitions; ++subpartitionIndex) { + assertEquals(numBytesWritten[subpartitionIndex], numBytesRead[subpartitionIndex]); + + List eventsWritten = new ArrayList<>(); + List eventsRead = new ArrayList<>(); + + ByteBuffer subpartitionDataWritten = + ByteBuffer.allocate(numBytesWritten[subpartitionIndex]); + for (DataAndType dataAndType : dataWritten[subpartitionIndex]) { + subpartitionDataWritten.put(dataAndType.data); + dataAndType.data.rewind(); + if (dataAndType.dataType.isEvent()) { + eventsWritten.add(dataAndType); + } + } + + ByteBuffer subpartitionDataRead = ByteBuffer.allocate(numBytesRead[subpartitionIndex]); + for (Buffer buffer : buffersRead[subpartitionIndex]) { + subpartitionDataRead.put(buffer.getNioBufferReadable()); + if (!buffer.isBuffer()) { + eventsRead.add(buffer); + } + } + + subpartitionDataWritten.flip(); + subpartitionDataRead.flip(); + assertEquals(subpartitionDataWritten, subpartitionDataRead); + + assertEquals(eventsWritten.size(), eventsRead.size()); + for (int i = 0; i < eventsWritten.size(); ++i) { + assertEquals(eventsWritten.get(i).dataType, eventsRead.get(i).getDataType()); + assertEquals(eventsWritten.get(i).data, eventsRead.get(i).getNioBufferReadable()); + } + } + } + + @Test + public void testWriteReadWithEmptyChannel() throws Exception { + int bufferPoolSize = 10; + int bufferSize = 1024; + int numSubpartitions = 5; + + ByteBuffer[] subpartitionRecords = { + ByteBuffer.allocate(128), + null, + ByteBuffer.allocate(1536), + null, + ByteBuffer.allocate(1024) + }; + + SortBuffer sortBuffer = createSortBuffer(bufferPoolSize, bufferSize, numSubpartitions); + for (int subpartition = 0; subpartition < numSubpartitions; ++subpartition) { + ByteBuffer record = subpartitionRecords[subpartition]; + if (record != null) { + sortBuffer.append(record, subpartition, Buffer.DataType.DATA_BUFFER); + record.rewind(); + } + } + sortBuffer.finish(); + + checkReadResult(sortBuffer, subpartitionRecords[0], 0, bufferSize); + + ByteBuffer expected1 = subpartitionRecords[2].duplicate(); + expected1.limit(bufferSize); + checkReadResult(sortBuffer, expected1.slice(), 2, bufferSize); + + ByteBuffer expected2 = subpartitionRecords[2].duplicate(); + expected2.position(bufferSize); + checkReadResult(sortBuffer, expected2.slice(), 2, bufferSize); + + checkReadResult(sortBuffer, subpartitionRecords[4], 4, bufferSize); + } + + private void checkReadResult( + SortBuffer sortBuffer, ByteBuffer expectedBuffer, int expectedChannel, int bufferSize) { + MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(bufferSize); + SortBuffer.BufferWithChannel bufferWithChannel = + sortBuffer.copyIntoSegment(segment, ignore -> {}, 0); + assertEquals(expectedChannel, bufferWithChannel.getChannelIndex()); + assertEquals(expectedBuffer, bufferWithChannel.getBuffer().getNioBufferReadable()); + } + + @Test(expected = IllegalArgumentException.class) + public void testWriteEmptyData() throws Exception { + int bufferSize = 1024; + + SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1); + + ByteBuffer record = ByteBuffer.allocate(1); + record.position(1); + + sortBuffer.append(record, 0, Buffer.DataType.DATA_BUFFER); + } + + @Test(expected = IllegalStateException.class) + public void testWriteFinishedSortBuffer() throws Exception { + int bufferSize = 1024; + + SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1); + sortBuffer.finish(); + + sortBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER); + } + + @Test(expected = IllegalStateException.class) + public void testWriteReleasedSortBuffer() throws Exception { + int bufferSize = 1024; + + SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1); + sortBuffer.release(); + + sortBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER); + } + + @Test + public void testWriteMoreDataThanCapacity() throws Exception { + int bufferPoolSize = 10; + int bufferSize = 1024; + + SortBuffer sortBuffer = createSortBuffer(bufferPoolSize, bufferSize, 1); + + for (int i = 1; i < bufferPoolSize; ++i) { + appendAndCheckResult(sortBuffer, bufferSize, true, bufferSize * i, i, true); + } + + // append should fail for insufficient capacity + int numRecords = bufferPoolSize - 1; + appendAndCheckResult( + sortBuffer, bufferSize, false, bufferSize * numRecords, numRecords, true); + } + + @Test + public void testWriteLargeRecord() throws Exception { + int bufferPoolSize = 10; + int bufferSize = 1024; + + SortBuffer sortBuffer = createSortBuffer(bufferPoolSize, bufferSize, 1); + // append should fail for insufficient capacity + appendAndCheckResult(sortBuffer, bufferPoolSize * bufferSize, false, 0, 0, false); + } + + private void appendAndCheckResult( + SortBuffer sortBuffer, + int recordSize, + boolean isSuccessful, + long numBytes, + long numRecords, + boolean hasRemaining) + throws IOException { + ByteBuffer largeRecord = ByteBuffer.allocate(recordSize); + + assertEquals(isSuccessful, sortBuffer.append(largeRecord, 0, Buffer.DataType.DATA_BUFFER)); + assertEquals(numBytes, sortBuffer.numBytes()); + assertEquals(numRecords, sortBuffer.numRecords()); + assertEquals(hasRemaining, sortBuffer.hasRemaining()); + } + + @Test(expected = IllegalStateException.class) + public void testReadUnfinishedSortBuffer() throws Exception { + int bufferSize = 1024; + + SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1); + sortBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER); + + assertTrue(sortBuffer.hasRemaining()); + sortBuffer.copyIntoSegment( + MemorySegmentFactory.allocateUnpooledSegment(bufferSize), ignore -> {}, 0); + } + + @Test(expected = IllegalStateException.class) + public void testReadReleasedSortBuffer() throws Exception { + int bufferSize = 1024; + + SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1); + sortBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER); + sortBuffer.finish(); + assertTrue(sortBuffer.hasRemaining()); + + sortBuffer.release(); + assertFalse(sortBuffer.hasRemaining()); + + sortBuffer.copyIntoSegment( + MemorySegmentFactory.allocateUnpooledSegment(bufferSize), ignore -> {}, 0); + } + + @Test(expected = IllegalStateException.class) + public void testReadEmptySortBuffer() throws Exception { + int bufferSize = 1024; + + SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1); + sortBuffer.finish(); + + assertFalse(sortBuffer.hasRemaining()); + sortBuffer.copyIntoSegment( + MemorySegmentFactory.allocateUnpooledSegment(bufferSize), ignore -> {}, 0); + } + + @Test + public void testReleaseSortBuffer() throws Exception { + int bufferPoolSize = 10; + int bufferSize = 1024; + int recordSize = (bufferPoolSize - 1) * bufferSize; + + NetworkBufferPool globalPool = new NetworkBufferPool(bufferPoolSize, bufferSize); + BufferPool bufferPool = globalPool.createBufferPool(bufferPoolSize, bufferPoolSize); + + SortBuffer sortBuffer = new PartitionSortedBuffer(bufferPool, 1, bufferSize, null); + sortBuffer.append(ByteBuffer.allocate(recordSize), 0, Buffer.DataType.DATA_BUFFER); + + assertEquals(bufferPoolSize, bufferPool.bestEffortGetNumOfUsedBuffers()); + assertTrue(sortBuffer.hasRemaining()); + assertEquals(1, sortBuffer.numRecords()); + assertEquals(recordSize, sortBuffer.numBytes()); + + // should release all data and resources + sortBuffer.release(); + assertEquals(0, bufferPool.bestEffortGetNumOfUsedBuffers()); + assertFalse(sortBuffer.hasRemaining()); + assertEquals(0, sortBuffer.numRecords()); + assertEquals(0, sortBuffer.numBytes()); + } + + private SortBuffer createSortBuffer(int bufferPoolSize, int bufferSize, int numSubpartitions) + throws IOException { + return createSortBuffer(bufferPoolSize, bufferSize, numSubpartitions, null); + } + + private SortBuffer createSortBuffer( + int bufferPoolSize, int bufferSize, int numSubpartitions, int[] customReadOrder) + throws IOException { + NetworkBufferPool globalPool = new NetworkBufferPool(bufferPoolSize, bufferSize); + BufferPool bufferPool = globalPool.createBufferPool(bufferPoolSize, bufferPoolSize); + + return new PartitionSortedBuffer(bufferPool, numSubpartitions, bufferSize, customReadOrder); + } + + public static int[] getRandomSubpartitionOrder(int numSubpartitions) { + Random random = new Random(1111); + int[] subpartitionReadOrder = new int[numSubpartitions]; + int shift = random.nextInt(numSubpartitions); + for (int i = 0; i < numSubpartitions; ++i) { + subpartitionReadOrder[i] = (i + shift) % numSubpartitions; + } + return subpartitionReadOrder; + } + + /** Data written and its {@link Buffer.DataType}. */ + public static class DataAndType { + private final ByteBuffer data; + private final Buffer.DataType dataType; + + DataAndType(ByteBuffer data, Buffer.DataType dataType) { + this.data = data; + this.dataType = dataType; + } + } +} diff --git a/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/ReadingIntegrationTest.java b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/ReadingIntegrationTest.java new file mode 100644 index 00000000..636df51d --- /dev/null +++ b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/ReadingIntegrationTest.java @@ -0,0 +1,407 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.coordinator.manager.DefaultShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerDescriptor; +import com.alibaba.flink.shuffle.core.config.TransferOptions; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReadingView; +import com.alibaba.flink.shuffle.core.storage.ReadingViewContext; +import com.alibaba.flink.shuffle.plugin.RemoteShuffleDescriptor; +import com.alibaba.flink.shuffle.plugin.utils.BufferUtils; +import com.alibaba.flink.shuffle.transfer.AbstractNettyTest; +import com.alibaba.flink.shuffle.transfer.ConnectionManager; +import com.alibaba.flink.shuffle.transfer.FakedDataPartitionReadingView; +import com.alibaba.flink.shuffle.transfer.NettyConfig; +import com.alibaba.flink.shuffle.transfer.NettyServer; +import com.alibaba.flink.shuffle.transfer.TestTransferBufferPool; +import com.alibaba.flink.shuffle.transfer.utils.NoOpPartitionedDataStore; + +import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; +import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Integration test for reading process. */ +@RunWith(Parameterized.class) +public class ReadingIntegrationTest { + + private String localhost; + + private Random random; + + private List dataStores; + + private List nettyServers; + + private List workerDescs; + + private NetworkBufferPool networkBufferPool; + + private final int upstreamParallelism; + + private final int numConcurrentReading; + + // Num of input gates share a ConnectionManager. + private final int groupSize; + + private final int downStreamParallelism; + + private final int numShuffleWorkers; + + // Num of buffers per reading channel. + private final int dataScalePerReadingView; + + private final int buffersPerClientChannelBufferPool; + + public ReadingIntegrationTest( + int upstreamParallelism, + int numConcurrentReading, + int groupSize, + int downStreamParallelism, + int numShuffleWorkers, + int dataScalePerReadingView, + int buffersPerClientChannelBufferPool) { + this.upstreamParallelism = upstreamParallelism; + this.numConcurrentReading = numConcurrentReading; + this.groupSize = groupSize; + this.downStreamParallelism = downStreamParallelism; + this.numShuffleWorkers = numShuffleWorkers; + this.dataScalePerReadingView = dataScalePerReadingView; + this.buffersPerClientChannelBufferPool = buffersPerClientChannelBufferPool; + } + + @Parameterized.Parameters + public static Collection data() { + return Arrays.asList( + new Object[][] { + {10, 10, 1, 8, 3, 100, 10}, + {10, 10, 1, 8, 3, 0, 10}, + {10, 10, 2, 8, 3, 100, 10}, + {10, 10, 8, 8, 3, 100, 10}, + {1200, 10, 4, 8, 3, 100, 10}, + {1200, 10, 4, 8, 3, 100, 50}, + {1200, 10, 1, 8, 3, 100, 100}, + {1200, 10, 4, 8, 3, 100, 100} + }); + } + + @Before + public void setup() throws Exception { + random = new Random(); + networkBufferPool = new NetworkBufferPool(20000, 32); + if (localhost == null) { + localhost = InetAddress.getLocalHost().getHostAddress(); + } + } + + @After + public void tearDown() { + assertEquals(20000, networkBufferPool.getNumberOfAvailableMemorySegments()); + networkBufferPool.destroy(); + nettyServers.forEach(NettyServer::shutdown); + } + + @Test(timeout = 600_000) + public void test() throws Exception { + startServers(numShuffleWorkers, dataScalePerReadingView); + + List connMgrs = new ArrayList<>(); + List readingThreads = new ArrayList<>(); + AtomicInteger buffersCounter = new AtomicInteger(0); + AtomicReference cause = new AtomicReference<>(null); + for (int i = 0; i < downStreamParallelism; i += groupSize) { + List bufferPools = new ArrayList<>(groupSize); + for (int j = 0; j < groupSize; j++) { + int numBuffers = buffersPerClientChannelBufferPool * numConcurrentReading; + bufferPools.add(networkBufferPool.createBufferPool(numBuffers, numBuffers)); + } + Pair> pair = + createInputGateGroup( + bufferPools, numConcurrentReading, upstreamParallelism, workerDescs); + connMgrs.add(pair.getLeft()); + for (RemoteShuffleInputGate gate : pair.getRight()) { + readingThreads.add(new ReadingThread(gate, buffersCounter, cause)); + } + } + readingThreads.forEach(Thread::start); + for (Thread t : readingThreads) { + t.join(); + } + for (ConnectionManager connMgr : connMgrs) { + assertEquals(0, connMgr.numPhysicalConnections()); + connMgr.shutdown(); + } + int expectTotalBuffers = 0; + for (FakedPartitionDataStore dataStore : dataStores) { + expectTotalBuffers += dataStore.totalBuffersToSend.get(); + } + assertEquals(expectTotalBuffers, buffersCounter.get()); + assertNull(cause.get()); + } + + private static class ReadingThread extends Thread { + + private final RemoteShuffleInputGate gate; + + private final AtomicInteger buffersCounter; + + private final AtomicReference cause; + + ReadingThread( + RemoteShuffleInputGate gate, + AtomicInteger buffersCounter, + AtomicReference cause) + throws Exception { + this.gate = gate; + this.buffersCounter = buffersCounter; + this.cause = cause; + } + + @Override + public void run() { + try { + gate.setup(); + while (!gate.isFinished()) { + while (true) { + Optional got = gate.pollNext(); + if (got.isPresent()) { + if (got.get().isBuffer()) { + got.get().getBuffer().recycleBuffer(); + buffersCounter.incrementAndGet(); + } else if (got.get().getEvent() instanceof EndOfPartitionEvent) { + break; + } + } else { + Thread.sleep(5); + } + } + } + } catch (Exception e) { + e.printStackTrace(); + } finally { + try { + gate.close(); + } catch (Throwable t) { + cause.set(t); + } + } + } + } + + private static class FeedingThread extends Thread { + + private final FakedDataPartitionReadingView readingView; + private final int numBuffersToSend; + + FeedingThread(FakedDataPartitionReadingView readingView, int numBuffersToSend) { + this.readingView = readingView; + this.numBuffersToSend = numBuffersToSend; + } + + @Override + public void run() { + TestTransferBufferPool serverBufferPool = new TestTransferBufferPool(20, 64); + try { + readingView.notifyBacklog(numBuffersToSend + 1); + for (int i = 0; i <= numBuffersToSend; i++) { + if (readingView.getError() != null) { + break; + } + ByteBuf buffer = serverBufferPool.requestBufferBlocking(); + if (i == numBuffersToSend) { + fillAsEOF(buffer); + readingView.setNoMoreData(true); + } else { + fillAsBuffer(buffer); + } + readingView.notifyBuffer(buffer); + } + } catch (Exception e) { + e.printStackTrace(); + } finally { + serverBufferPool.destroy(); + } + } + + private void fillAsBuffer(ByteBuf buffer) { + BufferUtils.setBufferHeader(buffer, Buffer.DataType.DATA_BUFFER, false, 4); + buffer.writeInt(0); + } + + private void fillAsEOF(ByteBuf buffer) throws IOException { + ByteBuf serialized = + EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE, false).asByteBuf(); + BufferUtils.setBufferHeader( + buffer, Buffer.DataType.EVENT_BUFFER, false, serialized.readableBytes()); + buffer.writeBytes(serialized); + } + } + + private Pair> createInputGateGroup( + List clientBufferPools, + int numConcurrentReading, + int upstreamParallelism, + List descs) + throws Exception { + NettyConfig nettyConfig = new NettyConfig(new Configuration()); + ConnectionManager connMgr = + ConnectionManager.createReadConnectionManager(nettyConfig, true); + connMgr.start(); + List gates = new ArrayList<>(); + for (BufferPool bufferPool : clientBufferPools) { + RemoteShuffleInputGate gate = + createInputGate( + numConcurrentReading, connMgr, upstreamParallelism, bufferPool, descs); + gates.add(gate); + } + return Pair.of(connMgr, gates); + } + + private RemoteShuffleInputGate createInputGate( + int numConcurrentReading, + ConnectionManager connMgr, + int upstreamParallelism, + BufferPool bufferPool, + List descs) + throws IOException { + String taskName = "ReadingIntegrationTest"; + IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID(); + + InputGateDeploymentDescriptor gateDesc = + new InputGateDeploymentDescriptor( + intermediateDataSetID, + ResultPartitionType.BLOCKING, + 0, + createShuffleDescs(upstreamParallelism, descs)); + return new RemoteShuffleInputGate( + taskName, + true, + 0, + Integer.MAX_VALUE, + gateDesc, + numConcurrentReading, + connMgr, + () -> bufferPool, + null); + } + + private ShuffleDescriptor[] createShuffleDescs( + int upstreamParallelism, List descs) { + ShuffleDescriptor[] ret = new ShuffleDescriptor[upstreamParallelism]; + JobID jobID = new JobID(CommonUtils.randomBytes(32)); + for (int i = 0; i < upstreamParallelism; i++) { + ResultPartitionID rID = new ResultPartitionID(); + int randIdx = random.nextInt(descs.size()); + ShuffleResource shuffleResource = + new DefaultShuffleResource( + new ShuffleWorkerDescriptor[] {descs.get(randIdx)}, + DataPartition.DataPartitionType.MAP_PARTITION); + ret[i] = new RemoteShuffleDescriptor(rID, jobID, shuffleResource); + } + return ret; + } + + private void startServers(int numShuffleWorkers, int dataScalePerReadingView) throws Exception { + dataStores = new ArrayList<>(); + nettyServers = new ArrayList<>(); + workerDescs = new ArrayList<>(); + int[] ports = AbstractNettyTest.getAvailablePorts(numShuffleWorkers); + for (int i = 0; i < numShuffleWorkers; i++) { + int dataPort = ports[i]; + workerDescs.add( + new ShuffleWorkerDescriptor( + null, InetAddress.getLocalHost().getHostAddress(), dataPort)); + FakedPartitionDataStore dataStore = + new FakedPartitionDataStore(dataScalePerReadingView); + dataStores.add(dataStore); + + Configuration config = new Configuration(); + config.setInteger(TransferOptions.SERVER_DATA_PORT, dataPort); + config.setInteger(TransferOptions.SERVER_DATA_PORT, dataPort); + NettyConfig nettyConfig = new NettyConfig(config); + NettyServer server = new NettyServer(dataStore, nettyConfig); + nettyServers.add(server); + server.start(); + } + } + + private class FakedPartitionDataStore extends NoOpPartitionedDataStore { + + private final int dataScalePerReadingView; + private final AtomicInteger totalBuffersToSend = new AtomicInteger(0); + + FakedPartitionDataStore(int dataScalePerReadingView) { + this.dataScalePerReadingView = dataScalePerReadingView; + } + + @Override + public DataPartitionReadingView createDataPartitionReadingView(ReadingViewContext context) { + FakedDataPartitionReadingView ret = + new FakedDataPartitionReadingView( + context.getDataListener(), + context.getBacklogListener(), + context.getFailureListener()); + final int buffersToSend; + if (dataScalePerReadingView == 0) { + buffersToSend = 0; + } else { + buffersToSend = dataScalePerReadingView + random.nextInt(dataScalePerReadingView); + } + totalBuffersToSend.addAndGet(buffersToSend); + new FeedingThread(ret, buffersToSend).start(); + return ret; + } + } +} diff --git a/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleInputGateFactoryTest.java b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleInputGateFactoryTest.java new file mode 100644 index 00000000..1c593337 --- /dev/null +++ b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleInputGateFactoryTest.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.coordinator.manager.DefaultShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerDescriptor; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.plugin.RemoteShuffleDescriptor; +import com.alibaba.flink.shuffle.plugin.config.PluginOptions; +import com.alibaba.flink.shuffle.plugin.utils.ConfigurationUtils; +import com.alibaba.flink.shuffle.transfer.ConnectionManager; +import com.alibaba.flink.shuffle.transfer.ShuffleReadClient; +import com.alibaba.flink.shuffle.transfer.TransferBufferPool; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; +import org.apache.flink.runtime.io.network.buffer.BufferDecompressor; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; +import org.apache.flink.util.function.SupplierWithException; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import org.junit.Test; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.function.Consumer; + +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +/** Test for {@link RemoteShuffleInputGateFactory}. */ +public class RemoteShuffleInputGateFactoryTest { + + @Test + public void testBasicRoutine() throws Exception { + testCreateInputGateAndBufferPool(16, "16m"); + } + + @Test + public void testLessThanMinMemoryPerGate() { + Exception e = + assertThrows( + ConfigurationException.class, + () -> testCreateInputGateAndBufferPool(16, "2m")); + String msg = "Insufficient network memory per input gate, please increase"; + assertTrue(e.getMessage().contains(msg)); + } + + @Test + public void testLessThanMinBuffersPerGate() { + Exception e = + assertThrows( + ConfigurationException.class, + () -> testCreateInputGateAndBufferPool(8, "8m")); + String msg = "Insufficient network memory per input gate, please increase"; + assertTrue(e.getMessage().contains(msg)); + } + + @Test + public void testNetworkBufferShortage() { + IOException e = + assertThrows(IOException.class, () -> testCreateInputGateAndBufferPool(8, "16m")); + assertTrue(e.getMessage().contains("Insufficient number of network buffers")); + } + + private void testCreateInputGateAndBufferPool(int totalNumNetworkBuffers, String inputMemory) + throws Exception { + int bufferSize = 1024 * 1024; + NetworkBufferPool networkBufferPool = + new NetworkBufferPool(totalNumNetworkBuffers, bufferSize); + Configuration conf = new Configuration(); + conf.setString(PluginOptions.MEMORY_PER_INPUT_GATE.key(), inputMemory); + RemoteShuffleInputGateFactory factory = + new TestingRemoteShuffleInputGateFactory( + ConfigurationUtils.fromFlinkConfiguration(conf), + networkBufferPool, + bufferSize); + TestingRemoteShuffleInputGate gate = + (TestingRemoteShuffleInputGate) + factory.create("", 0, createGateDescriptor(100), null); + BufferPool bp = gate.bufferPoolFactory.get(); + bp.lazyDestroy(); + networkBufferPool.destroy(); + } + + private InputGateDeploymentDescriptor createGateDescriptor(int numShuffleDescs) + throws Exception { + int subIdx = 99; + return new InputGateDeploymentDescriptor( + new IntermediateDataSetID(), + ResultPartitionType.BLOCKING, + subIdx, + createShuffleDescriptors(numShuffleDescs)); + } + + private ShuffleDescriptor[] createShuffleDescriptors(int num) throws Exception { + JobID jID = new JobID(CommonUtils.randomBytes(8)); + RemoteShuffleDescriptor[] ret = new RemoteShuffleDescriptor[num]; + for (int i = 0; i < num; i++) { + ResultPartitionID rID = new ResultPartitionID(); + ShuffleResource resource = + new DefaultShuffleResource( + new ShuffleWorkerDescriptor[] { + new ShuffleWorkerDescriptor( + null, InetAddress.getLocalHost().getHostAddress(), 0) + }, + DataPartition.DataPartitionType.MAP_PARTITION); + ret[i] = new RemoteShuffleDescriptor(rID, jID, resource); + } + return ret; + } + + private static class TestingRemoteShuffleInputGateFactory + extends RemoteShuffleInputGateFactory { + + private final int bufferSize; + + TestingRemoteShuffleInputGateFactory( + com.alibaba.flink.shuffle.common.config.Configuration configuration, + NetworkBufferPool networkBufferPool, + int bufferSize) { + super(configuration, networkBufferPool, bufferSize, "LZ4"); + this.bufferSize = bufferSize; + } + + @Override + RemoteShuffleInputGate createInputGate( + String owningTaskName, + boolean shuffleChannels, + int gateIndex, + InputGateDeploymentDescriptor igdd, + int numConcurrentReading, + ConnectionManager connectionManager, + SupplierWithException bufferPoolFactory, + BufferDecompressor bufferDecompressor) { + return new TestingRemoteShuffleInputGate( + owningTaskName, + shuffleChannels, + gateIndex, + bufferSize, + igdd, + numConcurrentReading, + connectionManager, + bufferPoolFactory, + bufferDecompressor); + } + } + + private static class TestingRemoteShuffleInputGate extends RemoteShuffleInputGate { + + private final SupplierWithException bufferPoolFactory; + + TestingRemoteShuffleInputGate( + String taskName, + boolean shuffleChannels, + int gateIndex, + int networkBufferSize, + InputGateDeploymentDescriptor gateDescriptor, + int numConcurrentReading, + ConnectionManager connectionManager, + SupplierWithException bufferPoolFactory, + BufferDecompressor bufferDecompressor) { + super( + taskName, + shuffleChannels, + gateIndex, + networkBufferSize, + gateDescriptor, + numConcurrentReading, + connectionManager, + bufferPoolFactory, + bufferDecompressor); + this.bufferPoolFactory = bufferPoolFactory; + } + + @Override + ShuffleReadClient createShuffleReadClient( + ConnectionManager connectionManager, + InetSocketAddress address, + DataSetID dataSetID, + MapPartitionID mapID, + int startSubIdx, + int endSubIdx, + int bufferSize, + TransferBufferPool bufferPool, + Consumer dataListener, + Consumer failureListener) { + return null; + } + } +} diff --git a/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleInputGateTest.java b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleInputGateTest.java new file mode 100644 index 00000000..8930af18 --- /dev/null +++ b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleInputGateTest.java @@ -0,0 +1,485 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.coordinator.manager.DefaultShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerDescriptor; +import com.alibaba.flink.shuffle.core.exception.PartitionNotFoundException; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.plugin.RemoteShuffleDescriptor; +import com.alibaba.flink.shuffle.plugin.utils.BufferUtils; +import com.alibaba.flink.shuffle.transfer.ConnectionManager; +import com.alibaba.flink.shuffle.transfer.ShuffleReadClient; +import com.alibaba.flink.shuffle.transfer.TransferBufferPool; + +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; +import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; +import org.apache.flink.util.function.SupplierWithException; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.function.Consumer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** Test for {@link RemoteShuffleInputGate}. */ +public class RemoteShuffleInputGateTest { + + private static final int bufferSize = 1; + + private NetworkBufferPool networkBufferPool; + + @Before + public void setup() { + networkBufferPool = + new NetworkBufferPool( + RemoteShuffleInputGateFactory.MIN_BUFFERS_PER_GATE, bufferSize); + } + + @After + public void teatDown() { + networkBufferPool.destroy(); + } + + @Test + public void testOpenChannelsOneByOne() throws Exception { + RemoteShuffleInputGate gate = createRemoteShuffleInputGate(2, 1); + + gate.setup(); + assertTrue(getChannel(gate, 0).isConnected); + assertTrue(getChannel(gate, 1).isConnected); + ReadingThread readingThread = new ReadingThread(gate); + readingThread.start(); + assertEquals(2, gate.getShuffleReadClients().size()); + assertEquals(2, gate.getNumberOfInputChannels()); + check(() -> assertEquals(Thread.State.WAITING, readingThread.getState())); + assertEquals(0, readingThread.numRead); + assertTrue(getChannel(gate, 0).isOpened); + assertFalse(getChannel(gate, 1).isOpened); + + getChannel(gate, 0).triggerEndOfPartitionEvent(); + readingThread.kick(); + check(() -> assertEquals(Thread.State.WAITING, readingThread.getState())); + assertEquals(1, readingThread.numRead); + assertTrue(getChannel(gate, 0).isClosed); + assertTrue(getChannel(gate, 1).isOpened); + + getChannel(gate, 1).triggerEndOfPartitionEvent(); + readingThread.kick(); + check(() -> assertSame(readingThread.getState(), Thread.State.TERMINATED)); + assertEquals(2, readingThread.numRead); + assertTrue(getChannel(gate, 1).isClosed); + } + + @Test + public void testBasicRoutine() throws Exception { + RemoteShuffleInputGate gate = createRemoteShuffleInputGate(1, 1); + gate.setup(); + ReadingThread readingThread = new ReadingThread(gate); + readingThread.start(); + check(() -> assertSame(readingThread.getState(), Thread.State.WAITING)); + assertEquals(0, readingThread.numRead); + assertFalse(gate.isAvailable()); + + getChannel(gate, 0).triggerData(); + assertTrue(gate.isAvailable()); + readingThread.kick(); + check(() -> assertSame(readingThread.getState(), Thread.State.WAITING)); + assertEquals(1, readingThread.numRead); + assertFalse(gate.isAvailable()); + + getChannel(gate, 0).triggerEndOfPartitionEvent(); + readingThread.kick(); + check(() -> assertSame(readingThread.getState(), Thread.State.TERMINATED)); + assertEquals(2, readingThread.numRead); + + assertNull(readingThread.cause); + } + + @Test + public void testReadingFailure() throws Exception { + RemoteShuffleInputGate gate = createRemoteShuffleInputGate(1, 1); + gate.setup(); + ReadingThread readingThread = new ReadingThread(gate); + readingThread.start(); + check(() -> assertSame(readingThread.getState(), Thread.State.WAITING)); + assertEquals(0, readingThread.numRead); + + getChannel(gate, 0).triggerFailure(); + readingThread.kick(); + check(() -> assertSame(readingThread.getState(), Thread.State.TERMINATED)); + + assertNotNull(readingThread.cause); + } + + @Test + public void testClosing() throws Exception { + RemoteShuffleInputGate gate = createRemoteShuffleInputGate(1, 1); + gate.setup(); + gate.close(); + assertTrue(getChannel(gate, 0).isClosed); + } + + @Test + public void testFireHandshakeByPollNext() throws Exception { + RemoteShuffleInputGate gate = createRemoteShuffleInputGate(1, 1); + gate.setup(); + assertTrue(gate.isAvailable()); + assertTrue(getChannel(gate, 0).isConnected); + assertFalse(getChannel(gate, 0).isOpened); + + assertFalse(gate.pollNext().isPresent()); + assertFalse(gate.isAvailable()); + assertTrue(getChannel(gate, 0).isOpened); + + getChannel(gate, 0).triggerData(); + assertTrue(gate.isAvailable()); + Optional polled = gate.pollNext(); + assertTrue(polled.isPresent()); + polled.get().getBuffer().recycleBuffer(); + assertFalse(gate.pollNext().isPresent()); + assertFalse(gate.isAvailable()); + + gate.close(); + } + + @Test + public void testPartitionException() throws Exception { + final RemoteShuffleInputGate gate0 = createRemoteShuffleInputGate(1, 1, true); + assertThrows(ShuffleException.class, gate0::setup); + gate0.close(); + + final RemoteShuffleInputGate gate1 = createRemoteShuffleInputGate(1, 1); + gate1.setup(); + ReadingThread readingThread = new ReadingThread(gate1); + readingThread.start(); + check(() -> assertEquals(readingThread.getState(), Thread.State.WAITING)); + assertEquals(0, readingThread.numRead); + + getChannel(gate1, 0).triggerPartitionNotFound(); + readingThread.kick(); + check(() -> assertEquals(readingThread.getState(), Thread.State.TERMINATED)); + + Class clazz = com.alibaba.flink.shuffle.plugin.transfer.PartitionNotFoundException.class; + assertEquals(clazz, readingThread.cause.getClass()); + gate1.close(); + } + + private RemoteShuffleInputGate createRemoteShuffleInputGate( + int numShuffleDescs, int numConcurrentReading) throws Exception { + return createRemoteShuffleInputGate(numShuffleDescs, numConcurrentReading, false); + } + + private RemoteShuffleInputGate createRemoteShuffleInputGate( + int numShuffleDescs, int numConcurrentReading, boolean throwsWhenConnect) + throws Exception { + return new TestingRemoteShuffleInputGate( + numConcurrentReading, + createGateDescriptor(numShuffleDescs), + () -> + networkBufferPool.createBufferPool( + RemoteShuffleInputGateFactory.MIN_BUFFERS_PER_GATE, + RemoteShuffleInputGateFactory.MIN_BUFFERS_PER_GATE), + throwsWhenConnect); + } + + private InputGateDeploymentDescriptor createGateDescriptor(int numShuffleDescs) + throws Exception { + int subIdx = 99; + return new InputGateDeploymentDescriptor( + new IntermediateDataSetID(), + ResultPartitionType.BLOCKING, + subIdx, + createShuffleDescriptors(numShuffleDescs)); + } + + private ShuffleDescriptor[] createShuffleDescriptors(int num) throws Exception { + JobID jID = new JobID(CommonUtils.randomBytes(8)); + RemoteShuffleDescriptor[] ret = new RemoteShuffleDescriptor[num]; + for (int i = 0; i < num; i++) { + ResultPartitionID rID = new ResultPartitionID(); + ShuffleResource resource = + new DefaultShuffleResource( + new ShuffleWorkerDescriptor[] { + new ShuffleWorkerDescriptor( + null, InetAddress.getLocalHost().getHostAddress(), 0) + }, + DataPartition.DataPartitionType.MAP_PARTITION); + ret[i] = new RemoteShuffleDescriptor(rID, jID, resource); + } + return ret; + } + + private FakedShuffleReadClient getChannel(RemoteShuffleInputGate gate, int idx) { + return (FakedShuffleReadClient) gate.getShuffleReadClients().get(idx); + } + + protected void check(Runnable runnable) throws InterruptedException { + for (int i = 0; i < 10; i++) { + try { + runnable.run(); + return; + } catch (Throwable t) { + Thread.sleep(200); + } + } + fail(); + } + + private class TestingRemoteShuffleInputGate extends RemoteShuffleInputGate { + + private final boolean throwsWhenConnect; + + public TestingRemoteShuffleInputGate( + int numConcurrentReading, + InputGateDeploymentDescriptor gateDescriptor, + SupplierWithException bufferPoolFactory, + boolean throwsWhenConnect) { + + super( + "RemoteShuffleInputGateTest", + true, + 0, + bufferSize, + gateDescriptor, + numConcurrentReading, + null, + bufferPoolFactory, + null); + this.throwsWhenConnect = throwsWhenConnect; + } + + @Override + ShuffleReadClient createShuffleReadClient( + ConnectionManager connectionManager, + InetSocketAddress address, + DataSetID dataSetID, + MapPartitionID mapID, + int startSubIdx, + int endSubIdx, + int bufferSize, + TransferBufferPool bufferPool, + Consumer dataListener, + Consumer failureListener) { + + List buffers = new ArrayList<>(); + for (int i = 0; i < 1024; i++) { + NetworkBuffer buffer = + new NetworkBuffer( + MemorySegmentFactory.allocateUnpooledSegment(1024), + FreeingBufferRecycler.INSTANCE); + buffers.add(buffer.asByteBuf()); + } + bufferPool.addBuffers(buffers); + + return new FakedShuffleReadClient(this, bufferPool, dataListener, failureListener); + } + } + + private class FakedShuffleReadClient extends ShuffleReadClient { + + private final TestingRemoteShuffleInputGate parent; + private final TransferBufferPool bufferPool; + private final Consumer dataListener; + private final Consumer failureListener; + + private boolean isConnected; + private boolean isOpened; + private boolean isClosed; + + public FakedShuffleReadClient( + TestingRemoteShuffleInputGate parent, + TransferBufferPool bufferPool, + Consumer dataListener, + Consumer failureListener) { + super( + new InetSocketAddress(1), + new DataSetID(CommonUtils.randomBytes(1)), + new MapPartitionID(CommonUtils.randomBytes(1)), + 0, + 0, + Integer.MAX_VALUE, + bufferPool, + new ConnectionManager(null, null, 3, Duration.ofMillis(1)), + dataListener, + failureListener); + this.bufferPool = bufferPool; + this.parent = parent; + this.dataListener = dataListener; + this.failureListener = failureListener; + } + + @Override + public void connect() { + if (parent.throwsWhenConnect) { + throw new ShuffleException("Connect failure."); + } + isConnected = true; + } + + @Override + public void open() { + isOpened = true; + } + + @Override + public boolean isOpened() { + return isOpened; + } + + @Override + public void close() { + isClosed = true; + } + + public void triggerData() throws Exception { + ByteBuf byteBuf = + new NetworkBuffer( + MemorySegmentFactory.allocateUnpooledSegment(1024), + FreeingBufferRecycler.INSTANCE); + BufferUtils.setBufferHeader(byteBuf, Buffer.DataType.DATA_BUFFER, false, 1); + byteBuf.writeByte(2); + dataListener.accept(byteBuf); + } + + public void triggerData(ByteBuffer data) throws Exception { + ByteBuf byteBuf = + new NetworkBuffer( + MemorySegmentFactory.allocateUnpooledSegment(1024), + FreeingBufferRecycler.INSTANCE); + BufferUtils.setBufferHeader( + byteBuf, Buffer.DataType.DATA_BUFFER, false, data.remaining()); + byteBuf.writeBytes(data); + data.position(0); + dataListener.accept(byteBuf); + } + + public void triggerEndOfPartitionEvent() throws IOException { + Buffer buffer = EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE, false); + ByteBuf byteBuf = + new NetworkBuffer( + MemorySegmentFactory.allocateUnpooledSegment(1024), + FreeingBufferRecycler.INSTANCE); + BufferUtils.setBufferHeader( + byteBuf, buffer.getDataType(), buffer.isCompressed(), buffer.readableBytes()); + dataListener.accept(byteBuf); + } + + public void triggerFailure() { + failureListener.accept(new Exception("")); + } + + public void triggerPartitionNotFound() { + failureListener.accept(new IOException(PartitionNotFoundException.class.getName())); + } + + @Override + public void notifyAvailableCredits(int numCredits) {} + } + + private static class ReadingThread extends Thread { + + private final Object lock = new Object(); + private final RemoteShuffleInputGate gate; + private int numRead; + private Throwable cause; + + private final List buffers = new ArrayList<>(); + + public ReadingThread(RemoteShuffleInputGate gate) { + this.gate = gate; + } + + @Override + public void run() { + while (true) { + try { + Optional got = gate.pollNext(); + if (got.isPresent()) { + numRead++; + if (got.get().isBuffer()) { + buffers.add(got.get().getBuffer().getNioBufferReadable()); + got.get().getBuffer().recycleBuffer(); + } else if (gate.isFinished() && got.get().moreAvailable()) { + throw new Exception("Got EOF but indicating more data available."); + } + } else { + if (gate.isFinished()) { + break; + } else { + synchronized (lock) { + lock.wait(); + } + } + } + } catch (Throwable t) { + cause = t; + break; + } + } + } + + public void kick() { + synchronized (lock) { + lock.notify(); + } + } + + public List getBuffers() { + return buffers; + } + } +} diff --git a/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleResultPartitionTest.java b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleResultPartitionTest.java new file mode 100644 index 00000000..cf2ed6bb --- /dev/null +++ b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/RemoteShuffleResultPartitionTest.java @@ -0,0 +1,532 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.ProtocolUtils; +import com.alibaba.flink.shuffle.coordinator.manager.DefaultShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerDescriptor; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.plugin.RemoteShuffleDescriptor; +import com.alibaba.flink.shuffle.plugin.transfer.PartitionSortedBufferTest.DataAndType; +import com.alibaba.flink.shuffle.plugin.utils.BufferUtils; +import com.alibaba.flink.shuffle.transfer.ConnectionManager; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.Buffer.DataType; +import org.apache.flink.runtime.io.network.buffer.BufferCompressor; +import org.apache.flink.runtime.io.network.buffer.BufferDecompressor; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.util.function.SupplierWithException; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Random; +import java.util.Set; +import java.util.stream.IntStream; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** Test for {@link RemoteShuffleResultPartition}. */ +public class RemoteShuffleResultPartitionTest { + + private static final int totalBuffers = 1000; + + private static final int bufferSize = 1024; + + private NetworkBufferPool globalBufferPool; + + private BufferPool sortBufferPool; + + private BufferPool nettyBufferPool; + + private RemoteShuffleResultPartition partitionWriter; + + private FakedRemoteShuffleOutputGate outputGate; + + private BufferCompressor bufferCompressor; + + private BufferDecompressor bufferDecompressor; + + @Before + public void setup() { + globalBufferPool = new NetworkBufferPool(totalBuffers, bufferSize); + bufferCompressor = new BufferCompressor(bufferSize, "LZ4"); + bufferDecompressor = new BufferDecompressor(bufferSize, "LZ4"); + } + + @After + public void tearDown() throws Exception { + if (outputGate != null) { + outputGate.release(); + } + + if (sortBufferPool != null) { + sortBufferPool.lazyDestroy(); + } + if (nettyBufferPool != null) { + nettyBufferPool.lazyDestroy(); + } + assertEquals(totalBuffers, globalBufferPool.getNumberOfAvailableMemorySegments()); + globalBufferPool.destroy(); + } + + @Test + public void testWriteNormalRecordWithCompressionEnabled() throws Exception { + testWriteNormalRecord(true); + } + + @Test + public void testWriteNormalRecordWithCompressionDisabled() throws Exception { + testWriteNormalRecord(false); + } + + @Test + public void testWriteLargeRecord() throws Exception { + int numSubpartitions = 2; + int numBuffers = 100; + initResultPartitionWriter(numSubpartitions, 10, 200, false); + + partitionWriter.setup(); + + byte[] dataWritten = new byte[bufferSize * numBuffers]; + Random random = new Random(); + random.nextBytes(dataWritten); + ByteBuffer recordWritten = ByteBuffer.wrap(dataWritten); + partitionWriter.emitRecord(recordWritten, 0); + assertEquals(0, sortBufferPool.bestEffortGetNumOfUsedBuffers()); + + partitionWriter.finish(); + partitionWriter.close(); + + List receivedBuffers = outputGate.getReceivedBuffers()[0]; + + ByteBuffer recordRead = ByteBuffer.allocate(bufferSize * numBuffers); + for (Buffer buffer : receivedBuffers) { + if (buffer.isBuffer()) { + recordRead.put( + buffer.getNioBuffer( + BufferUtils.HEADER_LENGTH, + buffer.readableBytes() - BufferUtils.HEADER_LENGTH)); + } + } + recordWritten.rewind(); + recordRead.flip(); + assertEquals(recordWritten, recordRead); + } + + @Test + public void testBroadcastLargeRecord() throws Exception { + int numSubpartitions = 2; + int numBuffers = 100; + initResultPartitionWriter(numSubpartitions, 10, 200, false); + + partitionWriter.setup(); + + byte[] dataWritten = new byte[bufferSize * numBuffers]; + Random random = new Random(); + random.nextBytes(dataWritten); + ByteBuffer recordWritten = ByteBuffer.wrap(dataWritten); + partitionWriter.broadcastRecord(recordWritten); + assertEquals(0, sortBufferPool.bestEffortGetNumOfUsedBuffers()); + + partitionWriter.finish(); + partitionWriter.close(); + + ByteBuffer recordRead0 = ByteBuffer.allocate(bufferSize * numBuffers); + for (Buffer buffer : outputGate.getReceivedBuffers()[0]) { + if (buffer.isBuffer()) { + recordRead0.put( + buffer.getNioBuffer( + BufferUtils.HEADER_LENGTH, + buffer.readableBytes() - BufferUtils.HEADER_LENGTH)); + } + } + recordWritten.rewind(); + recordRead0.flip(); + assertEquals(recordWritten, recordRead0); + + ByteBuffer recordRead1 = ByteBuffer.allocate(bufferSize * numBuffers); + for (Buffer buffer : outputGate.getReceivedBuffers()[1]) { + if (buffer.isBuffer()) { + recordRead1.put( + buffer.getNioBuffer( + BufferUtils.HEADER_LENGTH, + buffer.readableBytes() - BufferUtils.HEADER_LENGTH)); + } + } + recordWritten.rewind(); + recordRead1.flip(); + assertEquals(recordWritten, recordRead0); + } + + @Test + public void testFlush() throws Exception { + int numSubpartitions = 10; + + initResultPartitionWriter(numSubpartitions, 10, 20, false); + partitionWriter.setup(); + + partitionWriter.emitRecord(ByteBuffer.allocate(bufferSize), 0); + partitionWriter.emitRecord(ByteBuffer.allocate(bufferSize), 1); + assertEquals(3, sortBufferPool.bestEffortGetNumOfUsedBuffers()); + + partitionWriter.broadcastRecord(ByteBuffer.allocate(bufferSize)); + assertEquals(2, sortBufferPool.bestEffortGetNumOfUsedBuffers()); + + partitionWriter.flush(0); + assertEquals(0, sortBufferPool.bestEffortGetNumOfUsedBuffers()); + + partitionWriter.emitRecord(ByteBuffer.allocate(bufferSize), 2); + partitionWriter.emitRecord(ByteBuffer.allocate(bufferSize), 3); + assertEquals(3, sortBufferPool.bestEffortGetNumOfUsedBuffers()); + + partitionWriter.flushAll(); + assertEquals(0, sortBufferPool.bestEffortGetNumOfUsedBuffers()); + + partitionWriter.finish(); + partitionWriter.close(); + } + + private void testWriteNormalRecord(boolean compressionEnabled) throws Exception { + int numSubpartitions = 4; + int numRecords = 100; + Random random = new Random(); + + initResultPartitionWriter(numSubpartitions, 100, 500, compressionEnabled); + partitionWriter.setup(); + assertTrue(outputGate.isSetup()); + + Queue[] dataWritten = new Queue[numSubpartitions]; + IntStream.range(0, numSubpartitions).forEach(i -> dataWritten[i] = new ArrayDeque<>()); + int[] numBytesWritten = new int[numSubpartitions]; + Arrays.fill(numBytesWritten, 0); + + for (int i = 0; i < numRecords; i++) { + byte[] data = new byte[random.nextInt(2 * bufferSize) + 1]; + if (compressionEnabled) { + byte randomByte = (byte) random.nextInt(); + Arrays.fill(data, randomByte); + } else { + random.nextBytes(data); + } + ByteBuffer record = ByteBuffer.wrap(data); + boolean isBroadCast = random.nextBoolean(); + + if (isBroadCast) { + partitionWriter.broadcastRecord(record); + IntStream.range(0, numSubpartitions) + .forEach( + subpartition -> + recordDataWritten( + record, + DataType.DATA_BUFFER, + subpartition, + dataWritten, + numBytesWritten)); + } else { + int subpartition = random.nextInt(numSubpartitions); + partitionWriter.emitRecord(record, subpartition); + recordDataWritten( + record, DataType.DATA_BUFFER, subpartition, dataWritten, numBytesWritten); + } + } + + partitionWriter.finish(); + assertTrue(outputGate.isFinished()); + partitionWriter.close(); + assertTrue(outputGate.isClosed()); + + for (int subpartition = 0; subpartition < numSubpartitions; ++subpartition) { + ByteBuffer record = EventSerializer.toSerializedEvent(EndOfPartitionEvent.INSTANCE); + recordDataWritten( + record, DataType.EVENT_BUFFER, subpartition, dataWritten, numBytesWritten); + } + + outputGate + .getFinishedRegions() + .forEach( + regionIndex -> + assertTrue( + outputGate + .getNumBuffersByRegion() + .containsKey(regionIndex))); + + int[] numBytesRead = new int[numSubpartitions]; + List[] receivedBuffers = outputGate.getReceivedBuffers(); + List[] validateTarget = new List[numSubpartitions]; + Arrays.fill(numBytesRead, 0); + for (int i = 0; i < numSubpartitions; i++) { + validateTarget[i] = new ArrayList<>(); + for (Buffer buffer : receivedBuffers[i]) { + for (Buffer unpackedBuffer : BufferPacker.unpack(buffer.asByteBuf())) { + if (compressionEnabled && unpackedBuffer.isCompressed()) { + Buffer decompressedBuffer = + bufferDecompressor.decompressToIntermediateBuffer(unpackedBuffer); + ByteBuffer decompressed = decompressedBuffer.getNioBufferReadable(); + int numBytes = decompressed.remaining(); + MemorySegment segment = + MemorySegmentFactory.allocateUnpooledSegment(numBytes); + segment.put(0, decompressed, numBytes); + decompressedBuffer.recycleBuffer(); + validateTarget[i].add( + new NetworkBuffer( + segment, + buf -> {}, + unpackedBuffer.getDataType(), + numBytes)); + numBytesRead[i] += numBytes; + } else { + numBytesRead[i] += buffer.readableBytes(); + validateTarget[i].add(buffer); + } + } + } + } + IntStream.range(0, numSubpartitions).forEach(subpartitions -> {}); + PartitionSortedBufferTest.checkWriteReadResult( + numSubpartitions, numBytesWritten, numBytesWritten, dataWritten, validateTarget); + } + + private void initResultPartitionWriter( + int numSubpartitions, + int sortBufferPoolSize, + int nettyBufferPoolSize, + boolean compressionEnabled) + throws Exception { + + sortBufferPool = globalBufferPool.createBufferPool(sortBufferPoolSize, sortBufferPoolSize); + nettyBufferPool = + globalBufferPool.createBufferPool(nettyBufferPoolSize, nettyBufferPoolSize); + + outputGate = + new FakedRemoteShuffleOutputGate( + getShuffleDescriptor(), numSubpartitions, () -> nettyBufferPool); + outputGate.setup(); + + if (compressionEnabled) { + partitionWriter = + new RemoteShuffleResultPartition( + "RemoteShuffleResultPartitionWriterTest", + 0, + new ResultPartitionID(), + ResultPartitionType.BLOCKING, + numSubpartitions, + numSubpartitions, + bufferSize, + new ResultPartitionManager(), + bufferCompressor, + () -> sortBufferPool, + outputGate); + } else { + partitionWriter = + new RemoteShuffleResultPartition( + "RemoteShuffleResultPartitionWriterTest", + 0, + new ResultPartitionID(), + ResultPartitionType.BLOCKING, + numSubpartitions, + numSubpartitions, + bufferSize, + new ResultPartitionManager(), + null, + () -> sortBufferPool, + outputGate); + } + } + + private void recordDataWritten( + ByteBuffer record, + DataType dataType, + int subpartition, + Queue[] dataWritten, + int[] numBytesWritten) { + + record.rewind(); + dataWritten[subpartition].add(new DataAndType(record, dataType)); + numBytesWritten[subpartition] += record.remaining(); + } + + private static class FakedRemoteShuffleOutputGate extends RemoteShuffleOutputGate { + + private boolean isSetup; + private boolean isFinished; + private boolean isClosed; + private final List[] receivedBuffers; + private final Map numBuffersByRegion; + private final Set finishedRegions; + private int currentRegionIndex; + private boolean currentIsBroadcast; + + FakedRemoteShuffleOutputGate( + RemoteShuffleDescriptor shuffleDescriptor, + int numSubpartitions, + SupplierWithException bufferPoolFactory) { + + super( + shuffleDescriptor, + numSubpartitions, + bufferSize, + ProtocolUtils.emptyDataPartitionType(), + bufferPoolFactory, + new ConnectionManager(null, null, 3, Duration.ofMillis(1))); + isSetup = false; + isFinished = false; + isClosed = false; + numBuffersByRegion = new HashMap<>(); + finishedRegions = new HashSet<>(); + currentRegionIndex = -1; + receivedBuffers = new ArrayList[numSubpartitions]; + IntStream.range(0, numSubpartitions) + .forEach(i -> receivedBuffers[i] = new ArrayList<>()); + currentIsBroadcast = false; + } + + @Override + public void setup() throws IOException, InterruptedException { + bufferPool = bufferPoolFactory.get(); + isSetup = true; + } + + @Override + public void write(Buffer buffer, int subIdx) { + if (currentIsBroadcast) { + assertEquals(0, subIdx); + ByteBuffer byteBuffer = buffer.getNioBufferReadable(); + for (int i = 0; i < numSubs; i++) { + int numBytes = buffer.readableBytes(); + MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(numBytes); + byteBuffer.rewind(); + segment.put(0, byteBuffer, numBytes); + receivedBuffers[i].add( + new NetworkBuffer( + segment, + buf -> {}, + buffer.getDataType(), + buffer.isCompressed(), + numBytes)); + } + buffer.recycleBuffer(); + } else { + receivedBuffers[subIdx].add(buffer); + } + if (numBuffersByRegion.containsKey(currentRegionIndex)) { + int prev = numBuffersByRegion.get(currentRegionIndex); + numBuffersByRegion.put(currentRegionIndex, prev + 1); + } else { + numBuffersByRegion.put(currentRegionIndex, 1); + } + } + + @Override + public void regionStart(boolean isBroadcast) { + currentIsBroadcast = isBroadcast; + currentRegionIndex++; + } + + @Override + public void regionFinish() { + if (finishedRegions.contains(currentRegionIndex)) { + throw new IllegalStateException("Unexpected region: " + currentRegionIndex); + } + finishedRegions.add(currentRegionIndex); + } + + @Override + public void finish() throws InterruptedException { + isFinished = true; + } + + @Override + public void close() { + isClosed = true; + } + + public List[] getReceivedBuffers() { + return receivedBuffers; + } + + public Map getNumBuffersByRegion() { + return numBuffersByRegion; + } + + public Set getFinishedRegions() { + return finishedRegions; + } + + public boolean isSetup() { + return isSetup; + } + + public boolean isFinished() { + return isFinished; + } + + public boolean isClosed() { + return isClosed; + } + + public void release() throws Exception { + IntStream.range(0, numSubs) + .forEach( + subpartitionIndex -> { + receivedBuffers[subpartitionIndex].forEach(Buffer::recycleBuffer); + receivedBuffers[subpartitionIndex].clear(); + }); + numBuffersByRegion.clear(); + finishedRegions.clear(); + super.close(); + } + } + + private RemoteShuffleDescriptor getShuffleDescriptor() throws Exception { + return new RemoteShuffleDescriptor( + new ResultPartitionID(), + new JobID(CommonUtils.randomBytes(64)), + new DefaultShuffleResource( + new ShuffleWorkerDescriptor[] { + new ShuffleWorkerDescriptor(null, "localhost", 0) + }, + DataPartition.DataPartitionType.MAP_PARTITION)); + } +} diff --git a/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/StreamProcessorTest.java b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/StreamProcessorTest.java new file mode 100644 index 00000000..1ec3a3df --- /dev/null +++ b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/StreamProcessorTest.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.streaming.api.operators.InputSelectable; +import org.apache.flink.streaming.api.operators.InputSelection; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.io.DataInputStatus; +import org.apache.flink.streaming.runtime.io.MultipleInputSelectionHandler; +import org.apache.flink.streaming.runtime.io.PushingAsyncDataInput; +import org.apache.flink.streaming.runtime.io.StreamInputProcessor; +import org.apache.flink.streaming.runtime.io.StreamMultipleInputProcessor; +import org.apache.flink.streaming.runtime.io.StreamOneInputProcessor; +import org.apache.flink.streaming.runtime.io.StreamTaskInput; +import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; + +import org.apache.flink.shaded.curator4.com.google.common.collect.Lists; + +import org.junit.Test; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * This test is to assert Flink {@link + * org.apache.flink.streaming.runtime.io.StreamMultipleInputProcessor} will not ingest from input + * before really processing the data. + */ +public class StreamProcessorTest { + + @Test + public void testMultipleInputProcessor() throws Exception { + boolean[] eofs = new boolean[2]; + + AtomicBoolean emitted0 = new AtomicBoolean(false); + Deque statuses0 = + new ArrayDeque<>( + Lists.newArrayList( + DataInputStatus.MORE_AVAILABLE, DataInputStatus.END_OF_INPUT)); + StreamOneInputProcessor processor0 = getOneInputProcessor(0, statuses0, emitted0, eofs); + + AtomicBoolean emitted1 = new AtomicBoolean(false); + Deque statuses1 = + new ArrayDeque<>(Lists.newArrayList(DataInputStatus.END_OF_INPUT)); + StreamOneInputProcessor processor1 = getOneInputProcessor(1, statuses1, emitted1, eofs); + + InputSelectable inputSelectable = getInputSelectableForTwoInputProcessor(eofs); + final StreamInputProcessor processor; + processor = getMultipleInputProcessor(inputSelectable, processor0, processor1); + + assertFalse(emitted0.get()); + assertFalse(emitted1.get()); + assertFalse(eofs[0]); + + processor.processInput(); + assertTrue(emitted0.get() && !emitted1.get()); + assertFalse(eofs[0]); + + processor.processInput(); + assertTrue(emitted0.get() && !emitted1.get()); + assertTrue(eofs[0]); + + processor.processInput(); + assertTrue(emitted0.get() && emitted1.get()); + assertTrue(eofs[1]); + } + + private StreamMultipleInputProcessor getMultipleInputProcessor( + InputSelectable inputSelectable, StreamOneInputProcessor... inputProcessors) { + return new StreamMultipleInputProcessor( + new MultipleInputSelectionHandler(inputSelectable, inputProcessors.length), + inputProcessors); + } + + private InputSelectable getInputSelectableForTwoInputProcessor(boolean[] eofs) { + return () -> { + if (!eofs[0]) { + return InputSelection.FIRST; + } else { + return InputSelection.SECOND; + } + }; + } + + private StreamOneInputProcessor getOneInputProcessor( + int inputIdx, + Deque inputStatuses, + AtomicBoolean recordEmitted, + boolean[] endOfInputs) { + + PushingAsyncDataInput.DataOutput output = + new PushingAsyncDataInput.DataOutput() { + @Override + public void emitRecord(StreamRecord streamRecord) {} + + @Override + public void emitWatermark(Watermark watermark) {} + + @Override + public void emitWatermarkStatus(WatermarkStatus watermarkStatus) {} + + @Override + public void emitLatencyMarker(LatencyMarker latencyMarker) {} + }; + + StreamTaskInput taskInput = + new StreamTaskInput() { + @Override + public int getInputIndex() { + return inputIdx; + } + + @Override + public CompletableFuture prepareSnapshot( + ChannelStateWriter channelStateWriter, long l) { + return null; + } + + @Override + public void close() {} + + @Override + public DataInputStatus emitNext(DataOutput dataOutput) { + recordEmitted.set(true); + DataInputStatus inputStatus = inputStatuses.poll(); + if (inputStatus == DataInputStatus.END_OF_INPUT) { + endOfInputs[inputIdx] = true; + } + return inputStatus; + } + + @Override + public CompletableFuture getAvailableFuture() { + return null; + } + }; + return new StreamOneInputProcessor<>(taskInput, output, ignored -> {}); + } +} diff --git a/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/WritingIntegrationTest.java b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/WritingIntegrationTest.java new file mode 100644 index 00000000..cf89665d --- /dev/null +++ b/shuffle-plugin/src/test/java/com/alibaba/flink/shuffle/plugin/transfer/WritingIntegrationTest.java @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.plugin.transfer; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.coordinator.manager.DefaultShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleResource; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleWorkerDescriptor; +import com.alibaba.flink.shuffle.core.config.TransferOptions; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWritingView; +import com.alibaba.flink.shuffle.core.storage.WritingViewContext; +import com.alibaba.flink.shuffle.plugin.RemoteShuffleDescriptor; +import com.alibaba.flink.shuffle.transfer.AbstractNettyTest; +import com.alibaba.flink.shuffle.transfer.ConnectionManager; +import com.alibaba.flink.shuffle.transfer.FakedDataPartitionWritingView; +import com.alibaba.flink.shuffle.transfer.NettyConfig; +import com.alibaba.flink.shuffle.transfer.NettyServer; +import com.alibaba.flink.shuffle.transfer.TestTransferBufferPool; +import com.alibaba.flink.shuffle.transfer.TransferBufferPool; +import com.alibaba.flink.shuffle.transfer.utils.NoOpPartitionedDataStore; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Random; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyBufferSize; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyDataPartitionType; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** Integration test for writing process. */ +@RunWith(Parameterized.class) +public class WritingIntegrationTest { + + private static final Logger LOG = LoggerFactory.getLogger(WritingIntegrationTest.class); + + private String localhost; + + private TransferBufferPool serverBufferPool; + + private NetworkBufferPool networkBufferPool; + + private FakedPartitionDataStore dataStore; + + private NettyConfig nettyConfig; + + private NettyServer nettyServer; + + private int numSubs; + + private Random random; + + private final int parallelism; + + // Num of output gates share a ConnectionManager. + private final int groupSize; + + // Num of buffers for transmitting per writing channel. + private final int nBuffersPerTask; + + private final int dataScale; + + public WritingIntegrationTest( + int parallelism, int groupSize, int nBuffersPerTask, int dataScale) { + this.parallelism = parallelism; + this.groupSize = groupSize; + this.nBuffersPerTask = nBuffersPerTask; + this.dataScale = dataScale; + } + + @Parameterized.Parameters + public static Collection data() { + return Arrays.asList( + new Object[][] { + {10, 1, 10, 200}, + {10, 5, 10, 200}, + {10, 10, 10, 200}, + {10, 10, 10, 0}, + {1200, 20, 10, 200}, + {500, 20, 40, 200}, + }); + } + + @Before + public void setup() throws Exception { + nettyConfig = new NettyConfig(new Configuration()); + int dataPort = AbstractNettyTest.getAvailablePort(); + nettyConfig.getConfig().setInteger(TransferOptions.SERVER_DATA_PORT, dataPort); + if (localhost == null) { + localhost = InetAddress.getLocalHost().getHostAddress(); + } + + random = new Random(); + + // Setup server. + serverBufferPool = new TestTransferBufferPool(2000, 64); + dataStore = new FakedPartitionDataStore(() -> (Buffer) (serverBufferPool.requestBuffer())); + nettyServer = new NettyServer(dataStore, nettyConfig); + nettyServer.disableHeartbeat(); + nettyServer.start(); + + // Setup client. + numSubs = 32; + networkBufferPool = new NetworkBufferPool(20000, 32); + } + + @After + public void tearDown() { + assertEquals(2000, serverBufferPool.numBuffers()); + assertEquals(20000, networkBufferPool.getNumberOfAvailableMemorySegments()); + nettyServer.shutdown(); + serverBufferPool.destroy(); + networkBufferPool.destroy(); + } + + @Test(timeout = 300_000) + public void test() throws Exception { + + List transBufferPools = new ArrayList<>(parallelism); + + for (int i = 0; i < parallelism; i++) { + transBufferPools.add( + networkBufferPool.createBufferPool(nBuffersPerTask, nBuffersPerTask)); + } + List connMgrs = new ArrayList<>(); + List outputGates = new ArrayList<>(); + for (int i = 0; i < parallelism; i += groupSize) { + Pair> pair = + createOutputGateGroup(transBufferPools.subList(i, i + groupSize)); + connMgrs.add(pair.getLeft()); + outputGates.addAll(pair.getRight()); + } + + List threads = new ArrayList<>(); + int totalNumBuffersToSend = 0; + AtomicReference cause = new AtomicReference<>(null); + for (int i = 0; i < parallelism; i++) { + RemoteShuffleOutputGate gate = outputGates.get(i); + gate.setup(); + final int numBuffersToSend; + if (dataScale == 0) { + numBuffersToSend = 0; + } else { + numBuffersToSend = dataScale + random.nextInt(dataScale); + } + totalNumBuffersToSend += numBuffersToSend; + if (i % groupSize == 1) { + threads.add(new WritingThread(gate, numBuffersToSend, true, cause)); + } else { + threads.add(new WritingThread(gate, numBuffersToSend, false, cause)); + } + } + threads.forEach(t -> t.start()); + for (Thread t : threads) { + t.join(); + } + assertNull(cause.get()); + + for (FakedDataPartitionWritingView writingView : dataStore.writingViews) { + checkUntil( + () -> assertTrue(writingView.isFinished() || writingView.getError() != null)); + } + assertEquals(totalNumBuffersToSend, dataStore.numReceivedBuffers.get()); + for (ConnectionManager connMgr : connMgrs) { + assertEquals(0, connMgr.numPhysicalConnections()); + connMgr.shutdown(); + } + transBufferPools.forEach(pool -> pool.lazyDestroy()); + } + + private class WritingThread extends Thread { + + RemoteShuffleOutputGate gate; + + int numBuffers; + + boolean throwsWhenWriting; + + AtomicReference cause; + + WritingThread( + RemoteShuffleOutputGate gate, + int numBuffers, + boolean throwsWhenWriting, + AtomicReference cause) { + this.gate = gate; + this.numBuffers = numBuffers; + this.throwsWhenWriting = throwsWhenWriting; + this.cause = cause; + } + + @Override + public void run() { + try { + int regionSize = 10; + for (int i = 0; i < numBuffers; i++) { + MemorySegment mem = gate.getBufferPool().requestMemorySegmentBlocking(); + NetworkBuffer buffer = new NetworkBuffer(mem, gate.getBufferPool()::recycle); + while (buffer.readableBytes() + 4 <= buffer.capacity()) { + buffer.writeByte(random.nextInt()); + } + if (i % regionSize == 0) { + // No need to test broadcast which doesn't have effect on credit-based + // transportation. + gate.regionStart(false); + } + gate.write(buffer, random.nextInt(numSubs)); + if (i % regionSize == regionSize - 1) { + gate.regionFinish(); + } + } + if (numBuffers % regionSize != 0) { + gate.regionFinish(); + } + if (throwsWhenWriting) { + throw new Exception("Manual exception."); + } + gate.finish(); + } catch (Throwable t) { + } finally { + try { + gate.close(); + } catch (Throwable t) { + cause.set(t); + } + } + } + } + + private Pair> createOutputGateGroup( + List clientBufferPools) throws Exception { + ConnectionManager connMgr = + ConnectionManager.createWriteConnectionManager(nettyConfig, true); + connMgr.start(); + + List gates = new ArrayList<>(); + for (BufferPool bp : clientBufferPools) { + gates.add(createOutputGate(connMgr, bp)); + } + return Pair.of(connMgr, gates); + } + + private RemoteShuffleOutputGate createOutputGate( + ConnectionManager connManager, BufferPool bufferPool) throws Exception { + ResultPartitionID resultPartitionID = new ResultPartitionID(); + JobID jobID = new JobID(CommonUtils.randomBytes(32)); + ShuffleResource shuffleResource = + new DefaultShuffleResource( + new ShuffleWorkerDescriptor[] { + new ShuffleWorkerDescriptor( + null, localhost, nettyConfig.getServerPort()) + }, + DataPartition.DataPartitionType.MAP_PARTITION); + RemoteShuffleDescriptor shuffleDescriptor = + new RemoteShuffleDescriptor(resultPartitionID, jobID, shuffleResource); + return new RemoteShuffleOutputGate( + shuffleDescriptor, + numSubs, + emptyBufferSize(), + emptyDataPartitionType(), + () -> bufferPool, + connManager); + } + + protected void checkUntil(Runnable runnable) throws InterruptedException { + Throwable lastThrowable = null; + for (int i = 0; i < 100; i++) { + try { + runnable.run(); + return; + } catch (Throwable t) { + lastThrowable = t; + Thread.sleep(200); + } + } + LOG.info("", lastThrowable); + assertTrue(false); + } + + private static class FakedPartitionDataStore extends NoOpPartitionedDataStore { + + private final Supplier bufferSupplier; + private final AtomicInteger numReceivedBuffers; + private final List writingViews; + + FakedPartitionDataStore(Supplier bufferSupplier) { + this.bufferSupplier = bufferSupplier; + this.numReceivedBuffers = new AtomicInteger(0); + this.writingViews = new CopyOnWriteArrayList<>(); + } + + @Override + public DataPartitionWritingView createDataPartitionWritingView(WritingViewContext context) { + List buffers = new ArrayList<>(); + buffers.add(bufferSupplier.get()); + FakedDataPartitionWritingView writingView = + new FakedDataPartitionWritingView( + context.getDataSetID(), + context.getMapPartitionID(), + ignore -> numReceivedBuffers.addAndGet(1), + context.getDataRegionCreditListener(), + context.getFailureListener(), + buffers); + writingViews.add(writingView); + return writingView; + } + } +} diff --git a/shuffle-plugin/src/test/resources/log4j2-test.properties b/shuffle-plugin/src/test/resources/log4j2-test.properties new file mode 100644 index 00000000..d7fcb327 --- /dev/null +++ b/shuffle-plugin/src/test/resources/log4j2-test.properties @@ -0,0 +1,26 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +rootLogger.level=OFF +rootLogger.appenderRef.test.ref=TestLogger +appender.testlogger.name=TestLogger +appender.testlogger.type=CONSOLE +appender.testlogger.target=SYSTEM_ERR +appender.testlogger.layout.type=PatternLayout +appender.testlogger.layout.pattern=%-4r [%t] %-5p %c %x - %m%n diff --git a/shuffle-rpc/pom.xml b/shuffle-rpc/pom.xml new file mode 100644 index 00000000..a318fb03 --- /dev/null +++ b/shuffle-rpc/pom.xml @@ -0,0 +1,153 @@ + + + + + flink-shuffle-parent + com.alibaba.flink.shuffle + 1.0-SNAPSHOT + + 4.0.0 + + shuffle-rpc + + + 8 + 8 + + + + + com.alibaba.flink.shuffle + shuffle-common + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-core + ${project.version} + + + commons-cli + commons-cli + + + + + + org.apache.flink + flink-rpc-core + ${flink.version} + + + org.apache.flink + * + + + true + + + + org.apache.flink + flink-core + ${flink.version} + true + + + org.apache.flink + * + + + org.apache.commons + commons-lang3 + + + com.esotericsoftware.kryo + kryo + + + commons-collections + commons-collections + + + org.apache.commons + commons-compress + + + + + + org.apache.flink + flink-shaded-guava + 30.1.1-jre-14.0 + + + + + + + org.apache.maven.plugins + maven-shade-plugin + + + shade-remote-shuffle + package + + shade + + + false + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + flink-rpc-akka.jar + + + + + + + reference.conf + + + + + + org.apache.flink.shaded.** + + org.apache.flink + + com.alibaba.flink.shuffle.shaded.flink + + + + + + + + + + diff --git a/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleFencedRpcEndpoint.java b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleFencedRpcEndpoint.java new file mode 100644 index 00000000..1f0eade5 --- /dev/null +++ b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleFencedRpcEndpoint.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.rpc; + +import com.alibaba.flink.shuffle.rpc.executor.RpcMainThreadScheduledExecutor; +import com.alibaba.flink.shuffle.rpc.executor.ScheduledExecutor; + +import org.apache.flink.runtime.rpc.FencedRpcEndpoint; + +import javax.annotation.Nullable; + +import java.io.Serializable; + +/** {@link FencedRpcEndpoint} for remote shuffle. */ +public abstract class RemoteShuffleFencedRpcEndpoint + extends FencedRpcEndpoint { + + protected RemoteShuffleFencedRpcEndpoint( + RemoteShuffleRpcService rpcService, String endpointId, @Nullable F fencingToken) { + super(rpcService, endpointId, fencingToken); + } + + @Override + public RemoteShuffleRpcService getRpcService() { + return (RemoteShuffleRpcService) super.getRpcService(); + } + + public ScheduledExecutor getRpcMainThreadScheduledExecutor() { + return new RpcMainThreadScheduledExecutor(getMainThreadExecutor()); + } +} diff --git a/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleFencedRpcGateway.java b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleFencedRpcGateway.java new file mode 100644 index 00000000..9903050a --- /dev/null +++ b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleFencedRpcGateway.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.rpc; + +import org.apache.flink.runtime.rpc.FencedRpcGateway; + +import java.io.Serializable; + +/** {@link FencedRpcGateway} for remote shuffle. */ +public interface RemoteShuffleFencedRpcGateway + extends RemoteShuffleRpcGateway, FencedRpcGateway {} diff --git a/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleRpcEndpoint.java b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleRpcEndpoint.java new file mode 100644 index 00000000..b5e75bdd --- /dev/null +++ b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleRpcEndpoint.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.rpc; + +import com.alibaba.flink.shuffle.rpc.executor.RpcMainThreadScheduledExecutor; +import com.alibaba.flink.shuffle.rpc.executor.ScheduledExecutor; + +import org.apache.flink.runtime.rpc.RpcEndpoint; + +/** {@link RpcEndpoint} for remote shuffle. */ +public abstract class RemoteShuffleRpcEndpoint extends RpcEndpoint { + + protected RemoteShuffleRpcEndpoint(RemoteShuffleRpcService rpcService, String endpointId) { + super(rpcService, endpointId); + } + + @Override + public RemoteShuffleRpcService getRpcService() { + return (RemoteShuffleRpcService) super.getRpcService(); + } + + public ScheduledExecutor getRpcMainThreadScheduledExecutor() { + return new RpcMainThreadScheduledExecutor(getMainThreadExecutor()); + } +} diff --git a/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleRpcGateway.java b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleRpcGateway.java new file mode 100644 index 00000000..2b92fae1 --- /dev/null +++ b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleRpcGateway.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.rpc; + +import org.apache.flink.runtime.rpc.RpcGateway; + +/** {@link RpcGateway} for remote shuffle. */ +public interface RemoteShuffleRpcGateway extends RpcGateway {} diff --git a/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleRpcService.java b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleRpcService.java new file mode 100644 index 00000000..82cc1f11 --- /dev/null +++ b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleRpcService.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.rpc; + +import org.apache.flink.runtime.rpc.RpcService; + +import java.io.Serializable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; + +/** {@link RpcService} for remote shuffle service. */ +public interface RemoteShuffleRpcService extends RpcService { + + Executor getExecutor(); + + > + CompletableFuture connectTo(String address, F fencingToken, Class clazz); + + CompletableFuture connectTo( + String address, Class clazz); +} diff --git a/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleRpcServiceImpl.java b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleRpcServiceImpl.java new file mode 100644 index 00000000..794665bd --- /dev/null +++ b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RemoteShuffleRpcServiceImpl.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.rpc; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import org.apache.flink.runtime.rpc.FencedRpcGateway; +import org.apache.flink.runtime.rpc.RpcEndpoint; +import org.apache.flink.runtime.rpc.RpcGateway; +import org.apache.flink.runtime.rpc.RpcServer; +import org.apache.flink.runtime.rpc.RpcService; +import org.apache.flink.util.concurrent.ScheduledExecutor; + +import java.io.Serializable; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +/** + * A {@link RemoteShuffleRpcService} implementation which simply delegates to another {@link + * RpcService} instance. + */ +public class RemoteShuffleRpcServiceImpl implements RemoteShuffleRpcService { + + private final RpcService rpcService; + + public RemoteShuffleRpcServiceImpl(RpcService rpcService) { + CommonUtils.checkArgument(rpcService != null, "Must be not null."); + this.rpcService = rpcService; + } + + @Override + public String getAddress() { + return rpcService.getAddress(); + } + + @Override + public int getPort() { + return rpcService.getPort(); + } + + @Override + public CompletableFuture connect(String address, Class clazz) { + return rpcService.connect(address, clazz); + } + + @Override + public > CompletableFuture connect( + String address, F fencingToken, Class clazz) { + return rpcService.connect(address, fencingToken, clazz); + } + + @Override + public RpcServer startServer(C rpcEndpoint) { + return rpcService.startServer(rpcEndpoint); + } + + @Override + public RpcServer fenceRpcServer(RpcServer rpcServer, F fencingToken) { + return rpcService.fenceRpcServer(rpcServer, fencingToken); + } + + @Override + public void stopServer(RpcServer selfGateway) { + rpcService.stopServer(selfGateway); + } + + @Override + public CompletableFuture stopService() { + return rpcService.stopService(); + } + + @Override + public CompletableFuture getTerminationFuture() { + return rpcService.getTerminationFuture(); + } + + @Override + public ScheduledExecutor getScheduledExecutor() { + return rpcService.getScheduledExecutor(); + } + + @Override + public ScheduledFuture scheduleRunnable(Runnable runnable, long delay, TimeUnit unit) { + return rpcService.scheduleRunnable(runnable, delay, unit); + } + + @Override + public void execute(Runnable runnable) { + rpcService.execute(runnable); + } + + @Override + public CompletableFuture execute(Callable callable) { + return rpcService.execute(callable); + } + + @Override + public Executor getExecutor() { + return getScheduledExecutor(); + } + + @Override + public > + CompletableFuture connectTo(String address, F fencingToken, Class clazz) { + return connect(address, fencingToken, clazz); + } + + @Override + public CompletableFuture connectTo( + String address, Class clazz) { + return connect(address, clazz); + } +} diff --git a/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RpcTargetAddress.java b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RpcTargetAddress.java new file mode 100644 index 00000000..df4a6fd6 --- /dev/null +++ b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/RpcTargetAddress.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.rpc; + +import java.util.Objects; +import java.util.UUID; + +/** The address of a rpc service with registration. */ +public class RpcTargetAddress { + + /** The rpc address. */ + private final String targetAddress; + + /** The uuid for the target. */ + private final UUID leaderUUID; + + public RpcTargetAddress(String targetAddress, UUID leaderUUID) { + this.targetAddress = targetAddress; + this.leaderUUID = leaderUUID; + } + + public String getTargetAddress() { + return targetAddress; + } + + public UUID getLeaderUUID() { + return leaderUUID; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + RpcTargetAddress that = (RpcTargetAddress) o; + return Objects.equals(targetAddress, that.targetAddress) + && Objects.equals(leaderUUID, that.leaderUUID); + } + + @Override + public int hashCode() { + return Objects.hash(targetAddress, leaderUUID); + } + + @Override + public String toString() { + return "RpcRegistrationConnection{" + + "targetAddress='" + + targetAddress + + "', leaderUUID=" + + leaderUUID + + '}'; + } +} diff --git a/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/executor/RpcMainThreadScheduledExecutor.java b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/executor/RpcMainThreadScheduledExecutor.java new file mode 100644 index 00000000..a5ec3d63 --- /dev/null +++ b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/executor/RpcMainThreadScheduledExecutor.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.rpc.executor; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import javax.annotation.Nonnull; + +import java.util.concurrent.Callable; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +/** + * A {@link ScheduledExecutor} implementation which simply delegates to an {@link + * org.apache.flink.util.concurrent.ScheduledExecutor} instance. + */ +public class RpcMainThreadScheduledExecutor implements ScheduledExecutor { + + private final org.apache.flink.util.concurrent.ScheduledExecutor scheduledExecutor; + + public RpcMainThreadScheduledExecutor( + org.apache.flink.util.concurrent.ScheduledExecutor scheduledExecutor) { + CommonUtils.checkArgument(scheduledExecutor != null, "Must be not null."); + this.scheduledExecutor = scheduledExecutor; + } + + @Override + public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { + return scheduledExecutor.schedule(command, delay, unit); + } + + @Override + public ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit) { + return scheduledExecutor.schedule(callable, delay, unit); + } + + @Override + public ScheduledFuture scheduleAtFixedRate( + Runnable command, long initialDelay, long period, TimeUnit unit) { + return scheduledExecutor.scheduleAtFixedRate(command, initialDelay, period, unit); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay( + Runnable command, long initialDelay, long delay, TimeUnit unit) { + return scheduledExecutor.scheduleWithFixedDelay(command, initialDelay, delay, unit); + } + + @Override + public void execute(@Nonnull Runnable command) { + scheduledExecutor.execute(command); + } +} diff --git a/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/executor/ScheduledExecutor.java b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/executor/ScheduledExecutor.java new file mode 100644 index 00000000..7a9cc572 --- /dev/null +++ b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/executor/ScheduledExecutor.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.rpc.executor; + +import java.util.concurrent.Callable; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +/** + * Extension for the {@link Executor} interface which is enriched by method for scheduling tasks in + * the future. + * + *

This class is copied and modified from Apache Flink + * (org.apache.flink.util.concurrent.ScheduledExecutor). + */ +public interface ScheduledExecutor extends Executor { + + /** + * Executes the given command after the given delay. + * + * @param command the task to execute in the future + * @param delay the time from now to delay the execution + * @param unit the time unit of the delay parameter + * @return a ScheduledFuture representing the completion of the scheduled task + */ + ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit); + + /** + * Executes the given callable after the given delay. The result of the callable is returned as + * a {@link ScheduledFuture}. + * + * @param callable the callable to execute + * @param delay the time from now to delay the execution + * @param unit the time unit of the delay parameter + * @param result type of the callable + * @return a ScheduledFuture which holds the future value of the given callable + */ + ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit); + + /** + * Executes the given command periodically. The first execution is started after the {@code + * initialDelay}, the second execution is started after {@code initialDelay + period}, the third + * after {@code initialDelay + 2*period} and so on. The task is executed until either an + * execution fails, or the returned {@link ScheduledFuture} is cancelled. + * + * @param command the task to be executed periodically + * @param initialDelay the time from now until the first execution is triggered + * @param period the time after which the next execution is triggered + * @param unit the time unit of the delay and period parameter + * @return a ScheduledFuture representing the periodic task. This future never completes unless + * an execution of the given task fails or if the future is cancelled + */ + ScheduledFuture scheduleAtFixedRate( + Runnable command, long initialDelay, long period, TimeUnit unit); + + /** + * Executed the given command repeatedly with the given delay between the end of an execution + * and the start of the next execution. The task is executed repeatedly until either an + * exception occurs or if the returned {@link ScheduledFuture} is cancelled. + * + * @param command the task to execute repeatedly + * @param initialDelay the time from now until the first execution is triggered + * @param delay the time between the end of the current and the start of the next execution + * @param unit the time unit of the initial delay and the delay parameter + * @return a ScheduledFuture representing the repeatedly executed task. This future never + * completes unless the execution of the given task fails or if the future is cancelled + */ + ScheduledFuture scheduleWithFixedDelay( + Runnable command, long initialDelay, long delay, TimeUnit unit); +} diff --git a/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/executor/ScheduledExecutorServiceAdapter.java b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/executor/ScheduledExecutorServiceAdapter.java new file mode 100644 index 00000000..64551d5b --- /dev/null +++ b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/executor/ScheduledExecutorServiceAdapter.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.rpc.executor; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import javax.annotation.Nonnull; + +import java.util.concurrent.Callable; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +/** + * Adapter class for a {@link ScheduledExecutorService} which shall be used as a {@link + * org.apache.flink.util.concurrent.ScheduledExecutor}. + */ +public class ScheduledExecutorServiceAdapter implements ScheduledExecutor { + + private final ScheduledExecutorService scheduledExecutorService; + + public ScheduledExecutorServiceAdapter(ScheduledExecutorService scheduledExecutorService) { + this.scheduledExecutorService = CommonUtils.checkNotNull(scheduledExecutorService); + } + + @Override + public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { + return scheduledExecutorService.schedule(command, delay, unit); + } + + @Override + public ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit) { + return scheduledExecutorService.schedule(callable, delay, unit); + } + + @Override + public ScheduledFuture scheduleAtFixedRate( + Runnable command, long initialDelay, long period, TimeUnit unit) { + return scheduledExecutorService.scheduleAtFixedRate(command, initialDelay, period, unit); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay( + Runnable command, long initialDelay, long delay, TimeUnit unit) { + return scheduledExecutorService.scheduleWithFixedDelay(command, initialDelay, delay, unit); + } + + @Override + public void execute(@Nonnull Runnable command) { + scheduledExecutorService.execute(command); + } +} diff --git a/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/message/Acknowledge.java b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/message/Acknowledge.java new file mode 100644 index 00000000..663c4055 --- /dev/null +++ b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/message/Acknowledge.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.rpc.message; + +import java.io.Serializable; + +/** A generic acknowledgement message. */ +public class Acknowledge implements Serializable { + + private static final long serialVersionUID = -6303026142298217270L; + + /** The singleton instance. */ + private static final Acknowledge INSTANCE = new Acknowledge(); + + /** + * Gets the singleton instance. + * + * @return The singleton instance. + */ + public static Acknowledge get() { + return INSTANCE; + } + + // ------------------------------------------------------------------------ + + /** Private constructor to prevent instantiation. */ + private Acknowledge() {} + + // ------------------------------------------------------------------------ + + @Override + public boolean equals(Object obj) { + return obj != null && obj.getClass() == Acknowledge.class; + } + + @Override + public int hashCode() { + return 41; + } + + @Override + public String toString() { + return getClass().getSimpleName(); + } + + /** + * Read resolve to preserve the singleton object property. (per best practices, this should have + * visibility 'protected') + */ + protected Object readResolve() throws java.io.ObjectStreamException { + return INSTANCE; + } +} diff --git a/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/test/TestingRpcService.java b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/test/TestingRpcService.java new file mode 100644 index 00000000..1b74d546 --- /dev/null +++ b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/test/TestingRpcService.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.rpc.test; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.FutureUtils; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleFencedRpcGateway; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcGateway; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcServiceImpl; +import com.alibaba.flink.shuffle.rpc.utils.AkkaRpcServiceUtils; + +import java.io.Serializable; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executor; +import java.util.function.Function; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** + * An RPC Service implementation for testing. This RPC service acts as a replacement for the regular + * RPC service for cases where tests need to return prepared mock gateways instead of proper RPC + * gateways. + * + *

The TestingRpcService can be used for example in the following fashion, using Mockito + * for mocks and verification: + * + *

{@code
+ * TestingRpcService rpc = new TestingRpcService();
+ *
+ * ShuffleManagerGateway testGateway = mock(ShuffleManagerGateway.class);
+ * rpc.registerGateway("myAddress", testGateway);
+ *
+ * MyComponentToTest component = new MyComponentToTest();
+ * component.triggerSomethingThatCallsTheGateway();
+ *
+ * verify(testGateway, timeout(1000)).theTestMethod(any(UUID.class), anyString());
+ * }
+ */ +public class TestingRpcService extends RemoteShuffleRpcServiceImpl { + + private static final Function< + RemoteShuffleRpcGateway, CompletableFuture> + DEFAULT_RPC_GATEWAY_FUTURE_FUNCTION = CompletableFuture::completedFuture; + + /** Map of pre-registered connections. */ + private final ConcurrentHashMap registeredConnections; + + private volatile Function> + rpcGatewayFutureFunction = DEFAULT_RPC_GATEWAY_FUTURE_FUNCTION; + + /** Creates a new {@code TestingRpcService}, using the given configuration. */ + public TestingRpcService() { + super(startRpcService()); + + this.registeredConnections = new ConcurrentHashMap<>(); + } + + private static RemoteShuffleRpcService startRpcService() { + try { + return AkkaRpcServiceUtils.createRemoteRpcService( + new Configuration(), null, "0", null, Optional.empty()); + } catch (Throwable throwable) { + throw new RuntimeException(throwable); + } + } + + // ------------------------------------------------------------------------ + + @Override + public CompletableFuture stopService() { + final CompletableFuture terminationFuture = super.stopService(); + + terminationFuture.whenComplete( + (Void ignored, Throwable throwable) -> registeredConnections.clear()); + + return terminationFuture; + } + + // ------------------------------------------------------------------------ + // connections + // ------------------------------------------------------------------------ + + public void registerGateway(String address, RemoteShuffleRpcGateway gateway) { + checkNotNull(address); + checkNotNull(gateway); + + if (registeredConnections.putIfAbsent(address, gateway) != null) { + throw new IllegalStateException("a gateway is already registered under " + address); + } + } + + @SuppressWarnings("unchecked") + private CompletableFuture getRpcGatewayFuture( + C gateway) { + return (CompletableFuture) rpcGatewayFutureFunction.apply(gateway); + } + + public void clearGateways() { + registeredConnections.clear(); + } + + public void resetRpcGatewayFutureFunction() { + rpcGatewayFutureFunction = DEFAULT_RPC_GATEWAY_FUTURE_FUNCTION; + } + + public void setRpcGatewayFutureFunction( + Function> + rpcGatewayFutureFunction) { + this.rpcGatewayFutureFunction = rpcGatewayFutureFunction; + } + + @Override + public Executor getExecutor() { + return getScheduledExecutor(); + } + + @Override + public > + CompletableFuture connectTo(String address, F fencingToken, Class clazz) { + RemoteShuffleRpcGateway gateway = registeredConnections.get(address); + + if (gateway != null) { + if (clazz.isAssignableFrom(gateway.getClass())) { + @SuppressWarnings("unchecked") + C typedGateway = (C) gateway; + return getRpcGatewayFuture(typedGateway); + } else { + return FutureUtils.completedExceptionally( + new Exception( + "Gateway registered under " + + address + + " is not of type " + + clazz)); + } + } else { + return super.connectTo(address, fencingToken, clazz); + } + } + + @Override + public CompletableFuture connectTo( + String address, Class clazz) { + RemoteShuffleRpcGateway gateway = registeredConnections.get(address); + + if (gateway != null) { + if (clazz.isAssignableFrom(gateway.getClass())) { + @SuppressWarnings("unchecked") + C typedGateway = (C) gateway; + return getRpcGatewayFuture(typedGateway); + } else { + return FutureUtils.completedExceptionally( + new Exception( + "Gateway registered under " + + address + + " is not of type " + + clazz)); + } + } else { + return super.connectTo(address, clazz); + } + } +} diff --git a/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/utils/AkkaRpcServiceUtils.java b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/utils/AkkaRpcServiceUtils.java new file mode 100644 index 00000000..41d90c31 --- /dev/null +++ b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/utils/AkkaRpcServiceUtils.java @@ -0,0 +1,432 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.rpc.utils; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.FatalErrorExitUtils; +import com.alibaba.flink.shuffle.core.config.RpcOptions; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcServiceImpl; + +import org.apache.flink.configuration.AkkaOptions; +import org.apache.flink.configuration.ConfigurationUtils; +import org.apache.flink.configuration.CoreOptions; +import org.apache.flink.core.classloading.ComponentClassLoader; +import org.apache.flink.runtime.rpc.AddressResolution; +import org.apache.flink.runtime.rpc.RpcSystem; +import org.apache.flink.util.IOUtils; +import org.apache.flink.util.NetUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.io.InputStream; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.URL; +import java.net.UnknownHostException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Optional; +import java.util.ServiceLoader; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * These RPC utilities contain helper methods around RPC use, such as starting an RPC service, or + * constructing RPC addresses. + * + *

This class is partly copied from Apache Flink + * (org.apache.flink.runtime.rpc.akka.AkkaRpcServiceUtils). + */ +public class AkkaRpcServiceUtils { + + private static final AtomicReference rpcSystemRef = new AtomicReference<>(); + + private static final String SUPERVISOR_NAME = "rpc"; + + private static final String AKKA_TCP = "akka.tcp"; + + private static final String AKKA_SSL_TCP = "akka.ssl.tcp"; + + private static final String actorSystemName = "remote-shuffle"; + + private static final AtomicLong nextNameOffset = new AtomicLong(0L); + + public static RemoteShuffleRpcService createRemoteRpcService( + Configuration configuration, + @Nullable String externalAddress, + String externalPortRange, + @Nullable String bindAddress, + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") Optional bindPort) + throws Exception { + final AkkaRpcServiceBuilder akkaRpcServiceBuilder = + remoteServiceBuilder(configuration, externalAddress, externalPortRange); + if (bindAddress != null) { + akkaRpcServiceBuilder.withBindAddress(bindAddress); + } + bindPort.ifPresent(akkaRpcServiceBuilder::withBindPort); + return akkaRpcServiceBuilder.createAndStart(); + } + + public static AkkaRpcServiceBuilder remoteServiceBuilder( + Configuration configuration, + @Nullable String externalAddress, + String externalPortRange) { + return new AkkaRpcServiceBuilder(configuration, externalAddress, externalPortRange); + } + + /** + * @param hostname The hostname or address where the target RPC service is listening. + * @param port The port where the target RPC service is listening. + * @param endpointName The name of the RPC endpoint. + * @param akkaProtocol True, if security/encryption is enabled, false otherwise. + * @return The RPC URL of the specified RPC endpoint. + */ + public static String getRpcUrl( + String hostname, int port, String endpointName, AkkaProtocol akkaProtocol) { + + checkArgument(hostname != null, "Hostname is null."); + checkArgument(endpointName != null, "EndpointName is null."); + checkArgument(NetUtils.isValidClientPort(port), "Port must be in [1, 65535]"); + + String hostPort = NetUtils.unresolvedHostAndPortToNormalizedString(hostname, port); + return internalRpcUrl(endpointName, new RemoteAddressInformation(hostPort, akkaProtocol)); + } + + private static String internalRpcUrl( + String endpointName, RemoteAddressInformation remoteAddressInformation) { + String protocolPrefix = akkaProtocolToString(remoteAddressInformation.getAkkaProtocol()); + String hostPort = remoteAddressInformation.getHostnameAndPort(); + + // protocolPrefix://flink-remote-shuffle[@hostname:port]/user/rpc/endpointName + return String.format("%s://" + actorSystemName, protocolPrefix) + + "@" + + hostPort + + "/user/" + + SUPERVISOR_NAME + + "/" + + endpointName; + } + + private static String akkaProtocolToString(AkkaProtocol akkaProtocol) { + return akkaProtocol == AkkaProtocol.SSL_TCP ? AKKA_SSL_TCP : AKKA_TCP; + } + + public static InetSocketAddress getInetSocketAddressFromAkkaURL(String akkaURL) + throws Exception { + checkState(rpcSystemRef.get() != null, "Rpc system is not initialized."); + return rpcSystemRef.get().getInetSocketAddressFromRpcUrl(akkaURL); + } + + /** + * Creates a random name of the form prefix_X, where X is an increasing number. + * + * @param prefix Prefix string to prepend to the monotonically increasing name offset number + * @return A random name of the form prefix_X where X is an increasing number + */ + public static String createRandomName(String prefix) { + CommonUtils.checkArgument(prefix != null, "Prefix must not be null."); + + long nameOffset; + + // obtain the next name offset by incrementing it atomically + do { + nameOffset = nextNameOffset.get(); + } while (!nextNameOffset.compareAndSet(nameOffset, nameOffset + 1L)); + + return prefix + '_' + nameOffset; + } + + /** + * Creates a wildcard name symmetric to {@link #createRandomName(String)}. + * + * @param prefix prefix of the wildcard name + * @return wildcard name starting with the prefix + */ + public static String createWildcardName(String prefix) { + return prefix + "_*"; + } + + /** Whether to use TCP or encrypted TCP for Akka. */ + public enum AkkaProtocol { + TCP, + SSL_TCP + } + + private static final class RemoteAddressInformation { + private final String hostnameAndPort; + private final AkkaProtocol akkaProtocol; + + private RemoteAddressInformation(String hostnameAndPort, AkkaProtocol akkaProtocol) { + this.hostnameAndPort = hostnameAndPort; + this.akkaProtocol = akkaProtocol; + } + + private String getHostnameAndPort() { + return hostnameAndPort; + } + + private AkkaProtocol getAkkaProtocol() { + return akkaProtocol; + } + } + + /** Builder for {@link RemoteShuffleRpcService}. */ + public static class AkkaRpcServiceBuilder { + + private final org.apache.flink.configuration.Configuration flinkConf; + + @Nullable private final String externalAddress; + @Nullable private final String externalPortRange; + + private String actorSystemName = AkkaRpcServiceUtils.actorSystemName; + + private final RpcSystem.ForkJoinExecutorConfiguration forkJoinExecutorConfiguration; + + private String bindAddress = NetUtils.getWildcardIPAddress(); + @Nullable private Integer bindPort = null; + + /** Builder for creating a remote RPC service. */ + private AkkaRpcServiceBuilder( + Configuration configuration, + @Nullable String externalAddress, + @Nullable String externalPortRange) { + CommonUtils.checkArgument(configuration != null, "Must be not null."); + CommonUtils.checkArgument(externalPortRange != null, "Must be not null."); + + this.flinkConf = + org.apache.flink.configuration.Configuration.fromMap(configuration.toMap()); + // convert remote shuffle configuration to flink configuration + flinkConf.set( + AkkaOptions.ASK_TIMEOUT_DURATION, + configuration.getDuration(RpcOptions.RPC_TIMEOUT)); + flinkConf.set( + AkkaOptions.FRAMESIZE, configuration.getString(RpcOptions.AKKA_FRAME_SIZE)); + if (!FatalErrorExitUtils.isNeedStopProcess()) { + flinkConf.set(AkkaOptions.JVM_EXIT_ON_FATAL_ERROR, false); + } + + this.externalAddress = + externalAddress == null + ? InetAddress.getLoopbackAddress().getHostAddress() + : externalAddress; + this.externalPortRange = externalPortRange; + this.forkJoinExecutorConfiguration = getForkJoinExecutorConfiguration(flinkConf); + } + + public AkkaRpcServiceBuilder withActorSystemName(String actorSystemName) { + this.actorSystemName = CommonUtils.checkNotNull(actorSystemName); + return this; + } + + public AkkaRpcServiceBuilder withBindAddress(String bindAddress) { + this.bindAddress = CommonUtils.checkNotNull(bindAddress); + return this; + } + + public AkkaRpcServiceBuilder withBindPort(int bindPort) { + CommonUtils.checkArgument( + NetUtils.isValidHostPort(bindPort), "Invalid port number: " + bindPort); + this.bindPort = bindPort; + return this; + } + + public RemoteShuffleRpcService createAndStart() throws Exception { + if (rpcSystemRef.get() == null) { + loadRpcSystem(new Configuration()); + } + + RpcSystem.RpcServiceBuilder rpcServiceBuilder; + if (externalAddress == null) { + // create local actor system + rpcServiceBuilder = rpcSystemRef.get().localServiceBuilder(flinkConf); + } else { + // create remote actor system + rpcServiceBuilder = + rpcSystemRef + .get() + .remoteServiceBuilder( + flinkConf, externalAddress, externalPortRange); + } + + rpcServiceBuilder + .withComponentName(actorSystemName) + .withBindAddress(bindAddress) + .withExecutorConfiguration(forkJoinExecutorConfiguration); + if (bindPort != null) { + rpcServiceBuilder.withBindPort(bindPort); + } + return new RemoteShuffleRpcServiceImpl(rpcServiceBuilder.createAndStart()); + } + } + + public static RpcSystem.ForkJoinExecutorConfiguration getForkJoinExecutorConfiguration( + org.apache.flink.configuration.Configuration configuration) { + double parallelismFactor = + configuration.getDouble(AkkaOptions.FORK_JOIN_EXECUTOR_PARALLELISM_FACTOR); + int minParallelism = + configuration.getInteger(AkkaOptions.FORK_JOIN_EXECUTOR_PARALLELISM_MIN); + int maxParallelism = + configuration.getInteger(AkkaOptions.FORK_JOIN_EXECUTOR_PARALLELISM_MAX); + + return new RpcSystem.ForkJoinExecutorConfiguration( + parallelismFactor, minParallelism, maxParallelism); + } + + public static void loadRpcSystem(Configuration configuration) { + try { + if (rpcSystemRef.get() != null) { + return; + } + + org.apache.flink.configuration.Configuration flinkConf = + org.apache.flink.configuration.Configuration.fromMap(configuration.toMap()); + ClassLoader flinkClassLoader = RpcSystem.class.getClassLoader(); + + Path tmpDirectory = Paths.get(ConfigurationUtils.parseTempDirectories(flinkConf)[0]); + Files.createDirectories(tmpDirectory); + Path tempFile = + Files.createFile( + tmpDirectory.resolve("flink-rpc-akka_" + UUID.randomUUID() + ".jar")); + + boolean isShaded = RpcSystem.class.getName().startsWith("com.alibaba.flink.shuffle"); + String rpcJarName = isShaded ? "shaded-flink-rpc-akka.jar" : "flink-rpc-akka.jar"; + InputStream resourceStream = flinkClassLoader.getResourceAsStream(rpcJarName); + if (resourceStream == null) { + throw new RuntimeException("Akka RPC system could not be found."); + } + + IOUtils.copyBytes(resourceStream, Files.newOutputStream(tempFile)); + + ComponentClassLoader classLoader = + new ComponentClassLoader( + new URL[] {tempFile.toUri().toURL()}, + flinkClassLoader, + CoreOptions.parseParentFirstLoaderPatterns( + CoreOptions.PARENT_FIRST_LOGGING_PATTERNS, ""), + new String[] { + isShaded ? "org.apache.flink.shaded" : "org.apache.flink", + "com.alibaba.flink.shuffle" + }); + + RpcSystem newRpcSystem = + new CleanupOnCloseRpcSystem( + ServiceLoader.load(RpcSystem.class, classLoader).iterator().next(), + classLoader, + tempFile); + + if (!rpcSystemRef.compareAndSet(null, newRpcSystem)) { + newRpcSystem.close(); + } else { + Runtime.getRuntime() + .addShutdownHook(new Thread(AkkaRpcServiceUtils::closeRpcSystem)); + } + } catch (IOException e) { + throw new RuntimeException("Could not initialize RPC system.", e); + } + } + + public static void closeRpcSystem() { + RpcSystem rpcSystem = rpcSystemRef.get(); + if (rpcSystem != null && rpcSystemRef.compareAndSet(rpcSystem, null)) { + rpcSystem.close(); + } + } + + /** + * This is copied from Apache Flink (org.apache.flink.runtime.rpc.akka.CleanupOnCloseRpcSystem). + */ + private static final class CleanupOnCloseRpcSystem implements RpcSystem { + private static final Logger LOG = LoggerFactory.getLogger(CleanupOnCloseRpcSystem.class); + + private final RpcSystem rpcSystem; + private final ComponentClassLoader classLoader; + private final Path tempFile; + + public CleanupOnCloseRpcSystem( + RpcSystem rpcSystem, ComponentClassLoader classLoader, Path tempFile) { + this.rpcSystem = CommonUtils.checkNotNull(rpcSystem); + this.classLoader = CommonUtils.checkNotNull(classLoader); + this.tempFile = CommonUtils.checkNotNull(tempFile); + } + + @Override + public void close() { + rpcSystem.close(); + + try { + classLoader.close(); + } catch (Exception e) { + LOG.warn("Could not close RpcSystem classloader.", e); + } + try { + Files.delete(tempFile); + } catch (Exception e) { + LOG.warn("Could not delete temporary rpc system file {}.", tempFile, e); + } + } + + @Override + public RpcServiceBuilder localServiceBuilder( + org.apache.flink.configuration.Configuration config) { + return rpcSystem.localServiceBuilder(config); + } + + @Override + public RpcServiceBuilder remoteServiceBuilder( + org.apache.flink.configuration.Configuration configuration, + @Nullable String externalAddress, + String externalPortRange) { + return rpcSystem.remoteServiceBuilder( + configuration, externalAddress, externalPortRange); + } + + @Override + public String getRpcUrl( + String hostname, + int port, + String endpointName, + AddressResolution addressResolution, + org.apache.flink.configuration.Configuration config) + throws UnknownHostException { + return rpcSystem.getRpcUrl(hostname, port, endpointName, addressResolution, config); + } + + @Override + public InetSocketAddress getInetSocketAddressFromRpcUrl(String url) throws Exception { + return rpcSystem.getInetSocketAddressFromRpcUrl(url); + } + + @Override + public long getMaximumMessageSizeInBytes( + org.apache.flink.configuration.Configuration config) { + return rpcSystem.getMaximumMessageSizeInBytes(config); + } + } +} diff --git a/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/utils/RpcUtils.java b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/utils/RpcUtils.java new file mode 100644 index 00000000..dda8e5a4 --- /dev/null +++ b/shuffle-rpc/src/main/java/com/alibaba/flink/shuffle/rpc/utils/RpcUtils.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.rpc.utils; + +import com.alibaba.flink.shuffle.rpc.RemoteShuffleFencedRpcEndpoint; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcEndpoint; +import com.alibaba.flink.shuffle.rpc.RemoteShuffleRpcService; + +import org.apache.flink.runtime.rpc.RpcEndpoint; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** Util methods for remote shuffle rpc. */ +public class RpcUtils { + + /** + * Shuts the given {@link RpcEndpoint} down and awaits its termination. + * + * @param rpcEndpoint to terminate + * @param timeout for this operation + * @throws ExecutionException if a problem occurred + * @throws InterruptedException if the operation has been interrupted + * @throws TimeoutException if a timeout occurred + */ + public static void terminateRpcEndpoint(RemoteShuffleRpcEndpoint rpcEndpoint, long timeout) + throws ExecutionException, InterruptedException, TimeoutException { + rpcEndpoint.closeAsync().get(timeout, TimeUnit.MILLISECONDS); + } + + /** + * Shuts the given {@link RpcEndpoint} down and awaits its termination. + * + * @param rpcEndpoint to terminate + * @param timeout for this operation + * @throws ExecutionException if a problem occurred + * @throws InterruptedException if the operation has been interrupted + * @throws TimeoutException if a timeout occurred + */ + public static void terminateRpcEndpoint( + RemoteShuffleFencedRpcEndpoint rpcEndpoint, long timeout) + throws ExecutionException, InterruptedException, TimeoutException { + rpcEndpoint.closeAsync().get(timeout, TimeUnit.MILLISECONDS); + } + + /** + * Shuts the given rpc service down and waits for its termination. + * + * @param rpcService to shut down + * @param timeout for this operation + * @throws InterruptedException if the operation has been interrupted + * @throws ExecutionException if a problem occurred + * @throws TimeoutException if a timeout occurred + */ + public static void terminateRpcService(RemoteShuffleRpcService rpcService, long timeout) + throws InterruptedException, ExecutionException, TimeoutException { + rpcService.stopService().get(timeout, TimeUnit.MILLISECONDS); + } +} diff --git a/shuffle-rpc/src/main/resources/flink-rpc-akka.jar b/shuffle-rpc/src/main/resources/flink-rpc-akka.jar new file mode 100644 index 00000000..524f199d Binary files /dev/null and b/shuffle-rpc/src/main/resources/flink-rpc-akka.jar differ diff --git a/shuffle-rpc/src/main/resources/shaded-flink-rpc-akka.jar b/shuffle-rpc/src/main/resources/shaded-flink-rpc-akka.jar new file mode 100644 index 00000000..edd69336 Binary files /dev/null and b/shuffle-rpc/src/main/resources/shaded-flink-rpc-akka.jar differ diff --git a/shuffle-storage/pom.xml b/shuffle-storage/pom.xml new file mode 100644 index 00000000..beb8875a --- /dev/null +++ b/shuffle-storage/pom.xml @@ -0,0 +1,58 @@ + + + + + com.alibaba.flink.shuffle + flink-shuffle-parent + 1.0-SNAPSHOT + + 4.0.0 + + shuffle-storage + + + + + com.alibaba.flink.shuffle + shuffle-common + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-core + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-metrics + ${project.version} + + + + org.apache.flink + flink-shaded-netty + 4.1.49.Final-${flink.shaded.version} + provided + + + + diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/StorageMetrics.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/StorageMetrics.java new file mode 100644 index 00000000..8c257d28 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/StorageMetrics.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage; + +import com.alibaba.flink.shuffle.metrics.entry.MetricUtils; + +import com.alibaba.metrics.Counter; +import com.alibaba.metrics.Gauge; + +import java.util.function.Supplier; + +/** Constants and util metrics of STORAGE metrics. */ +public class StorageMetrics { + + // Group name + public static final String STORAGE = "remote-shuffle.storage"; + + // Available writing buffers in buffer pool. + public static final String NUM_AVAILABLE_WRITING_BUFFERS = + STORAGE + ".num_available_writing_buffers"; + + // Available reading buffers in buffer pool. + public static final String NUM_AVAILABLE_READING_BUFFERS = + STORAGE + ".num_available_reading_buffers"; + + // Number of data partitions stored. + public static final String NUM_DATA_PARTITIONS = STORAGE + ".num_data_partitions"; + + // Shuffle data size in bytes. + public static final String NUM_BYTES_DATA = STORAGE + ".data_size_bytes"; + + // Index data size in bytes. + public static final String NUM_BYTES_INDEX = STORAGE + ".index_size_bytes"; + + public static void registerGaugeForNumAvailableWritingBuffers(Supplier availableNum) { + MetricUtils.registerMetric( + STORAGE, + NUM_AVAILABLE_WRITING_BUFFERS, + new Gauge() { + @Override + public Integer getValue() { + return availableNum.get(); + } + + @Override + public long lastUpdateTime() { + return System.currentTimeMillis(); + } + }); + } + + public static void registerGaugeForNumAvailableReadingBuffers(Supplier availableNum) { + MetricUtils.registerMetric( + STORAGE, + NUM_AVAILABLE_READING_BUFFERS, + new Gauge() { + @Override + public Integer getValue() { + return availableNum.get(); + } + + @Override + public long lastUpdateTime() { + return System.currentTimeMillis(); + } + }); + } + + public static Counter registerCounterForNumDataPartitions() { + return MetricUtils.getCounter(STORAGE, NUM_DATA_PARTITIONS); + } + + public static void registerGaugeForNumBytesDataSize(Supplier dataSizeBytes) { + MetricUtils.registerMetric( + STORAGE, + NUM_BYTES_DATA, + new Gauge() { + @Override + public Long getValue() { + return dataSizeBytes.get(); + } + + @Override + public long lastUpdateTime() { + return System.currentTimeMillis(); + } + }); + } + + public static void registerGaugeForNumBytesIndexSize(Supplier indexSizeBytes) { + MetricUtils.registerMetric( + STORAGE, + NUM_BYTES_INDEX, + new Gauge() { + @Override + public Long getValue() { + return indexSizeBytes.get(); + } + + @Override + public long lastUpdateTime() { + return System.currentTimeMillis(); + } + }); + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/datastore/PartitionReadingViewImpl.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/datastore/PartitionReadingViewImpl.java new file mode 100644 index 00000000..b849dba7 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/datastore/PartitionReadingViewImpl.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.datastore; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.storage.BufferWithBacklog; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReader; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReadingView; + +/** Implementation of {@link DataPartitionReadingView}. */ +public class PartitionReadingViewImpl implements DataPartitionReadingView { + + /** The corresponding {@link DataPartitionReader} to read data from. */ + private final DataPartitionReader reader; + + private boolean isError; + + public PartitionReadingViewImpl(DataPartitionReader reader) { + CommonUtils.checkArgument(reader != null, "Must be not null."); + this.reader = reader; + } + + @Override + public BufferWithBacklog nextBuffer() throws Exception { + checkNotInErrorState(); + + return reader.nextBuffer(); + } + + @Override + public void onError(Throwable throwable) { + checkNotInErrorState(); + + isError = true; + reader.onError(throwable); + } + + @Override + public boolean isFinished() { + checkNotInErrorState(); + + return reader.isFinished(); + } + + private void checkNotInErrorState() { + CommonUtils.checkState(!isError, "Reading view is in error state."); + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/datastore/PartitionWritingViewImpl.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/datastore/PartitionWritingViewImpl.java new file mode 100644 index 00000000..4aaba60f --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/datastore/PartitionWritingViewImpl.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.datastore; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; +import com.alibaba.flink.shuffle.core.listener.DataCommitListener; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.memory.BufferSupplier; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWriter; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWritingView; +import com.alibaba.flink.shuffle.core.utils.BufferUtils; + +import javax.annotation.Nullable; + +/** Implementation of {@link DataPartitionWritingView}. */ +public class PartitionWritingViewImpl implements DataPartitionWritingView { + + /** Target {@link DataPartitionWriter} to add data or event to. */ + private final DataPartitionWriter partitionWriter; + + /** Whether the {@link #onError} method has been called or not. */ + private boolean isError; + + /** Whether the {@link #finish} method has been called or not. */ + private boolean isInputFinished; + + /** + * Whether the {@link #regionStarted} method has been called and a new data region has started + * or not. + */ + private boolean isRegionStarted; + + public PartitionWritingViewImpl(DataPartitionWriter partitionWriter) { + CommonUtils.checkArgument(partitionWriter != null, "Must be not null."); + this.partitionWriter = partitionWriter; + } + + @Override + public void onBuffer(Buffer buffer, ReducePartitionID reducePartitionID) { + CommonUtils.checkArgument(buffer != null, "Must be not null."); + + try { + checkNotInErrorState(); + checkInputNotFinished(); + checkRegionStarted(); + CommonUtils.checkArgument(reducePartitionID != null, "Must be not null."); + } catch (Throwable throwable) { + BufferUtils.recycleBuffer(buffer); + throw throwable; + } + + partitionWriter.addBuffer(reducePartitionID, buffer); + } + + @Override + public void regionStarted(int dataRegionIndex, boolean isBroadcastRegion) { + CommonUtils.checkArgument(dataRegionIndex >= 0, "Must be non-negative."); + + checkNotInErrorState(); + checkInputNotFinished(); + checkRegionFinished(); + + isRegionStarted = true; + partitionWriter.startRegion(dataRegionIndex, isBroadcastRegion); + } + + @Override + public void regionFinished() { + checkNotInErrorState(); + checkInputNotFinished(); + checkRegionStarted(); + + isRegionStarted = false; + partitionWriter.finishRegion(); + } + + @Override + public void finish(DataCommitListener commitListener) { + CommonUtils.checkArgument(commitListener != null, "Must be not null."); + + checkNotInErrorState(); + checkInputNotFinished(); + checkRegionFinished(); + + isInputFinished = true; + partitionWriter.finishDataInput(commitListener); + } + + @Override + public void onError(@Nullable Throwable throwable) { + checkNotInErrorState(); + checkInputNotFinished(); + + isError = true; + partitionWriter.onError(throwable); + } + + @Override + public BufferSupplier getBufferSupplier() { + return partitionWriter; + } + + private void checkInputNotFinished() { + CommonUtils.checkState(!isInputFinished, "Writing view is already finished."); + } + + private void checkNotInErrorState() { + CommonUtils.checkState(!isError, "Writing view is in error state."); + } + + private void checkRegionStarted() { + CommonUtils.checkState(isRegionStarted, "Need to start a new data region first."); + } + + private void checkRegionFinished() { + CommonUtils.checkState(!isRegionStarted, "Need to finish the current data region first."); + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/datastore/PartitionedDataStoreImpl.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/datastore/PartitionedDataStoreImpl.java new file mode 100644 index 00000000..ba9886e1 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/datastore/PartitionedDataStoreImpl.java @@ -0,0 +1,655 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.datastore; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.config.MemorySize; +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.config.MemoryOptions; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.exception.DuplicatedPartitionException; +import com.alibaba.flink.shuffle.core.exception.PartitionNotFoundException; +import com.alibaba.flink.shuffle.core.executor.SimpleSingleThreadExecutorPool; +import com.alibaba.flink.shuffle.core.executor.SingleThreadExecutorPool; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.listener.PartitionStateListener; +import com.alibaba.flink.shuffle.core.memory.BufferDispatcher; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.core.storage.DataPartitionFactory; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReader; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReadingView; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWriter; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWritingView; +import com.alibaba.flink.shuffle.core.storage.DataSet; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.core.storage.ReadingViewContext; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; +import com.alibaba.flink.shuffle.core.storage.WritingViewContext; +import com.alibaba.flink.shuffle.storage.StorageMetrics; +import com.alibaba.flink.shuffle.storage.utils.DataPartitionUtils; + +import com.alibaba.metrics.Counter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.ServiceLoader; +import java.util.Set; +import java.util.concurrent.CompletableFuture; + +/** Implementation of {@link PartitionedDataStore}. */ +public class PartitionedDataStoreImpl implements PartitionedDataStore { + + private static final Logger LOG = LoggerFactory.getLogger(PartitionedDataStoreImpl.class); + + /** Lock object used to protect unsafe structures. */ + private final Object lock = new Object(); + + /** Read-only configuration of the shuffle cluster. */ + private final Configuration configuration; + + /** + * Listener to be notified when the state (created, deleted) of {@link DataPartition} changes. + */ + private final PartitionStateListener partitionStateListener; + + /** All {@link DataSet}s in this {@link PartitionedDataStore} indexed by {@link JobID}. */ + @GuardedBy("lock") + private final Map> dataSetsByJob; + + /** All {@link DataSet}s in this {@link PartitionedDataStore} indexed by {@link DataSetID}. */ + @GuardedBy("lock") + private final Map dataSets; + + /** All available {@link DataPartitionFactory}s in the classpath. */ + private final Map partitionFactories = new HashMap<>(); + + /** Buffer pool from where to allocate buffers for data writing. */ + private final BufferDispatcher writingBufferDispatcher; + + /** Buffer pool from where to allocate buffers for data reading. */ + private final BufferDispatcher readingBufferDispatcher; + + /** + * Executor pool from where to allocate single thread executors for data partition event + * processing. + */ + private final HashMap executorPools = new HashMap<>(); + + /** Number of data partitions stored in the {@link PartitionedDataStore} currently. */ + private final Counter numDataPartitions = StorageMetrics.registerCounterForNumDataPartitions(); + + /** Whether this data store has been shut down or not. */ + @GuardedBy("lock") + private boolean isShutDown; + + public PartitionedDataStoreImpl( + Configuration configuration, PartitionStateListener partitionStateListener) { + CommonUtils.checkArgument(configuration != null, "Must be not null."); + CommonUtils.checkArgument(partitionStateListener != null, "Must be not null."); + + this.configuration = configuration; + this.partitionStateListener = partitionStateListener; + this.dataSets = new HashMap<>(); + this.dataSetsByJob = new HashMap<>(); + + this.writingBufferDispatcher = createWritingBufferManager(configuration); + this.readingBufferDispatcher = createReadingBufferManager(configuration); + + ServiceLoader serviceLoader = + ServiceLoader.load(DataPartitionFactory.class); + + synchronized (lock) { + for (DataPartitionFactory partitionFactory : serviceLoader) { + try { + partitionFactory.initialize(configuration); + partitionFactories.put(partitionFactory.getClass().getName(), partitionFactory); + } catch (Throwable throwable) { + String className = partitionFactory.getClass().getName(); + LOG.warn( + "Failed to initialize {} because '{}'.", + className, + throwable.getMessage()); + } + } + + CommonUtils.checkState( + !partitionFactories.isEmpty(), "No valid partition factory found."); + } + } + + private BufferDispatcher createWritingBufferManager(Configuration configuration) { + BufferDispatcher ret = + createBufferManager( + configuration, + MemoryOptions.MEMORY_SIZE_FOR_DATA_WRITING, + "WRITING BUFFER POOL"); + StorageMetrics.registerGaugeForNumAvailableWritingBuffers(ret::numAvailableBuffers); + return ret; + } + + private BufferDispatcher createReadingBufferManager(Configuration configuration) { + BufferDispatcher ret = + createBufferManager( + configuration, + MemoryOptions.MEMORY_SIZE_FOR_DATA_READING, + "READING BUFFER POOL"); + StorageMetrics.registerGaugeForNumAvailableReadingBuffers(ret::numAvailableBuffers); + return ret; + } + + private BufferDispatcher createBufferManager( + Configuration configuration, + ConfigOption memorySizeOption, + String bufferManagerName) { + CommonUtils.checkArgument(configuration != null, "Must be not null."); + CommonUtils.checkArgument(memorySizeOption != null, "Must be not null."); + + MemorySize bufferSize = + CommonUtils.checkNotNull( + configuration.getMemorySize(MemoryOptions.MEMORY_BUFFER_SIZE)); + if (bufferSize.getBytes() <= 0) { + throw new ConfigurationException( + String.format( + "Illegal buffer size configured by %s, must be positive.", + MemoryOptions.MEMORY_BUFFER_SIZE.key())); + } + + MemorySize memorySize = + CommonUtils.checkNotNull(configuration.getMemorySize(memorySizeOption)); + if (memorySize.getBytes() < MemoryOptions.MIN_VALID_MEMORY_SIZE.getBytes()) { + throw new ConfigurationException( + String.format( + "The configured value of %s must be larger than %s.", + memorySizeOption.key(), + MemoryOptions.MIN_VALID_MEMORY_SIZE.toHumanReadableString())); + } + + int numBuffers = CommonUtils.checkedDownCast(memorySize.getBytes() / bufferSize.getBytes()); + if (numBuffers <= 0) { + throw new ConfigurationException( + String.format( + "The configured value of %s must be no smaller than the configured " + + "value of %s.", + memorySizeOption.key(), MemoryOptions.MEMORY_BUFFER_SIZE.key())); + } + + return new BufferDispatcher( + bufferManagerName, numBuffers, CommonUtils.checkedDownCast(bufferSize.getBytes())); + } + + private SimpleSingleThreadExecutorPool createExecutorPool(StorageMeta storageMeta) { + CommonUtils.checkArgument(storageMeta != null, "Must be not null."); + + ConfigOption configOption; + switch (storageMeta.getStorageType()) { + case SSD: + configOption = StorageOptions.STORAGE_SSD_NUM_EXECUTOR_THREADS; + break; + case HDD: + configOption = StorageOptions.STORAGE_NUM_THREADS_PER_HDD; + break; + case MEMORY: + configOption = StorageOptions.STORAGE_MEMORY_NUM_EXECUTOR_THREADS; + break; + default: + throw new ShuffleException("Illegal storage type."); + } + + Integer numThreads = configuration.getInteger(configOption); + if (numThreads == null || numThreads <= 0) { + // if negative value is configured, configuration exception will be thrown + throw new ConfigurationException( + String.format( + "The configured value of %s must be positive.", configOption.key())); + } + + // the actual number of threads will be min[configured value, 4 * (number of processors)] + numThreads = Math.min(numThreads, 4 * Runtime.getRuntime().availableProcessors()); + return new SimpleSingleThreadExecutorPool(numThreads, "datastore-executor-thread"); + } + + @Override + public DataPartitionWritingView createDataPartitionWritingView(WritingViewContext context) + throws Exception { + DataPartition dataPartition = null; + boolean isNewDataPartition = false; + DataPartitionFactory factory = + partitionFactories.get(context.getPartitionFactoryClassName()); + if (factory == null) { + throw new ShuffleException( + "Can not find target partition factory in classpath or partition factory " + + "initialization failed."); + } + + synchronized (lock) { + CommonUtils.checkState(!isShutDown, "Data store has been shut down."); + + DataSet dataSet = dataSets.get(context.getDataSetID()); + if (dataSet != null && dataSet.containsDataPartition(context.getDataPartitionID())) { + dataPartition = dataSet.getDataPartition(context.getDataPartitionID()); + } + + if (dataPartition == null) { + isNewDataPartition = true; + dataPartition = + factory.createDataPartition( + this, + context.getJobID(), + context.getDataSetID(), + context.getDataPartitionID(), + context.getNumReducePartitions()); + try { + partitionStateListener.onPartitionCreated(dataPartition.getPartitionMeta()); + addDataPartition(dataPartition); + } catch (Throwable throwable) { + onPartitionAddFailure(dataPartition, throwable); + throw throwable; + } + } + } + + try { + DataPartitionWriter partitionWriter = + dataPartition.createPartitionWriter( + context.getMapPartitionID(), + context.getDataRegionCreditListener(), + context.getFailureListener()); + return new PartitionWritingViewImpl(partitionWriter); + } catch (Throwable throwable) { + if (isNewDataPartition) { + // the new data partition should not hold any resource + removeDataPartition(dataPartition.getPartitionMeta()); + } + throw throwable; + } + } + + @Override + public DataPartitionReadingView createDataPartitionReadingView(ReadingViewContext context) + throws Exception { + DataPartition dataPartition = + getDataPartition(context.getDataSetID(), context.getPartitionID()); + if (dataPartition == null) { + // throw partition not found exception if it could not find the target data partition + throw new PartitionNotFoundException( + context.getDataSetID(), + context.getPartitionID(), + "can not be found in data store, possibly released"); + } + + if (!dataPartition.isConsumable()) { + PartitionNotFoundException exception = + new PartitionNotFoundException( + context.getDataSetID(), + context.getPartitionID(), + "released or not consumable"); + CommonUtils.runQuietly(() -> releaseDataPartition(dataPartition, exception, true)); + throw exception; + } + + DataPartitionReader partitionReader = + dataPartition.createPartitionReader( + context.getStartPartitionIndex(), + context.getEndPartitionIndex(), + context.getDataListener(), + context.getBacklogListener(), + context.getFailureListener()); + return new PartitionReadingViewImpl(partitionReader); + } + + @Override + public boolean isDataPartitionConsumable(DataPartitionMeta partitionMeta) { + CommonUtils.checkArgument(partitionMeta != null, "Must be not null."); + + synchronized (lock) { + DataSet dataSet = dataSets.get(partitionMeta.getDataSetID()); + if (dataSet == null) { + return false; + } + + DataPartition dataPartition = + dataSet.getDataPartition(partitionMeta.getDataPartitionID()); + if (dataPartition == null) { + return false; + } + + return dataPartition.isConsumable(); + } + } + + /** + * Failure handler in case of adding {@link DataPartition}s to this data store fails. It should + * never happen by design, we add this to catch potential bugs. + */ + private void onPartitionAddFailure(DataPartition dataPartition, Throwable throwable) { + CommonUtils.checkArgument(dataPartition != null, "Must be not null"); + + boolean removePartition = !(throwable instanceof DuplicatedPartitionException); + CommonUtils.runQuietly( + () -> releaseDataPartition(dataPartition, throwable, removePartition)); + + DataPartitionMeta partitionMeta = dataPartition.getPartitionMeta(); + LOG.error("Fatal: failed to add data partition: {}.", partitionMeta, throwable); + } + + @Override + public void addDataPartition(DataPartitionMeta partitionMeta) throws Exception { + DataPartitionFactory factory = + CommonUtils.checkNotNull( + partitionFactories.get(partitionMeta.getPartitionFactoryClassName())); + + final DataPartition dataPartition; + try { + // DataPartitionFactory#createDataPartition method must release + // all data partition resources itself if any exception occurs + dataPartition = factory.createDataPartition(this, partitionMeta); + } catch (Throwable throwable) { + CommonUtils.runQuietly(() -> partitionStateListener.onPartitionRemoved(partitionMeta)); + LOG.error("Failed to reconstruct data partition from meta.", throwable); + throw throwable; + } + + try { + addDataPartition(dataPartition); + } catch (Throwable throwable) { + onPartitionAddFailure(dataPartition, throwable); + throw throwable; + } + } + + @Override + public void removeDataPartition(DataPartitionMeta partitionMeta) { + CommonUtils.checkArgument(partitionMeta != null, "Must be not null"); + + synchronized (lock) { + DataSetID dataSetID = partitionMeta.getDataSetID(); + DataSet dataSet = dataSets.get(dataSetID); + + DataPartition dataPartition = null; + if (dataSet != null) { + dataPartition = dataSet.removeDataPartition(partitionMeta.getDataPartitionID()); + } + + if (dataPartition != null) { + numDataPartitions.dec(); + } + + Set dataSetIDS = null; + if (dataSet != null && dataSet.getNumDataPartitions() == 0) { + dataSetIDS = dataSetsByJob.get(dataSet.getJobID()); + dataSetIDS.remove(dataSetID); + dataSets.remove(dataSetID); + } + + if (dataSetIDS != null && dataSetIDS.isEmpty()) { + dataSetsByJob.remove(dataSet.getJobID()); + } + } + CommonUtils.runQuietly(() -> partitionStateListener.onPartitionRemoved(partitionMeta)); + } + + /** + * Releases the target {@link DataPartition} and notifies the corresponding {@link + * PartitionStateListener} if the {@link DataPartition} is released successfully. + */ + private void releaseDataPartition( + DataPartition dataPartition, Throwable releaseCause, boolean removePartition) { + if (dataPartition == null) { + return; + } + + DataPartitionMeta partitionMeta = dataPartition.getPartitionMeta(); + CompletableFuture future = + DataPartitionUtils.releaseDataPartition(dataPartition, releaseCause); + future.whenComplete( + (ignored, throwable) -> { + if (throwable != null) { + LOG.error( + "Failed to release data partition: {}.", partitionMeta, throwable); + return; + } + + if (removePartition) { + removeDataPartition(dataPartition.getPartitionMeta()); + } else { + CommonUtils.runQuietly( + () -> partitionStateListener.onPartitionRemoved(partitionMeta)); + } + LOG.info("Successfully released data partition: {}.", partitionMeta); + }); + } + + private void addDataPartition(DataPartition dataPartition) { + CommonUtils.checkArgument(dataPartition != null, "Must be not null."); + + DataPartitionMeta partitionMeta = dataPartition.getPartitionMeta(); + synchronized (lock) { + CommonUtils.checkState(!isShutDown, "Data store has been shut down."); + + DataSet dataSet = + dataSets.computeIfAbsent( + partitionMeta.getDataSetID(), + (dataSetID) -> new DataSet(partitionMeta.getJobID(), dataSetID)); + dataSet.addDataPartition(dataPartition); + + Set dataSetIDS = + dataSetsByJob.computeIfAbsent( + partitionMeta.getJobID(), (ignored) -> new HashSet<>()); + dataSetIDS.add(partitionMeta.getDataSetID()); + numDataPartitions.inc(); + } + } + + private DataPartition getDataPartition(DataSetID dataSetID, DataPartitionID partitionID) { + CommonUtils.checkArgument(dataSetID != null, "Must be not null."); + CommonUtils.checkArgument(partitionID != null, "Must be not null."); + + synchronized (lock) { + CommonUtils.checkState(!isShutDown, "Data store has been shut down."); + + DataPartition dataPartition = null; + DataSet dataSet = dataSets.get(dataSetID); + if (dataSet != null) { + dataPartition = dataSet.getDataPartition(partitionID); + } + + return dataPartition; + } + } + + @Override + public void releaseDataPartition( + DataSetID dataSetID, DataPartitionID partitionID, @Nullable Throwable throwable) { + CommonUtils.checkArgument(dataSetID != null, "Must be not null."); + CommonUtils.checkArgument(partitionID != null, "Must be not null."); + + DataPartition dataPartition = getDataPartition(dataSetID, partitionID); + CommonUtils.runQuietly(() -> releaseDataPartition(dataPartition, throwable, true)); + } + + @Override + public void releaseDataSet(DataSetID dataSetID, @Nullable Throwable throwable) { + CommonUtils.checkArgument(dataSetID != null, "Must be not null."); + + List dataPartitions = new ArrayList<>(); + synchronized (lock) { + DataSet dataSet = dataSets.get(dataSetID); + if (dataSet != null) { + dataPartitions.addAll(dataSet.getDataPartitions()); + } + } + + for (DataPartition dataPartition : dataPartitions) { + CommonUtils.runQuietly(() -> releaseDataPartition(dataPartition, throwable, true)); + } + } + + @Override + public void releaseDataByJobID(JobID jobID, @Nullable Throwable throwable) { + CommonUtils.checkArgument(jobID != null, "Must be not null."); + + List dataPartitions = new ArrayList<>(); + synchronized (lock) { + Set dataSetIDS = dataSetsByJob.get(jobID); + + if (dataSetIDS != null) { + for (DataSetID dataSetID : dataSetIDS) { + dataPartitions.addAll(dataSets.get(dataSetID).getDataPartitions()); + } + } + } + + for (DataPartition dataPartition : dataPartitions) { + CommonUtils.runQuietly(() -> releaseDataPartition(dataPartition, throwable, true)); + } + } + + @Override + public void shutDown(boolean releaseData) { + LOG.info(String.format("Shutting down the data store: releaseData=%s.", releaseData)); + + List dataPartitions = new ArrayList<>(); + synchronized (lock) { + if (isShutDown) { + return; + } + isShutDown = true; + + if (releaseData) { + List dataSetList = new ArrayList<>(dataSets.values()); + dataSets.clear(); + dataSetsByJob.clear(); + + for (DataSet dataSet : dataSetList) { + dataPartitions.addAll(dataSet.clearDataPartitions()); + } + } + } + + DataPartitionUtils.releaseDataPartitions( + dataPartitions, new ShuffleException("Shutting down."), partitionStateListener); + + destroyBufferManager(writingBufferDispatcher); + destroyBufferManager(readingBufferDispatcher); + + destroyExecutorPools(); + } + + @Override + public boolean isShutDown() { + synchronized (lock) { + return isShutDown; + } + } + + /** + * Destroys the target {@link BufferDispatcher} and logs the error if encountering any + * exception. + */ + private void destroyBufferManager(BufferDispatcher bufferDispatcher) { + try { + CommonUtils.checkArgument(bufferDispatcher != null, "Must be not null."); + + bufferDispatcher.destroy(); + } catch (Throwable throwable) { + LOG.error("Fatal: failed to destroy buffer manager.", throwable); + } + } + + /** + * Destroys all the {@link SingleThreadExecutorPool}s and logs the error if encountering any + * exception. + */ + private void destroyExecutorPools() { + synchronized (lock) { + for (SingleThreadExecutorPool executorPool : executorPools.values()) { + try { + executorPool.destroy(); + } catch (Throwable throwable) { + LOG.error("Fatal: failed to destroy executor pool.", throwable); + } + } + executorPools.clear(); + } + } + + @Override + public Configuration getConfiguration() { + return configuration; + } + + @Override + public BufferDispatcher getWritingBufferDispatcher() { + return writingBufferDispatcher; + } + + @Override + public BufferDispatcher getReadingBufferDispatcher() { + return readingBufferDispatcher; + } + + @Override + public SingleThreadExecutorPool getExecutorPool(StorageMeta storageMeta) { + synchronized (lock) { + if (isShutDown) { + throw new ShuffleException("Data store has been already shutdown."); + } + + return executorPools.computeIfAbsent(storageMeta, this::createExecutorPool); + } + } + + // --------------------------------------------------------------------------------------------- + // For test + // --------------------------------------------------------------------------------------------- + + Map>> getStoredData() { + Map>> dataSetsByJobMap = new HashMap<>(); + synchronized (lock) { + for (Map.Entry> entry : dataSetsByJob.entrySet()) { + JobID jobID = entry.getKey(); + Map> dataSetPartitions = new HashMap<>(); + dataSetsByJobMap.put(jobID, dataSetPartitions); + + for (Map.Entry dataSetEntry : dataSets.entrySet()) { + if (dataSetEntry.getValue().getJobID().equals(jobID)) { + dataSetPartitions.put( + dataSetEntry.getKey(), + dataSetEntry.getValue().getDataPartitionIDs()); + } + } + } + } + return dataSetsByJobMap; + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/exception/ConcurrentWriteException.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/exception/ConcurrentWriteException.java new file mode 100644 index 00000000..9d478723 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/exception/ConcurrentWriteException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.exception; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.core.storage.MapPartition; + +/** + * Exception to be thrown if more than one data partition writers try to write data to the same + * {@link MapPartition}. + */ +public class ConcurrentWriteException extends ShuffleException { + + private static final long serialVersionUID = 218906970670946303L; + + public ConcurrentWriteException(String message) { + super(message); + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/exception/FileCorruptedException.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/exception/FileCorruptedException.java new file mode 100644 index 00000000..e14973bb --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/exception/FileCorruptedException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.exception; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; + +/** Exception to be thrown if the partition file has corrupted. */ +public class FileCorruptedException extends ShuffleException { + + private static final long serialVersionUID = 7295806358577699892L; + + public FileCorruptedException() { + super("Target file has corrupted."); + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BaseDataPartition.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BaseDataPartition.java new file mode 100644 index 00000000..9696a08f --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BaseDataPartition.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.core.executor.SingleThreadExecutor; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.listener.BufferListener; +import com.alibaba.flink.shuffle.core.memory.BufferDispatcher; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReader; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWriter; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.core.utils.BufferUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; + +/** + * {@link BaseDataPartition} implements some basic logics of {@link DataPartition} which can be + * reused by subclasses and simplify the implementation of new data partitions. It adopts a single + * thread execution mode and all processing logics of a {@link DataPartition} will run in the main + * executor thread. As a result, no lock is needed and the data partition processing logics can be + * simplified. + */ +public abstract class BaseDataPartition implements DataPartition { + + private static final Logger LOG = LoggerFactory.getLogger(BaseDataPartition.class); + + /** + * A single thread executor to process data partition events including data writing and reading. + */ + private final SingleThreadExecutor mainExecutor; + + /** All pending {@link PartitionProcessingTask}s to be processed by the {@link #processor}. */ + private final LinkedList taskQueue = new LinkedList<>(); + + /** A {@link Runnable} task responsible for processing all {@link PartitionProcessingTask}s. */ + private final DataPartitionProcessor processor = new DataPartitionProcessor(); + + /** {@link PartitionedDataStore} storing this data partition. */ + protected final PartitionedDataStore dataStore; + + /** All {@link DataPartitionReader} reading this data partition. */ + protected final Set readers = new HashSet<>(); + + /** All {@link DataPartitionWriter} writing this data partition. */ + protected final Map writers = new HashMap<>(); + + /** Whether this data partition is released or not. */ + protected boolean isReleased; + + /** The reason why this data partition is released. */ + protected Throwable releaseCause; + + public BaseDataPartition(PartitionedDataStore dataStore, SingleThreadExecutor mainExecutor) { + CommonUtils.checkArgument(dataStore != null, "Must be not null."); + CommonUtils.checkArgument(mainExecutor != null, "Must be not null."); + + this.dataStore = dataStore; + this.mainExecutor = mainExecutor; + } + + /** Returns true if the program is running in the main executor thread. */ + protected boolean inExecutorThread() { + return mainExecutor.inExecutorThread(); + } + + /** + * Adds the target {@link PartitionProcessingTask} to the task queue. High priority tasks will + * be inserted into the head of the task queue. + */ + protected void addPartitionProcessingTask( + PartitionProcessingTask task, boolean isPrioritizedTask) { + try { + final boolean triggerProcessing; + synchronized (taskQueue) { + triggerProcessing = taskQueue.isEmpty(); + if (isPrioritizedTask) { + taskQueue.addFirst(task); + } else { + taskQueue.addLast(task); + } + } + + if (triggerProcessing) { + mainExecutor.execute(processor); + } + } catch (Throwable throwable) { + // exception may happen if too many tasks have been submitted which should be rare and + // the corresponding consequence is not defined, so just trigger fatal error and exist + LOG.error("Fatal: failed to add new partition processing task.", throwable); + if (!dataStore.isShutDown()) { + CommonUtils.exitOnFatalError(throwable); + } + } + } + + /** Adds the target low priority {@link PartitionProcessingTask} to the task queue. */ + protected void addPartitionProcessingTask(PartitionProcessingTask task) { + addPartitionProcessingTask(task, false); + } + + @Override + public CompletableFuture releasePartition(@Nullable Throwable releaseCause) { + CompletableFuture future = new CompletableFuture<>(); + addPartitionProcessingTask( + () -> { + try { + releaseInternal( + releaseCause != null + ? releaseCause + : new ShuffleException("Data partition released.")); + future.complete(null); + } catch (Throwable throwable) { + future.completeExceptionally(throwable); + } + }, + true); + return future; + } + + protected void releaseInternal(Throwable releaseCause) throws Exception { + CommonUtils.checkArgument(releaseCause != null, "Must be not null."); + CommonUtils.checkState(inExecutorThread(), "Not in executor thread."); + + LOG.info("Releasing data partition: {}.", getPartitionMeta()); + isReleased = true; + if (this.releaseCause == null) { + this.releaseCause = + new ShuffleException( + String.format("Data partition %s released.", getPartitionMeta()), + releaseCause); + } + + Throwable exception = null; + for (DataPartitionWriter writer : writers.values()) { + try { + writer.release(this.releaseCause); + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + LOG.error("Fatal: failed to release partition writer.", throwable); + } + } + writers.clear(); + + for (DataPartitionReader reader : readers) { + try { + reader.release(this.releaseCause); + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + LOG.error("Fatal: failed to release partition reader.", throwable); + } + } + readers.clear(); + + if (exception != null) { + ExceptionUtils.rethrowException(exception); + } + } + + /** Utility method to allocate buffers from the {@link BufferDispatcher}. */ + protected void allocateBuffers( + BufferDispatcher bufferDispatcher, + BufferListener bufferListener, + int minBuffers, + int maxBuffers) { + DataPartitionMeta partitionMeta = getPartitionMeta(); + JobID jobID = partitionMeta.getJobID(); + DataSetID dataSetID = partitionMeta.getDataSetID(); + DataPartitionID partitionID = partitionMeta.getDataPartitionID(); + + bufferDispatcher.requestBuffer( + jobID, dataSetID, partitionID, minBuffers, maxBuffers, bufferListener); + } + + /** Utility method to release this data partition when internal error occurs. */ + protected void releaseOnInternalError(Throwable throwable) throws Exception { + releaseInternal(throwable); + + dataStore.removeDataPartition(getPartitionMeta()); + } + + /** Utility method to recycle buffers to target {@link BufferDispatcher}. */ + protected void recycleBuffers( + Collection buffers, BufferDispatcher bufferDispatcher) { + DataPartitionMeta partitionMeta = getPartitionMeta(); + BufferUtils.recycleBuffers( + buffers, + bufferDispatcher, + partitionMeta.getJobID(), + partitionMeta.getDataSetID(), + partitionMeta.getDataPartitionID()); + } + + /** + * Returns the {@link DataPartitionWritingTask} of this data partition. Different {@link + * DataPartition} implementations can implement different {@link DataPartitionWritingTask}. + */ + protected abstract DataPartitionWritingTask getPartitionWritingTask(); + + /** + * Returns the {@link DataPartitionReadingTask} of this data partition. Different {@link + * DataPartition} implementations can implement different {@link DataPartitionReadingTask}. + */ + protected abstract DataPartitionReadingTask getPartitionReadingTask(); + + /** + * {@link DataPartitionProcessor} is responsible for processing all the pending {@link + * PartitionProcessingTask}s of this {@link DataPartition}. + */ + protected class DataPartitionProcessor implements Runnable { + + @Override + public void run() { + boolean continueProcessing = true; + while (continueProcessing) { + PartitionProcessingTask task; + synchronized (taskQueue) { + task = taskQueue.pollFirst(); + continueProcessing = !taskQueue.isEmpty(); + } + + try { + CommonUtils.checkNotNull(task).process(); + } catch (Throwable throwable) { + // tasks need to handle exceptions themselves so this should + // never happen, we add this to catch potential bugs + LOG.error("Fatal: failed to run partition processing task.", throwable); + } + } + } + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BaseDataPartitionReader.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BaseDataPartitionReader.java new file mode 100644 index 00000000..ba8419be --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BaseDataPartitionReader.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.listener.BacklogListener; +import com.alibaba.flink.shuffle.core.listener.DataListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.storage.BufferWithBacklog; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReader; +import com.alibaba.flink.shuffle.core.utils.BufferUtils; +import com.alibaba.flink.shuffle.core.utils.ListenerUtils; + +import javax.annotation.concurrent.GuardedBy; + +import java.util.ArrayDeque; +import java.util.Queue; + +/** + * {@link BaseDataPartitionReader} implements some basic logics of {@link DataPartitionReader} which + * can be reused by subclasses and simplify the implementation of new {@link DataPartitionReader}s. + */ +public abstract class BaseDataPartitionReader implements DataPartitionReader { + + /** + * Lock used for synchronization and to avoid potential race conditions between the data reading + * thread and the data partition executor thread. + */ + protected final Object lock = new Object(); + + /** Listener to be notified when there is data available for consumption. */ + protected final DataListener dataListener; + + /** Listener to be notified when there is backlog available in this reader. */ + protected final BacklogListener backlogListener; + + /** Listener to be notified if any failure occurs while reading the data. */ + protected final FailureListener failureListener; + + /** All {@link Buffer}s read and can be polled from this partition reader. */ + @GuardedBy("lock") + protected final Queue buffersRead = new ArrayDeque<>(); + + /** Whether all the data has been successfully read or not. */ + @GuardedBy("lock") + protected boolean isFinished; + + /** Whether this partition reader has been released or not. */ + @GuardedBy("lock") + protected boolean isReleased; + + /** Exception causing the release of this partition reader. */ + @GuardedBy("lock") + protected Throwable releaseCause; + + /** Whether there is any error at the consumer side or not. */ + @GuardedBy("lock") + protected boolean isError; + + public BaseDataPartitionReader( + DataListener dataListener, + BacklogListener backlogListener, + FailureListener failureListener) { + CommonUtils.checkArgument(dataListener != null, "Must be not null."); + CommonUtils.checkArgument(backlogListener != null, "Must be not null."); + CommonUtils.checkArgument(failureListener != null, "Must be not null."); + + this.dataListener = dataListener; + this.backlogListener = backlogListener; + this.failureListener = failureListener; + } + + @Override + public BufferWithBacklog nextBuffer() { + synchronized (lock) { + Buffer buffer = buffersRead.poll(); + if (buffer == null) { + return null; + } + + return new BufferWithBacklog(buffer, buffersRead.size()); + } + } + + /** + * Adds a buffer read to this {@link DataPartitionReader} for consumption and notifies the + * target {@link DataListener} of the available data if needed. + */ + protected void addBuffer(Buffer buffer, boolean hasRemaining) { + if (buffer == null) { + return; + } + + final boolean recycleBuffer; + boolean notifyDataAvailable = false; + final Throwable throwable; + + synchronized (lock) { + recycleBuffer = isReleased || isFinished || isError; + throwable = releaseCause; + isFinished = !hasRemaining; + + if (!recycleBuffer) { + notifyDataAvailable = buffersRead.isEmpty(); + buffersRead.add(buffer); + } + } + + if (recycleBuffer) { + BufferUtils.recycleBuffer(buffer); + throw new ShuffleException("Partition reader has been failed or finished.", throwable); + } + + if (notifyDataAvailable) { + ListenerUtils.notifyAvailableData(dataListener); + } + } + + protected void notifyBacklog(int backlog) { + ListenerUtils.notifyBacklog(backlogListener, backlog); + } + + @Override + public void release(Throwable throwable) throws Exception { + boolean notifyFailure; + Queue buffers; + + synchronized (lock) { + if (isReleased) { + return; + } + + isReleased = true; + releaseCause = throwable; + + notifyFailure = !isError; + buffers = new ArrayDeque<>(buffersRead); + buffersRead.clear(); + } + + BufferUtils.recycleBuffers(buffers); + if (notifyFailure) { + ListenerUtils.notifyFailure(failureListener, throwable); + } + } + + @Override + public void onError(Throwable throwable) { + Queue buffers; + + synchronized (lock) { + isError = true; + releaseCause = throwable; + + buffers = new ArrayDeque<>(buffersRead); + buffersRead.clear(); + } + + BufferUtils.recycleBuffers(buffers); + } + + /** + * Closes this data partition reader which means all data has been read successfully. It should + * notify failure and throw exception if encountering any failure. + */ + protected void closeReader() throws Exception { + synchronized (lock) { + isFinished = true; + } + + ListenerUtils.notifyAvailableData(dataListener); + } + + @Override + public boolean isFinished() { + synchronized (lock) { + return isFinished && buffersRead.isEmpty(); + } + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BaseDataPartitionWriter.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BaseDataPartitionWriter.java new file mode 100644 index 00000000..383a6dc3 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BaseDataPartitionWriter.java @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; +import com.alibaba.flink.shuffle.core.listener.DataCommitListener; +import com.alibaba.flink.shuffle.core.listener.DataRegionCreditListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.memory.BufferRecycler; +import com.alibaba.flink.shuffle.core.storage.BufferQueue; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWriter; +import com.alibaba.flink.shuffle.core.utils.BufferUtils; +import com.alibaba.flink.shuffle.core.utils.ListenerUtils; + +import javax.annotation.concurrent.GuardedBy; + +import java.util.ArrayDeque; +import java.util.Queue; + +/** + * {@link BaseDataPartitionWriter} implements some basics logic of {@link DataPartitionWriter} which + * can be reused by subclasses and simplify the implementation of new {@link DataPartitionWriter}s. + */ +public abstract class BaseDataPartitionWriter implements DataPartitionWriter { + + /** + * Minimum number of credits to notify the credit listener of new credits. Bulk notification can + * reduce small network packages. + */ + public static final int MIN_CREDITS_TO_NOTIFY = 10; + + /** Target {@link DataPartition} to write data to. */ + protected final BaseDataPartition dataPartition; + + /** {@link MapPartitionID} of all the data written. */ + protected final MapPartitionID mapPartitionID; + + /** + * {@link DataRegionCreditListener} to be notified when new credits are available for the + * corresponding data producer. + */ + protected final DataRegionCreditListener dataRegionCreditListener; + + /** + * {@link FailureListener} to be notified if any exception occurs when processing the pending + * {@link BufferOrMarker}s. + */ + protected final FailureListener failureListener; + + /** + * Lock used for synchronization and to avoid potential race conditions between the data writing + * thread and the data partition executor thread. + */ + protected final Object lock = new Object(); + + /** All available credits can be used by the corresponding data producer. */ + @GuardedBy("lock") + protected final Queue availableCredits = new ArrayDeque<>(); + + /** + * All pending {@link BufferOrMarker}s already added to this partition writer and waiting to be + * processed. + */ + @GuardedBy("lock") + protected final Queue bufferOrMarkers = new ArrayDeque<>(); + + /** Whether this partition writer has been released or not. */ + @GuardedBy("lock") + protected boolean isReleased; + + /** Whether there is any error at the producer side or not. */ + @GuardedBy("lock") + protected boolean isError; + + /** + * Whether this {@link DataPartitionWriter} needs more credits to receive and cache data or not. + */ + protected boolean needMoreCredits; + + /** Index number of the current data region being written. */ + protected int currentDataRegionIndex; + + protected BaseDataPartitionWriter( + BaseDataPartition dataPartition, + MapPartitionID mapPartitionID, + DataRegionCreditListener dataRegionCreditListener, + FailureListener failureListener) { + CommonUtils.checkArgument(dataPartition != null, "Must be not null."); + CommonUtils.checkArgument(mapPartitionID != null, "Must be not null."); + CommonUtils.checkArgument(dataRegionCreditListener != null, "Must be not null."); + CommonUtils.checkArgument(failureListener != null, "Must be not null."); + + this.dataPartition = dataPartition; + this.mapPartitionID = mapPartitionID; + this.dataRegionCreditListener = dataRegionCreditListener; + this.failureListener = failureListener; + } + + @Override + public MapPartitionID getMapPartitionID() { + return mapPartitionID; + } + + @Override + public void addBuffer(ReducePartitionID reducePartitionID, Buffer buffer) { + addBufferOrMarker(new BufferOrMarker.DataBuffer(mapPartitionID, reducePartitionID, buffer)); + } + + @Override + public void startRegion(int dataRegionIndex, boolean isBroadcastRegion) { + addBufferOrMarker( + new BufferOrMarker.RegionStartedMarker( + mapPartitionID, dataRegionIndex, isBroadcastRegion)); + } + + @Override + public void finishRegion() { + addBufferOrMarker(new BufferOrMarker.RegionFinishedMarker(mapPartitionID)); + } + + @Override + public void finishDataInput(DataCommitListener commitListener) { + addBufferOrMarker(new BufferOrMarker.InputFinishedMarker(mapPartitionID, commitListener)); + } + + /** Adds a new {@link BufferOrMarker} to this partition writer to be processed. */ + protected abstract void addBufferOrMarker(BufferOrMarker bufferOrMarker); + + @Override + public boolean writeData() throws Exception { + Queue pendingBufferOrMarkers = getPendingBufferOrMarkers(); + if (pendingBufferOrMarkers == null) { + return false; + } + + BufferOrMarker bufferOrMarker; + try { + while ((bufferOrMarker = pendingBufferOrMarkers.poll()) != null) { + if (processBufferOrMarker(bufferOrMarker)) { + return true; + } + } + } finally { + BufferOrMarker.releaseBuffers(pendingBufferOrMarkers); + } + return false; + } + + protected boolean processBufferOrMarker(BufferOrMarker bufferOrMarker) throws Exception { + switch (bufferOrMarker.getType()) { + case ERROR_MARKER: + processErrorMarker(bufferOrMarker.asErrorMarker()); + return true; + case INPUT_FINISHED_MARKER: + processInputFinishedMarker(bufferOrMarker.asInputFinishedMarker()); + return true; + case REGION_STARTED_MARKER: + processRegionStartedMarker(bufferOrMarker.asRegionStartedMarker()); + return false; + case REGION_FINISHED_MARKER: + processRegionFinishedMarker(bufferOrMarker.asRegionFinishedMarker()); + return false; + case DATA_BUFFER: + processDataBuffer(bufferOrMarker.asDataBuffer()); + return false; + default: + throw new ShuffleException( + String.format("Illegal type: %s.", bufferOrMarker.getType())); + } + } + + protected void processErrorMarker(BufferOrMarker.ErrorMarker marker) throws Exception { + needMoreCredits = false; + releaseUnusedCredits(); + ExceptionUtils.rethrowException(marker.getFailure()); + } + + protected void processRegionStartedMarker(BufferOrMarker.RegionStartedMarker marker) + throws Exception { + needMoreCredits = true; + currentDataRegionIndex = marker.getDataRegionIndex(); + } + + protected abstract void processDataBuffer(BufferOrMarker.DataBuffer buffer) throws Exception; + + protected void processRegionFinishedMarker(BufferOrMarker.RegionFinishedMarker marker) + throws Exception { + needMoreCredits = false; + releaseUnusedCredits(); + } + + protected void processInputFinishedMarker(BufferOrMarker.InputFinishedMarker marker) + throws Exception { + CommonUtils.checkState(!needMoreCredits, "Must finish region before finish input."); + CommonUtils.checkState(availableCredits.isEmpty(), "Buffers (credits) leanKing."); + ListenerUtils.notifyDataCommitted(marker.getCommitListener()); + } + + @Override + public void onError(Throwable throwable) { + synchronized (lock) { + if (isReleased || isError) { + return; + } + + isError = true; + } + + Queue pendingBufferOrMarkers = getPendingBufferOrMarkers(); + BufferOrMarker.releaseBuffers(pendingBufferOrMarkers); + + Throwable exception = new ShuffleException("Writing view failed.", throwable); + addBufferOrMarker(new BufferOrMarker.ErrorMarker(mapPartitionID, exception)); + } + + @Override + public boolean assignCredits(BufferQueue credits, BufferRecycler recycler) { + CommonUtils.checkArgument(credits != null, "Must be not null."); + CommonUtils.checkArgument(recycler != null, "Must be not null."); + + if (isReleased || !needMoreCredits) { + return false; + } + + if (credits.size() < MIN_CREDITS_TO_NOTIFY) { + return needMoreCredits; + } + + int numBuffers = 0; + synchronized (lock) { + if (isError) { + return false; + } + + while (credits.size() > 0) { + ++numBuffers; + availableCredits.add(new Buffer(credits.poll(), recycler, 0)); + } + } + + ListenerUtils.notifyAvailableCredits( + numBuffers, currentDataRegionIndex, dataRegionCreditListener); + return needMoreCredits; + } + + @Override + public Buffer pollBuffer() { + synchronized (lock) { + if (isReleased || isError) { + throw new ShuffleException("Partition writer has been released or failed."); + } + + return availableCredits.poll(); + } + } + + @Override + public void release(Throwable throwable) throws Exception { + Queue buffers; + boolean notifyFailure; + + synchronized (lock) { + if (isReleased) { + return; + } + + notifyFailure = !isError; + isReleased = true; + buffers = new ArrayDeque<>(availableCredits); + availableCredits.clear(); + } + + if (notifyFailure) { + ListenerUtils.notifyFailure( + failureListener, + new ShuffleException( + "Error encountered while writing data partition.", throwable)); + } + + BufferUtils.recycleBuffers(buffers); + BufferOrMarker.releaseBuffers(getPendingBufferOrMarkers()); + releaseUnusedCredits(); + } + + private void releaseUnusedCredits() { + Queue unusedCredits; + synchronized (lock) { + unusedCredits = new ArrayDeque<>(availableCredits); + availableCredits.clear(); + } + + BufferUtils.recycleBuffers(unusedCredits); + } + + private Queue getPendingBufferOrMarkers() { + synchronized (lock) { + if (bufferOrMarkers.isEmpty()) { + return null; + } + + Queue pendingBufferOrMarkers = new ArrayDeque<>(bufferOrMarkers); + bufferOrMarkers.clear(); + return pendingBufferOrMarkers; + } + } + + // --------------------------------------------------------------------------------------------- + // For test + // --------------------------------------------------------------------------------------------- + + int getNumPendingBuffers() { + return bufferOrMarkers.size(); + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BaseMapPartition.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BaseMapPartition.java new file mode 100644 index 00000000..52c485c9 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BaseMapPartition.java @@ -0,0 +1,676 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.core.config.MemoryOptions; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.executor.SingleThreadExecutor; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.listener.BacklogListener; +import com.alibaba.flink.shuffle.core.listener.BufferListener; +import com.alibaba.flink.shuffle.core.listener.DataListener; +import com.alibaba.flink.shuffle.core.listener.DataRegionCreditListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; +import com.alibaba.flink.shuffle.core.memory.BufferRecycler; +import com.alibaba.flink.shuffle.core.storage.BufferQueue; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReader; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWriter; +import com.alibaba.flink.shuffle.core.storage.MapPartition; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.storage.exception.ConcurrentWriteException; +import com.alibaba.flink.shuffle.storage.utils.DataPartitionUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.PriorityQueue; + +/** + * Base {@link MapPartition} implementation which takes care of allocating resources and io + * scheduling. It can be used by different subclasses and simplify the new {@link MapPartition} + * implementation. + */ +public abstract class BaseMapPartition extends BaseDataPartition implements MapPartition { + + private static final Logger LOG = LoggerFactory.getLogger(BaseMapPartition.class); + + /** Task responsible for writing data to this {@link MapPartition}. */ + private final MapPartitionWritingTask writingTask; + + /** Task responsible for reading data from this {@link MapPartition}. */ + private final MapPartitionReadingTask readingTask; + + /** Whether this {@link MapPartition} has finished writing all data. */ + protected boolean isFinished; + + /** + * Whether a {@link DataPartitionWriter} has been created for this {@link MapPartition} or not. + */ + protected boolean partitionWriterCreated; + + public BaseMapPartition(PartitionedDataStore dataStore, SingleThreadExecutor mainExecutor) { + super(dataStore, mainExecutor); + + Configuration configuration = dataStore.getConfiguration(); + this.readingTask = new MapPartitionReadingTask(configuration); + this.writingTask = new MapPartitionWritingTask(configuration); + } + + @Override + public DataPartitionWriter createPartitionWriter( + MapPartitionID mapPartitionID, + DataRegionCreditListener dataRegionCreditListener, + FailureListener failureListener) + throws Exception { + CommonUtils.checkArgument(mapPartitionID != null, "Must be not null."); + CommonUtils.checkArgument(dataRegionCreditListener != null, "Must be not null."); + CommonUtils.checkArgument(failureListener != null, "Must be not null."); + + final DataPartitionWriter writer = + getDataPartitionWriter(mapPartitionID, dataRegionCreditListener, failureListener); + + addPartitionProcessingTask( + () -> { + try { + CommonUtils.checkArgument( + mapPartitionID.equals(getPartitionMeta().getDataPartitionID()), + "Inconsistent partition ID for the target map partition."); + + final Exception exception; + if (isReleased) { + exception = new ShuffleException("Data partition has been released."); + } else if (!writers.isEmpty() || partitionWriterCreated) { + exception = + new ConcurrentWriteException( + "Trying to write an existing map partition."); + } else { + exception = null; + } + + if (exception != null) { + DataPartitionUtils.releaseDataPartitionWriter(writer, exception); + return; + } + + partitionWriterCreated = true; + writers.put(mapPartitionID, writer); + } catch (Throwable throwable) { + CommonUtils.runQuietly(() -> releaseOnInternalError(throwable)); + LOG.error("Failed to create data partition writer.", throwable); + } + }, + true); + return writer; + } + + @Override + public DataPartitionReader createPartitionReader( + int startPartitionIndex, + int endPartitionIndex, + DataListener dataListener, + BacklogListener backlogListener, + FailureListener failureListener) + throws Exception { + CommonUtils.checkArgument(dataListener != null, "Must be not null."); + CommonUtils.checkArgument(backlogListener != null, "Must be not null."); + CommonUtils.checkArgument(failureListener != null, "Must be not null."); + + final DataPartitionReader reader = + getDataPartitionReader( + startPartitionIndex, + endPartitionIndex, + dataListener, + backlogListener, + failureListener); + + addPartitionProcessingTask( + () -> { + try { + CommonUtils.checkState(!isReleased, "Data partition has been released."); + + // allocate resources when the first reader is registered + boolean allocateResources = readers.isEmpty(); + readers.add(reader); + + if (allocateResources) { + DataPartitionReadingTask readingTask = + CommonUtils.checkNotNull(getPartitionReadingTask()); + readingTask.allocateResources(); + } + } catch (Throwable throwable) { + DataPartitionUtils.releaseDataPartitionReader(reader, throwable); + LOG.error("Failed to create data partition reader.", throwable); + } + }, + true); + return reader; + } + + /** + * Returns the corresponding {@link DataPartitionReader} of the target reduce partitions. The + * implementation is responsible for closing its allocated resources if any when encountering + * any exception. + */ + protected abstract DataPartitionReader getDataPartitionReader( + int startPartitionIndex, + int endPartitionIndex, + DataListener dataListener, + BacklogListener backlogListener, + FailureListener failureListener) + throws Exception; + + /** + * Returns the corresponding {@link DataPartitionWriter} of the target map partition. The + * implementation is responsible for closing its allocated resources if any when encountering + * any exception. + */ + protected abstract DataPartitionWriter getDataPartitionWriter( + MapPartitionID mapPartitionID, + DataRegionCreditListener dataRegionCreditListener, + FailureListener failureListener) + throws Exception; + + @Override + public MapPartitionWritingTask getPartitionWritingTask() { + return writingTask; + } + + @Override + public MapPartitionReadingTask getPartitionReadingTask() { + return readingTask; + } + + @Override + protected void releaseInternal(Throwable releaseCause) throws Exception { + Throwable exception = null; + + try { + super.releaseInternal(releaseCause); + } catch (Throwable throwable) { + exception = throwable; + LOG.error("Fatal: failed to release base data partition.", throwable); + } + + try { + writingTask.release(this.releaseCause); + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + LOG.error("Fatal: failed to release data writing task.", throwable); + } + + try { + readingTask.release(this.releaseCause); + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + LOG.error("Fatal: failed to release data reading task.", throwable); + } + + if (exception != null) { + ExceptionUtils.rethrowException(exception); + } + } + + /** + * {@link MapPartitionWritingTask} implements the basic resource allocation and data processing + * logics which can be reused by subclasses. + */ + protected class MapPartitionWritingTask implements DataPartitionWritingTask, BufferListener { + + /** + * Minimum size of memory in bytes to trigger writing of {@link DataPartitionReader}s. Use a + * portion of the guaranteed memory for data bulk writing and keep the left memory for data + * transmission over the network, better writing pipeline can be achieved (1/2 is just an + * empirical value). + */ + public final int minMemoryToWrite = + CommonUtils.checkedDownCast( + StorageOptions.MIN_WRITING_READING_MEMORY_SIZE.divide(2).getBytes()); + + /** + * Minimum number of buffers (calculated from {@link #minMemoryToWrite} and buffer size) to + * trigger writing of {@link DataPartitionWriter}s. + */ + public final int minBuffersToWrite; + + /** + * Minimum number of buffers (calculated from buffer size and {@link + * StorageOptions#MIN_WRITING_READING_MEMORY_SIZE}) to be used for data partition writing. + */ + protected final int minWritingBuffers; + + /** + * Maximum number of buffers (calculated from buffer size and the configured value for + * {@link StorageOptions#STORAGE_MAX_PARTITION_WRITING_MEMORY}) to be used for data + * partition writing. + */ + protected final int maxWritingBuffers; + + /** Available buffers can be used for data writing of the target partition. */ + protected final BufferQueue buffers = new BufferQueue(new ArrayList<>()); + + /** {@link DataPartitionWriter} instance used to write data to this {@link MapPartition}. */ + protected DataPartitionWriter writer; + + protected MapPartitionWritingTask(Configuration configuration) { + int minWritingMemory = + CommonUtils.checkedDownCast( + StorageOptions.MIN_WRITING_READING_MEMORY_SIZE.getBytes()); + int maxWritingMemory = + CommonUtils.checkedDownCast( + configuration + .getMemorySize( + StorageOptions.STORAGE_MAX_PARTITION_READING_MEMORY) + .getBytes()); + int bufferSize = + CommonUtils.checkedDownCast( + configuration + .getMemorySize(MemoryOptions.MEMORY_BUFFER_SIZE) + .getBytes()); + this.minBuffersToWrite = Math.max(1, minMemoryToWrite / bufferSize); + this.minWritingBuffers = Math.max(1, minWritingMemory / bufferSize); + this.maxWritingBuffers = Math.max(minWritingBuffers, maxWritingMemory / bufferSize); + } + + @Override + public void process() { + try { + CommonUtils.checkState(inExecutorThread(), "Not in main thread."); + + if (isReleased) { + return; + } + CommonUtils.checkState(!isFinished, "Data partition has been finished."); + CommonUtils.checkState(!buffers.isReleased(), "Buffers has been released."); + + if (writer == null) { + CommonUtils.checkState(writers.size() == 1, "Too many partition writers."); + MapPartitionID partitionID = getPartitionMeta().getDataPartitionID(); + writer = CommonUtils.checkNotNull(writers.get(partitionID)); + } + + if (!writer.writeData()) { + dispatchBuffers(); + return; + } + + writer = null; + writers.clear(); + recycleBuffers(buffers.release(), dataStore.getWritingBufferDispatcher()); + isFinished = true; + LOG.info("Successfully write data partition: {}.", getPartitionMeta()); + } catch (Throwable throwable) { + LOG.error("Failed to write partition data.", throwable); + CommonUtils.runQuietly(() -> releaseOnInternalError(throwable)); + } + } + + private void checkInProcessState() { + CommonUtils.checkState(writer != null, "No registered writer."); + CommonUtils.checkState(!isReleased, "Partition has been released."); + CommonUtils.checkState(!isFinished, "Data writing has finished."); + } + + @Override + public void allocateResources() throws Exception { + CommonUtils.checkState(inExecutorThread(), "Not in main thread."); + checkInProcessState(); + + allocateBuffers( + dataStore.getWritingBufferDispatcher(), + this, + minWritingBuffers, + maxWritingBuffers); + } + + @Override + public void release(@Nullable Throwable releaseCause) throws Exception { + try { + CommonUtils.checkState(inExecutorThread(), "Not in main thread."); + + writer = null; + recycleBuffers(buffers.release(), dataStore.getWritingBufferDispatcher()); + } catch (Throwable throwable) { + LOG.error("Fatal: failed to release the data writing task.", throwable); + ExceptionUtils.rethrowException(throwable); + } + } + + @Override + public void triggerWriting() { + addPartitionProcessingTask(this); + } + + private void dispatchBuffers() { + CommonUtils.checkState(inExecutorThread(), "Not in main thread."); + checkInProcessState(); + + if (!writer.assignCredits(buffers, buffer -> recycle(buffer, buffers)) + && buffers.size() > 0) { + List toRelease = new ArrayList<>(buffers.size()); + while (buffers.size() > 0) { + toRelease.add(buffers.poll()); + } + recycleBuffers(toRelease, dataStore.getWritingBufferDispatcher()); + } + } + + /** Notifies the allocated writing buffers to this data writing task. */ + @Override + public void notifyBuffers(List allocatedBuffers, Throwable exception) { + addPartitionProcessingTask( + () -> { + try { + if (exception != null) { + recycleBuffers( + allocatedBuffers, dataStore.getWritingBufferDispatcher()); + throw exception; + } + + CommonUtils.checkArgument( + allocatedBuffers != null && !allocatedBuffers.isEmpty(), + "Fatal: empty buffer was allocated."); + + if (isReleased || isFinished) { + recycleBuffers( + allocatedBuffers, dataStore.getWritingBufferDispatcher()); + return; + } + + if (buffers.isReleased()) { + recycleBuffers( + allocatedBuffers, dataStore.getWritingBufferDispatcher()); + throw new ShuffleException("Buffers has been released."); + } + + buffers.add(allocatedBuffers); + allocatedBuffers.clear(); + dispatchBuffers(); + } catch (Throwable throwable) { + CommonUtils.runQuietly(() -> releaseOnInternalError(throwable)); + LOG.error("Fatal: resource allocation error.", throwable); + } + }); + } + + private void handleRecycledBuffer(ByteBuffer buffer, BufferQueue buffers) { + try { + CommonUtils.checkArgument(buffer != null, "Must be not null."); + + if (isReleased || isFinished) { + recycleBuffers( + Collections.singletonList(buffer), + dataStore.getWritingBufferDispatcher()); + return; + } + + if (buffers.isReleased()) { + recycleBuffers( + Collections.singletonList(buffer), + dataStore.getWritingBufferDispatcher()); + throw new ShuffleException("Buffers has been released."); + } + + buffers.add(buffer); + dispatchBuffers(); + } catch (Throwable throwable) { + CommonUtils.runQuietly(() -> releaseOnInternalError(throwable)); + LOG.error("Resource recycling error.", throwable); + } + } + + /** + * Recycles a writing buffer to this data writing task. If no more buffer is needed, the + * recycled buffer will be returned to the buffer manager directly and if any unexpected + * exception occurs, the corresponding data partition will be released. + */ + public void recycle(ByteBuffer buffer, BufferQueue buffers) { + if (!inExecutorThread()) { + addPartitionProcessingTask(() -> handleRecycledBuffer(buffer, buffers)); + return; + } + + handleRecycledBuffer(buffer, buffers); + } + } + + /** + * {@link MapPartitionReadingTask} implements the basic resource allocation and data reading + * logics (including IO scheduling) which can be reused by subclasses. + */ + protected class MapPartitionReadingTask implements DataPartitionReadingTask, BufferListener { + + /** + * Minimum size of memory in bytes to trigger reading of {@link DataPartitionReader}s. Use a + * portion of the guaranteed memory for data bulk reading and keep the left memory as data + * cache in the reading view, better reading pipeline can be achieved (1/2 is just an + * empirical value). + */ + public final int minMemoryToRead = + CommonUtils.checkedDownCast( + StorageOptions.MIN_WRITING_READING_MEMORY_SIZE.divide(2).getBytes()); + + /** + * Minimum number of buffers (calculated from {@link #minMemoryToRead} and buffer size) to + * trigger reading of {@link DataPartitionReader}s. + */ + public final int minBuffersToRead; + + /** + * Minimum number of buffers (calculated from buffer size and {@link + * StorageOptions#MIN_WRITING_READING_MEMORY_SIZE}) to be used for data partition reading. + */ + protected final int minReadingBuffers; + + /** + * Maximum number of buffers (calculated from buffer size and the configured value for + * {@link StorageOptions#STORAGE_MAX_PARTITION_READING_MEMORY}) to be used for data + * partition reading. + */ + protected final int maxReadingBuffers; + + /** All available buffers can be used by the partition readers for reading. */ + protected BufferQueue buffers = BufferQueue.RELEASED_EMPTY_BUFFER_QUEUE; + + protected MapPartitionReadingTask(Configuration configuration) { + int minReadingMemory = + CommonUtils.checkedDownCast( + StorageOptions.MIN_WRITING_READING_MEMORY_SIZE.getBytes()); + int maxReadingMemory = + CommonUtils.checkedDownCast( + configuration + .getMemorySize( + StorageOptions.STORAGE_MAX_PARTITION_READING_MEMORY) + .getBytes()); + int bufferSize = + CommonUtils.checkedDownCast( + configuration + .getMemorySize(MemoryOptions.MEMORY_BUFFER_SIZE) + .getBytes()); + this.minBuffersToRead = Math.max(1, minMemoryToRead / bufferSize); + this.minReadingBuffers = Math.max(1, minReadingMemory / bufferSize); + this.maxReadingBuffers = Math.max(minReadingBuffers, maxReadingMemory / bufferSize); + } + + @Override + public void process() { + try { + CommonUtils.checkState(inExecutorThread(), "Not in main thread."); + + if (isReleased) { + return; + } + CommonUtils.checkState(!readers.isEmpty(), "No reader registered."); + CommonUtils.checkState(!buffers.isReleased(), "Buffers has been released."); + BufferRecycler recycler = (buffer) -> recycle(buffer, buffers); + + for (DataPartitionReader reader : readers) { + if (!reader.isOpened()) { + reader.open(); + } + } + PriorityQueue sortedReaders = new PriorityQueue<>(readers); + + while (buffers.size() > 0 && !sortedReaders.isEmpty()) { + DataPartitionReader reader = sortedReaders.poll(); + try { + if (!reader.readData(buffers, recycler)) { + removePartitionReader(reader); + LOG.debug("Successfully read partition data: {}.", reader); + } + } catch (Throwable throwable) { + removePartitionReader(reader); + DataPartitionUtils.releaseDataPartitionReader(reader, throwable); + LOG.debug("Failed to read partition data: {}.", reader, throwable); + } + } + } catch (Throwable throwable) { + DataPartitionUtils.releaseDataPartitionReaders(readers, throwable); + recycleBuffers(buffers.release(), dataStore.getReadingBufferDispatcher()); + LOG.error("Fatal: failed to read partition data.", throwable); + } + } + + private void removePartitionReader(DataPartitionReader reader) { + readers.remove(reader); + if (readers.isEmpty()) { + recycleBuffers(buffers.release(), dataStore.getReadingBufferDispatcher()); + } + } + + /** + * Recycles a reading buffer to this data reading task. If no more buffer is needed, the + * recycled buffer will be returned to the buffer manager directly and if any unexpected + * exception occurs, all registered readers will be released. + */ + private void recycle(ByteBuffer buffer, BufferQueue bufferQueue) { + addPartitionProcessingTask( + () -> { + try { + CommonUtils.checkArgument(buffer != null, "Must be not null."); + + if (bufferQueue == null || bufferQueue.isReleased()) { + recycleBuffers( + Collections.singletonList(buffer), + dataStore.getReadingBufferDispatcher()); + CommonUtils.checkState(bufferQueue != null, "Must be not null."); + return; + } + + bufferQueue.add(buffer); + if (bufferQueue.size() >= minBuffersToRead) { + triggerReading(); + } + } catch (Throwable throwable) { + DataPartitionUtils.releaseDataPartitionReaders(readers, throwable); + recycleBuffers( + buffers.release(), dataStore.getReadingBufferDispatcher()); + LOG.error("Resource recycling error.", throwable); + } + }); + } + + @Override + public void allocateResources() throws Exception { + CommonUtils.checkState(inExecutorThread(), "Not in main thread."); + CommonUtils.checkState(!readers.isEmpty(), "No reader registered."); + CommonUtils.checkState(!isReleased, "Partition has been released."); + + allocateBuffers( + dataStore.getReadingBufferDispatcher(), + this, + minReadingBuffers, + maxReadingBuffers); + } + + @Override + public void triggerReading() { + if (inExecutorThread()) { + process(); + return; + } + + addPartitionProcessingTask(this); + } + + /** Notifies the allocated reading buffers to this data reading task. */ + @Override + public void notifyBuffers(List allocatedBuffers, Throwable exception) { + addPartitionProcessingTask( + () -> { + try { + if (exception != null) { + recycleBuffers( + allocatedBuffers, dataStore.getReadingBufferDispatcher()); + throw exception; + } + + CommonUtils.checkArgument( + allocatedBuffers != null && !allocatedBuffers.isEmpty(), + "Fatal: empty buffer was allocated."); + + if (!buffers.isReleased()) { + recycleBuffers( + buffers.release(), dataStore.getReadingBufferDispatcher()); + LOG.error("Fatal: the allocated data reading buffers are leaking."); + } + + if (isReleased || readers.isEmpty()) { + recycleBuffers( + allocatedBuffers, dataStore.getReadingBufferDispatcher()); + return; + } + + buffers = new BufferQueue(allocatedBuffers); + triggerReading(); + } catch (Throwable throwable) { + DataPartitionUtils.releaseDataPartitionReaders(readers, throwable); + recycleBuffers( + buffers.release(), dataStore.getReadingBufferDispatcher()); + LOG.error("Resource allocation error.", throwable); + } + }); + } + + /** + * Releases all reading buffers and partition readers if the corresponding data partition + * has been released. + */ + @Override + public void release(Throwable releaseCause) throws Exception { + try { + CommonUtils.checkState(inExecutorThread(), "Not in main thread."); + + recycleBuffers(buffers.release(), dataStore.getReadingBufferDispatcher()); + } catch (Throwable throwable) { + LOG.error("Fatal: failed to release the data reading task.", throwable); + ExceptionUtils.rethrowException(throwable); + } + } + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BaseMapPartitionWriter.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BaseMapPartitionWriter.java new file mode 100644 index 00000000..ed3a1fbd --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BaseMapPartitionWriter.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.listener.DataRegionCreditListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; +import com.alibaba.flink.shuffle.core.storage.MapPartition; + +/** + * {@link BaseMapPartitionWriter} implements some basic logics for data writing of {@link + * MapPartition}s. + */ +public abstract class BaseMapPartitionWriter extends BaseDataPartitionWriter { + + public BaseMapPartitionWriter( + MapPartitionID mapPartitionID, + BaseDataPartition dataPartition, + DataRegionCreditListener dataRegionCreditListener, + FailureListener failureListener) { + super(dataPartition, mapPartitionID, dataRegionCreditListener, failureListener); + } + + @Override + protected void processRegionStartedMarker(BufferOrMarker.RegionStartedMarker marker) + throws Exception { + super.processRegionStartedMarker(marker); + + DataPartitionWritingTask writingTask = + CommonUtils.checkNotNull(dataPartition.getPartitionWritingTask()); + if (needMoreCredits) { + writingTask.allocateResources(); + } + } + + @Override + protected void addBufferOrMarker(BufferOrMarker bufferOrMarker) { + boolean recycleBuffer; + boolean triggerWriting = false; + + synchronized (lock) { + if (!(recycleBuffer = isReleased)) { + // trigger data writing when the first buffer is added + triggerWriting = bufferOrMarkers.isEmpty(); + bufferOrMarkers.add(bufferOrMarker); + } + } + + if (recycleBuffer) { + BufferOrMarker.releaseBuffer(bufferOrMarker); + throw new ShuffleException("Partition writer has been released."); + } + + if (triggerWriting) { + DataPartitionWritingTask writingTask = + CommonUtils.checkNotNull(dataPartition.getPartitionWritingTask()); + writingTask.triggerWriting(); + } + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BufferOrMarker.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BufferOrMarker.java new file mode 100644 index 00000000..da4a4f83 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/BufferOrMarker.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; +import com.alibaba.flink.shuffle.core.listener.DataCommitListener; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.core.utils.BufferUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.util.Collection; + +/** + * Data buffer or event markers. It is just a helper structure can be used by {@link DataPartition} + * implementations. + */ +public abstract class BufferOrMarker { + + private static final Logger LOG = LoggerFactory.getLogger(BufferOrMarker.class); + + /** Marks the producer partition of this buffer or marker. */ + private final MapPartitionID mapPartitionID; + + public BufferOrMarker(MapPartitionID mapPartitionID) { + CommonUtils.checkArgument(mapPartitionID != null, "Must be not null."); + this.mapPartitionID = mapPartitionID; + } + + public abstract Type getType(); + + public MapPartitionID getMapPartitionID() { + return mapPartitionID; + } + + public DataBuffer asDataBuffer() { + return (DataBuffer) this; + } + + public RegionStartedMarker asRegionStartedMarker() { + return (RegionStartedMarker) this; + } + + public RegionFinishedMarker asRegionFinishedMarker() { + return (RegionFinishedMarker) this; + } + + public InputFinishedMarker asInputFinishedMarker() { + return (InputFinishedMarker) this; + } + + public ErrorMarker asErrorMarker() { + return (ErrorMarker) this; + } + + /** Releases all the given {@link BufferOrMarker}s. Will log error if any exception occurs. */ + public static void releaseBuffers( + @Nullable Collection bufferOrMarkers) { + if (bufferOrMarkers == null) { + return; + } + + for (BufferOrMarker bufferOrMarker : bufferOrMarkers) { + releaseBuffer(bufferOrMarker); + } + // clear method is not supported by all collections + CommonUtils.runQuietly(bufferOrMarkers::clear); + } + + /** Releases the target {@link BufferOrMarker}. Will log error if any exception occurs. */ + public static void releaseBuffer(@Nullable BufferOrMarker bufferOrMarker) { + if (bufferOrMarker == null) { + return; + } + + try { + if (bufferOrMarker.getType() == Type.DATA_BUFFER) { + BufferUtils.recycleBuffer(bufferOrMarker.asDataBuffer().getBuffer()); + } + } catch (Throwable throwable) { + LOG.error("Fatal: failed to release the target buffer.", throwable); + } + } + + /** Types of {@link BufferOrMarker}. */ + enum Type { + + /** DATA_BUFFER represents a {@link Buffer} containing data to be written. */ + DATA_BUFFER, + + /** + * REGION_STARTED_MARKER marks the starting of a new data region in the writing + * stream. + */ + REGION_STARTED_MARKER, + + /** + * REGION_FINISHED_MARKER marks the ending of current data region in the writing + * stream. + */ + REGION_FINISHED_MARKER, + + /** REGION_FINISHED_MARKER marks the ending of the current writing stream. */ + INPUT_FINISHED_MARKER, + + /** ERROR_MARKER marks the failure of the corresponding data writing view. */ + ERROR_MARKER + } + + /** Definition of {@link Type#DATA_BUFFER}. */ + public static class DataBuffer extends BufferOrMarker { + + /** Target reduce partition of the data. */ + private final ReducePartitionID reducePartitionID; + + /** Buffer containing data to be written. */ + private final Buffer buffer; + + public DataBuffer( + MapPartitionID mapPartitionID, ReducePartitionID reducePartitionID, Buffer buffer) { + super(mapPartitionID); + + CommonUtils.checkArgument(reducePartitionID != null, "Must be not null."); + CommonUtils.checkArgument(buffer != null, "Must be not null."); + + this.reducePartitionID = reducePartitionID; + this.buffer = buffer; + } + + @Override + public Type getType() { + return Type.DATA_BUFFER; + } + + public void release() { + buffer.release(); + } + + public ReducePartitionID getReducePartitionID() { + return reducePartitionID; + } + + public Buffer getBuffer() { + return buffer; + } + } + + /** Definition of {@link Type#REGION_STARTED_MARKER}. */ + public static class RegionStartedMarker extends BufferOrMarker { + + /** Data region index (started from 0) of the new region. */ + private final int dataRegionIndex; + + /** + * Whether the new data region is a broadcast region. In a broadcast region, each piece of + * data will be written to all reduce partitions. + */ + private final boolean isBroadcastRegion; + + public RegionStartedMarker( + MapPartitionID mapPartitionID, int dataRegionIndex, boolean isBroadcastRegion) { + super(mapPartitionID); + + CommonUtils.checkArgument(dataRegionIndex >= 0, "Must be non-negative."); + + this.dataRegionIndex = dataRegionIndex; + this.isBroadcastRegion = isBroadcastRegion; + } + + @Override + public Type getType() { + return Type.REGION_STARTED_MARKER; + } + + public int getDataRegionIndex() { + return dataRegionIndex; + } + + public boolean isBroadcastRegion() { + return isBroadcastRegion; + } + } + + /** Definition of {@link Type#REGION_FINISHED_MARKER}. */ + public static class RegionFinishedMarker extends BufferOrMarker { + + public RegionFinishedMarker(MapPartitionID mapPartitionID) { + super(mapPartitionID); + } + + @Override + public Type getType() { + return Type.REGION_FINISHED_MARKER; + } + } + + /** Definition of {@link Type#INPUT_FINISHED_MARKER}. */ + public static class InputFinishedMarker extends BufferOrMarker { + + /** Listener to be notified when all the data written is committed. */ + private final DataCommitListener commitListener; + + public InputFinishedMarker( + MapPartitionID mapPartitionID, DataCommitListener commitListener) { + super(mapPartitionID); + + CommonUtils.checkArgument(commitListener != null, "Must be not null."); + this.commitListener = commitListener; + } + + public DataCommitListener getCommitListener() { + return commitListener; + } + + @Override + public Type getType() { + return Type.INPUT_FINISHED_MARKER; + } + } + + /** Definition of {@link Type#ERROR_MARKER}. */ + public static class ErrorMarker extends BufferOrMarker { + + /** Failure encountered in the corresponding writing view. */ + private final Throwable throwable; + + public ErrorMarker(MapPartitionID mapPartitionID, Throwable throwable) { + super(mapPartitionID); + + CommonUtils.checkArgument(throwable != null, "Must be not null."); + this.throwable = throwable; + } + + public Throwable getFailure() { + return throwable; + } + + @Override + public Type getType() { + return Type.ERROR_MARKER; + } + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/DataPartitionReadingTask.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/DataPartitionReadingTask.java new file mode 100644 index 00000000..1a0900dd --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/DataPartitionReadingTask.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import javax.annotation.Nullable; + +/** + * {@link DataPartitionReadingTask} encapsulates the logic of data partition reading and will be + * triggered when there is a data reading request. + */ +public interface DataPartitionReadingTask extends PartitionProcessingTask { + + /** Allocates resources for data reading, will be called on the first data reading request. */ + void allocateResources() throws Exception; + + /** Triggers running of this reading task which will read data from data partition. */ + void triggerReading(); + + /** Releases this {@link DataPartitionReadingTask} which releases all allocated resources. */ + void release(@Nullable Throwable throwable) throws Exception; +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/DataPartitionWritingTask.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/DataPartitionWritingTask.java new file mode 100644 index 00000000..2ae6970b --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/DataPartitionWritingTask.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import javax.annotation.Nullable; + +/** + * {@link DataPartitionWritingTask} encapsulates the logic of data partition writing and will be + * triggered when there is a data writing request. + */ +public interface DataPartitionWritingTask extends PartitionProcessingTask { + + /** Allocates resources for data writing, will be called on the first data writing request. */ + void allocateResources() throws Exception; + + /** Triggers running of this writing task which will write data to data partition. */ + void triggerWriting(); + + /** Releases this {@link DataPartitionWritingTask} which releases all allocated resources. */ + void release(@Nullable Throwable throwable) throws Exception; +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/HDDOnlyLocalFileMapPartitionFactory.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/HDDOnlyLocalFileMapPartitionFactory.java new file mode 100644 index 00000000..6f99c04c --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/HDDOnlyLocalFileMapPartitionFactory.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; + +/** + * A {@link LocalFileMapPartitionFactory} variant which only uses HDD to store data partition data. + */ +public class HDDOnlyLocalFileMapPartitionFactory extends LocalFileMapPartitionFactory { + + @Override + public void initialize(Configuration configuration) { + super.initialize(configuration); + + if (hddStorageMetas.isEmpty()) { + throw new ConfigurationException( + String.format( + "No valid data dir of HDD storage type is configured for %s.", + StorageOptions.STORAGE_LOCAL_DATA_DIRS.key())); + } + } + + @Override + protected StorageMeta getNextDataStorageMeta() { + StorageMeta storageMeta = CommonUtils.checkNotNull(hddStorageMetas.poll()); + hddStorageMetas.add(storageMeta); + return storageMeta; + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartition.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartition.java new file mode 100644 index 00000000..0b72580d --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartition.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.executor.SingleThreadExecutor; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.listener.BacklogListener; +import com.alibaba.flink.shuffle.core.listener.DataListener; +import com.alibaba.flink.shuffle.core.listener.DataRegionCreditListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReader; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWriter; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; + +/** A {@link DataPartition} implementation which writes data to and read data from local file. */ +public class LocalFileMapPartition extends BaseMapPartition { + + private static final Logger LOG = LoggerFactory.getLogger(LocalFileMapPartition.class); + + /** {@link DataPartitionMeta} of this data partition. */ + private final LocalFileMapPartitionMeta partitionMeta; + + /** Local file storing all data of this data partition. */ + private final LocalMapPartitionFile partitionFile; + + public LocalFileMapPartition( + StorageMeta storageMeta, + PartitionedDataStore dataStore, + JobID jobID, + DataSetID dataSetID, + MapPartitionID partitionID, + int numReducePartitions) { + super(dataStore, getSingleThreadExecutor(dataStore, storageMeta)); + + String storagePath = storageMeta.getStoragePath(); + File storageDir = new File(storagePath); + CommonUtils.checkArgument(storagePath.endsWith("/"), "Illegal storage path."); + CommonUtils.checkArgument(storageDir.exists(), "Storage path does not exist."); + CommonUtils.checkArgument(storageDir.isDirectory(), "Storage path is not a directory."); + + Configuration configuration = dataStore.getConfiguration(); + ConfigOption configOption = StorageOptions.STORAGE_FILE_TOLERABLE_FAILURES; + int tolerableFailures = CommonUtils.checkNotNull(configuration.getInteger(configOption)); + + String fileName = CommonUtils.randomHexString(32); + LocalMapPartitionFileMeta fileMeta = + new LocalMapPartitionFileMeta( + storagePath + fileName, + numReducePartitions, + LocalMapPartitionFile.LATEST_STORAGE_VERSION); + this.partitionFile = new LocalMapPartitionFile(fileMeta, tolerableFailures, true); + this.partitionMeta = + new LocalFileMapPartitionMeta(jobID, dataSetID, partitionID, fileMeta, storageMeta); + } + + /** + * Used to construct data partition instances when adding a finished external data partition or + * recovering after failure. + */ + public LocalFileMapPartition( + PartitionedDataStore dataStore, LocalFileMapPartitionMeta partitionMeta) { + super(dataStore, getSingleThreadExecutor(dataStore, partitionMeta.getStorageMeta())); + + this.partitionMeta = partitionMeta; + LocalMapPartitionFileMeta fileMeta = partitionMeta.getPartitionFileMeta(); + this.partitionFile = fileMeta.createPersistentFile(dataStore.getConfiguration()); + + if (!partitionFile.isConsumable()) { + partitionFile.setConsumable(false); + throw new ShuffleException("Partition data is not consumable."); + } + } + + private static SingleThreadExecutor getSingleThreadExecutor( + PartitionedDataStore dataStore, StorageMeta storageMeta) { + CommonUtils.checkArgument(dataStore != null, "Must be not null."); + CommonUtils.checkArgument(storageMeta != null, "Must be not null."); + + return dataStore.getExecutorPool(storageMeta).getSingleThreadExecutor(); + } + + @Override + public boolean isConsumable() { + return partitionFile.isConsumable(); + } + + @Override + protected DataPartitionReader getDataPartitionReader( + int startPartitionIndex, + int endPartitionIndex, + DataListener dataListener, + BacklogListener backlogListener, + FailureListener failureListener) { + // for different storage versions and formats, different file reader implementations are + // needed for backward compatibility, we must keep backward compatibility when upgrading + int storageVersion = partitionFile.getFileMeta().getStorageVersion(); + if (storageVersion <= 1) { + boolean dataChecksumEnabled = + dataStore + .getConfiguration() + .getBoolean(StorageOptions.STORAGE_ENABLE_DATA_CHECKSUM); + LocalMapPartitionFileReader fileReader = + new LocalMapPartitionFileReader( + dataChecksumEnabled, + startPartitionIndex, + endPartitionIndex, + partitionFile); + return new LocalFileMapPartitionReader( + fileReader, dataListener, backlogListener, failureListener); + } + + throw new ShuffleException( + String.format( + "Illegal storage version, current: %d, supported: %d.", + storageVersion, partitionFile.getLatestStorageVersion())); + } + + @Override + protected DataPartitionWriter getDataPartitionWriter( + MapPartitionID mapPartitionID, + DataRegionCreditListener dataRegionCreditListener, + FailureListener failureListener) { + boolean dataChecksumEnabled = + dataStore + .getConfiguration() + .getBoolean(StorageOptions.STORAGE_ENABLE_DATA_CHECKSUM); + return new LocalFileMapPartitionWriter( + dataChecksumEnabled, + mapPartitionID, + this, + dataRegionCreditListener, + failureListener, + partitionFile); + } + + @Override + protected void releaseInternal(Throwable releaseCause) throws Exception { + Throwable exception = null; + + try { + super.releaseInternal(releaseCause); + } catch (Throwable throwable) { + exception = throwable; + LOG.error("Fatal: failed to release base map partition.", throwable); + } + + try { + partitionFile.deleteFile(); + } catch (Throwable throwable) { + exception = throwable; + LOG.error("Fatal: failed to delete the partition file.", throwable); + } + + if (exception != null) { + ExceptionUtils.rethrowException(exception); + } + } + + @Override + public LocalFileMapPartitionMeta getPartitionMeta() { + return partitionMeta; + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionFactory.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionFactory.java new file mode 100644 index 00000000..331e16d5 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionFactory.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.core.storage.DataPartitionFactory; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; +import com.alibaba.flink.shuffle.core.storage.StorageType; +import com.alibaba.flink.shuffle.storage.utils.StorageConfigParseUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.NotThreadSafe; + +import java.io.DataInput; +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.stream.Collectors; + +/** {@link DataPartitionFactory} of {@link LocalFileMapPartition}. */ +@NotThreadSafe +public class LocalFileMapPartitionFactory implements DataPartitionFactory { + + private static final Logger LOG = LoggerFactory.getLogger(LocalFileMapPartitionFactory.class); + + protected final Queue ssdStorageMetas = new ArrayDeque<>(); + + protected final Queue hddStorageMetas = new ArrayDeque<>(); + + protected StorageType preferredStorageType; + + @Override + public void initialize(Configuration configuration) { + String directories = configuration.getString(StorageOptions.STORAGE_LOCAL_DATA_DIRS); + if (directories == null) { + throw new ConfigurationException( + StorageOptions.STORAGE_LOCAL_DATA_DIRS.key() + " is not configured."); + } + + String diskTypeString = configuration.getString(StorageOptions.STORAGE_PREFERRED_TYPE); + try { + preferredStorageType = + StorageType.valueOf(CommonUtils.checkNotNull(diskTypeString).trim()); + } catch (Exception exception) { + throw new ConfigurationException( + String.format( + "Illegal configured value %s for %s. Must be SSD, HDD or UNKNOWN.", + diskTypeString, StorageOptions.STORAGE_PREFERRED_TYPE.key())); + } + + StorageConfigParseUtils.ParsedPathLists parsedPathLists = + StorageConfigParseUtils.parseStoragePaths(directories); + if (parsedPathLists.getAllPaths().isEmpty()) { + throw new ConfigurationException( + String.format( + "No valid data dir is configured for %s.", + StorageOptions.STORAGE_LOCAL_DATA_DIRS.key())); + } + + this.ssdStorageMetas.addAll( + parsedPathLists.getSsdPaths().stream() + .map(storagePath -> new StorageMeta(storagePath, StorageType.SSD)) + .collect(Collectors.toList())); + this.hddStorageMetas.addAll( + parsedPathLists.getHddPaths().stream() + .map(storagePath -> new StorageMeta(storagePath, StorageType.HDD)) + .collect(Collectors.toList())); + + if (ssdStorageMetas.isEmpty() && preferredStorageType == StorageType.SSD) { + LOG.warn( + "No valid data dir of SSD type is configured for {}.", + StorageOptions.STORAGE_LOCAL_DATA_DIRS.key()); + } + + if (hddStorageMetas.isEmpty() && preferredStorageType == StorageType.HDD) { + LOG.warn( + "No valid data dir of HDD type is configured for {}.", + StorageOptions.STORAGE_LOCAL_DATA_DIRS.key()); + } + } + + /** + * Returns the next data path to use for data storage. It serves data path in a simple round + * robin way. More complicated strategies can be implemented in the future. + */ + protected StorageMeta getNextDataStorageMeta() { + switch (preferredStorageType) { + case SSD: + { + StorageMeta storageMeta = getStorageMeta(ssdStorageMetas); + if (storageMeta == null) { + storageMeta = getStorageMeta(hddStorageMetas); + } + return CommonUtils.checkNotNull(storageMeta); + } + case HDD: + { + StorageMeta storageMeta = getStorageMeta(hddStorageMetas); + if (storageMeta == null) { + storageMeta = getStorageMeta(ssdStorageMetas); + } + return CommonUtils.checkNotNull(storageMeta); + } + default: + throw new ShuffleException("Illegal preferred storage type."); + } + } + + private StorageMeta getStorageMeta(Queue storageMetas) { + StorageMeta storageMeta = storageMetas.poll(); + if (storageMeta != null) { + storageMetas.add(storageMeta); + } + return storageMeta; + } + + @Override + public LocalFileMapPartition createDataPartition( + PartitionedDataStore dataStore, + JobID jobID, + DataSetID dataSetID, + DataPartitionID dataPartitionID, + int numReducePartitions) { + CommonUtils.checkArgument(dataPartitionID != null, "Must be not null."); + CommonUtils.checkArgument(dataPartitionID instanceof MapPartitionID, "Illegal type."); + + MapPartitionID mapPartitionID = (MapPartitionID) dataPartitionID; + return new LocalFileMapPartition( + getNextDataStorageMeta(), + dataStore, + jobID, + dataSetID, + mapPartitionID, + numReducePartitions); + } + + @Override + public LocalFileMapPartition createDataPartition( + PartitionedDataStore dataStore, DataPartitionMeta partitionMeta) { + CommonUtils.checkArgument( + partitionMeta instanceof LocalFileMapPartitionMeta, "Illegal data partition type."); + + return new LocalFileMapPartition(dataStore, (LocalFileMapPartitionMeta) partitionMeta); + } + + @Override + public LocalFileMapPartitionMeta recoverDataPartitionMeta(DataInput dataInput) + throws IOException { + return LocalFileMapPartitionMeta.readFrom(dataInput); + } + + /** At the present, only MAP_PARTITION is supported. */ + @Override + public DataPartition.DataPartitionType getDataPartitionType() { + return DataPartition.DataPartitionType.MAP_PARTITION; + } + + // --------------------------------------------------------------------------------------------- + // For test + // --------------------------------------------------------------------------------------------- + + StorageType getPreferredStorageType() { + return preferredStorageType; + } + + List getSsdStorageMetas() { + return new ArrayList<>(ssdStorageMetas); + } + + List getHddStorageMetas() { + return new ArrayList<>(hddStorageMetas); + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionMeta.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionMeta.java new file mode 100644 index 00000000..db127428 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionMeta.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; +import com.alibaba.flink.shuffle.core.storage.MapPartitionMeta; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Objects; + +/** {@link DataPartitionMeta} of {@link LocalFileMapPartition}. */ +public class LocalFileMapPartitionMeta extends MapPartitionMeta { + + private static final long serialVersionUID = -7947536484763210922L; + + /** Meta of the corresponding {@link LocalMapPartitionFile}. */ + private final LocalMapPartitionFileMeta fileMeta; + + public LocalFileMapPartitionMeta( + JobID jobID, + DataSetID dataSetID, + MapPartitionID partitionID, + LocalMapPartitionFileMeta fileMeta, + StorageMeta storageMeta) { + super(jobID, dataSetID, partitionID, storageMeta); + + CommonUtils.checkArgument(fileMeta != null, "Must be not null."); + this.fileMeta = fileMeta; + } + + /** + * Reconstructs the {@link DataPartitionMeta} from {@link DataInput} when recovering from + * failure. + */ + public static LocalFileMapPartitionMeta readFrom(DataInput dataInput) throws IOException { + JobID jobID = JobID.readFrom(dataInput); + DataSetID dataSetID = DataSetID.readFrom(dataInput); + MapPartitionID partitionID = MapPartitionID.readFrom(dataInput); + + StorageMeta storageMeta = StorageMeta.readFrom(dataInput); + LocalMapPartitionFileMeta fileMeta = LocalMapPartitionFileMeta.readFrom(dataInput); + + return new LocalFileMapPartitionMeta(jobID, dataSetID, partitionID, fileMeta, storageMeta); + } + + @Override + public String getPartitionFactoryClassName() { + return LocalFileMapPartitionFactory.class.getName(); + } + + public LocalMapPartitionFileMeta getPartitionFileMeta() { + return fileMeta; + } + + @Override + public void writeTo(DataOutput dataOutput) throws Exception { + jobID.writeTo(dataOutput); + dataSetID.writeTo(dataOutput); + partitionID.writeTo(dataOutput); + storageMeta.writeTo(dataOutput); + fileMeta.writeTo(dataOutput); + } + + @Override + public boolean equals(Object that) { + if (this == that) { + return true; + } + + if (!(that instanceof LocalFileMapPartitionMeta)) { + return false; + } + + LocalFileMapPartitionMeta partitionMeta = (LocalFileMapPartitionMeta) that; + return Objects.equals(jobID, partitionMeta.jobID) + && Objects.equals(dataSetID, partitionMeta.dataSetID) + && Objects.equals(partitionID, partitionMeta.partitionID) + && Objects.equals(storageMeta, partitionMeta.storageMeta) + && Objects.equals(fileMeta, partitionMeta.fileMeta); + } + + @Override + public int hashCode() { + return Objects.hash(jobID, dataSetID, partitionID, storageMeta, fileMeta); + } + + @Override + public String toString() { + return "LocalFileMapPartitionMeta{" + + "JobID=" + + jobID + + ", DataSetID=" + + dataSetID + + ", PartitionID=" + + partitionID + + ", StorageMeta=" + + storageMeta + + ", FileMeta=" + + fileMeta + + '}'; + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionReader.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionReader.java new file mode 100644 index 00000000..d73e6607 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionReader.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.core.listener.BacklogListener; +import com.alibaba.flink.shuffle.core.listener.DataListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.memory.BufferRecycler; +import com.alibaba.flink.shuffle.core.storage.BufferQueue; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReader; +import com.alibaba.flink.shuffle.core.utils.BufferUtils; +import com.alibaba.flink.shuffle.core.utils.ListenerUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; + +/** {@link DataPartitionReader} for {@link LocalFileMapPartition}. */ +public class LocalFileMapPartitionReader extends BaseDataPartitionReader { + + private static final Logger LOG = LoggerFactory.getLogger(LocalFileMapPartitionReader.class); + + /** File reader responsible for reading data from partition file. */ + private final LocalMapPartitionFileReader fileReader; + + /** Whether this {@link DataPartitionReader} has been opened or not. */ + private boolean isOpened; + + public LocalFileMapPartitionReader( + LocalMapPartitionFileReader fileReader, + DataListener dataListener, + BacklogListener backlogListener, + FailureListener failureListener) { + super(dataListener, backlogListener, failureListener); + + CommonUtils.checkArgument(fileReader != null, "Must be not null."); + this.fileReader = fileReader; + } + + @Override + public void open() throws Exception { + CommonUtils.checkState(!isOpened, "Partition reader has been opened."); + + isOpened = true; + fileReader.open(); + } + + /** + * Reads data through the corresponding {@link LocalMapPartitionFileReader} and returns true if + * there is remaining data to be read with this data partition reader. + */ + @Override + public boolean readData(BufferQueue buffers, BufferRecycler recycler) throws Exception { + CommonUtils.checkArgument(buffers != null, "Must be not null."); + CommonUtils.checkArgument(recycler != null, "Must be not null."); + + CommonUtils.checkState(isOpened, "Partition reader is not opened."); + + boolean hasReaming = fileReader.hasRemaining(); + boolean continueReading = hasReaming; + int numDataBuffers = 0; + while (continueReading) { + ByteBuffer buffer = buffers.poll(); + if (buffer == null) { + break; + } + + try { + continueReading = fileReader.readBuffer(buffer); + } catch (Throwable throwable) { + BufferUtils.recycleBuffer(buffer, recycler); + throw throwable; + } + + hasReaming = fileReader.hasRemaining(); + addBuffer(new Buffer(buffer, recycler, buffer.remaining()), hasReaming); + ++numDataBuffers; + } + + if (numDataBuffers > 0) { + notifyBacklog(numDataBuffers); + } + + if (!hasReaming) { + closeReader(); + } + + return hasReaming; + } + + @Override + public long getPriority() { + CommonUtils.checkState(isOpened, "Partition reader is not opened."); + + // small file offset has high reading priority, it means when reading the partition file, + // the reader always reads data in file offset order which can reduce random IO and lead to + // more sequential reading thus is better for IO performance + return fileReader.geConsumingOffset(); + } + + @Override + public boolean isOpened() { + return isOpened; + } + + @Override + public void release(Throwable releaseCause) throws Exception { + Throwable exception = null; + + try { + super.release(releaseCause); + } catch (Throwable throwable) { + exception = throwable; + LOG.error("Failed to release base partition reader.", throwable); + } + + try { + fileReader.close(); + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + LOG.error("Failed to release file reader.", throwable); + } + + if (exception != null) { + ExceptionUtils.rethrowException(exception); + } + } + + @Override + protected void closeReader() throws Exception { + try { + fileReader.finishReading(); + + // mark the reader as finished after closing succeeded + super.closeReader(); + } catch (Throwable throwable) { + ListenerUtils.notifyFailure(failureListener, throwable); + ExceptionUtils.rethrowException(throwable); + } + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionWriter.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionWriter.java new file mode 100644 index 00000000..7de2b772 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionWriter.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.listener.DataRegionCreditListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWriter; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** {@link DataPartitionWriter} for {@link LocalFileMapPartition}. */ +public class LocalFileMapPartitionWriter extends BaseMapPartitionWriter { + + private static final Logger LOG = LoggerFactory.getLogger(LocalFileMapPartitionWriter.class); + + /** File writer used to write data to local file. */ + private final LocalMapPartitionFileWriter fileWriter; + + public LocalFileMapPartitionWriter( + boolean dataChecksumEnabled, + MapPartitionID mapPartitionID, + BaseMapPartition dataPartition, + DataRegionCreditListener dataRegionCreditListener, + FailureListener failureListener, + LocalMapPartitionFile partitionFile) { + super(mapPartitionID, dataPartition, dataRegionCreditListener, failureListener); + + this.fileWriter = + new LocalMapPartitionFileWriter( + partitionFile, + dataPartition.getPartitionWritingTask().minBuffersToWrite, + dataChecksumEnabled); + } + + @Override + protected void processRegionStartedMarker(BufferOrMarker.RegionStartedMarker marker) + throws Exception { + super.processRegionStartedMarker(marker); + + fileWriter.startRegion(marker.isBroadcastRegion()); + } + + @Override + protected void processDataBuffer(BufferOrMarker.DataBuffer buffer) throws Exception { + if (!fileWriter.isOpened()) { + fileWriter.open(); + } + + // the file writer is responsible for releasing the target buffer + fileWriter.writeBuffer(buffer); + } + + @Override + protected void processRegionFinishedMarker(BufferOrMarker.RegionFinishedMarker marker) + throws Exception { + super.processRegionFinishedMarker(marker); + + fileWriter.finishRegion(); + } + + @Override + protected void processInputFinishedMarker(BufferOrMarker.InputFinishedMarker marker) + throws Exception { + fileWriter.finishWriting(); + + super.processInputFinishedMarker(marker); + } + + @Override + public void release(Throwable throwable) throws Exception { + Throwable error = null; + + try { + super.release(throwable); + } catch (Throwable exception) { + error = exception; + LOG.error("Failed to release base partition writer.", exception); + } + + try { + fileWriter.close(); + } catch (Throwable exception) { + error = error == null ? exception : error; + LOG.error("Failed to release file writer.", exception); + } + + if (error != null) { + ExceptionUtils.rethrowException(error); + } + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFile.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFile.java new file mode 100644 index 00000000..ac6d0bfb --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFile.java @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.storage.exception.FileCorruptedException; +import com.alibaba.flink.shuffle.storage.utils.IOUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.NotThreadSafe; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.HashSet; +import java.util.Set; +import java.util.zip.CRC32; +import java.util.zip.Checksum; + +/** + * Local file based {@link PersistentFile} implementation which is used to store data partition data + * to and read data partition data from local file. It consists of two files, the data file is used + * to store data and the index file is used to store index information of the data file for reading. + * + *

The data file contains one or multiple sequential data regions. In each data region, data + * of different reduce partitions will be stored in reduce partition index order. In the index + * file, for each data region in the data file, there will be an index region. Each index region + * contains the same number of index entries with the number of reduce partitions and each index + * entry point to the data of the corresponding reduce partition in the data file. + */ +@NotThreadSafe +public class LocalMapPartitionFile implements PersistentFile { + + private static final Logger LOG = LoggerFactory.getLogger(LocalMapPartitionFile.class); + + /** + * Latest storage version used. One need to increase this version after changing the storage + * format. + */ + public static final int LATEST_STORAGE_VERSION = 1; + + /** + * Size of each index entry in the index file: 8 for file offset, 8 for number of bytes and 4 + * bytes for magic number. + */ + public static final int INDEX_ENTRY_SIZE = 8 + 8 + 4; + + /** + * Size of index data checksum, 8 for the number of data regions and 8 for magic number. The + * index data checksum sits at the tail the index data file. + */ + public static final int INDEX_DATA_CHECKSUM_SIZE = 8 + 8; + + /** Name suffix of the data file. */ + public static final String DATA_FILE_SUFFIX = ".data"; + + /** Name suffix of the index file. */ + public static final String INDEX_FILE_SUFFIX = ".index"; + + /** Name suffix of the in-progressing partial data file. */ + public static final String PARTIAL_DATA_FILE_SUFFIX = ".data.partial"; + + /** Name suffix of the in-progressing partial index file. */ + public static final String PARTIAL_INDEX_FILE_SUFFIX = ".index.partial"; + + /** Meta information of this data partition file. */ + private final LocalMapPartitionFileMeta fileMeta; + + /** + * Maximum number of {@link IOException}s can be tolerated before marking this partition file as + * corrupted. + */ + private final int tolerableFailures; + + /** All data readers that have opened and are reading data from the file. */ + private final Set readers = new HashSet<>(); + + /** Opened data file channel shared by all data readers for data reading. */ + @Nullable private FileChannel dataReadingChannel; + + /** Opened index file channel shared by all data readers for data reading. */ + @Nullable private FileChannel indexReadingChannel; + + /** Counter recording all data reading failures. */ + private int failureCounter; + + /** Whether this persistent file is consumable and ready for data reading. */ + private volatile boolean isConsumable; + + /** Whether the checksum of the corresponding index data is verified or not. */ + private volatile boolean indexDataChecksumVerified; + + public LocalMapPartitionFile( + LocalMapPartitionFileMeta fileMeta, + int tolerableFailures, + boolean indexDataChecksumVerified) { + CommonUtils.checkArgument(fileMeta != null, "Must be not null."); + CommonUtils.checkArgument(tolerableFailures >= 0, "Must be non-negative."); + + this.fileMeta = fileMeta; + this.tolerableFailures = tolerableFailures; + this.indexDataChecksumVerified = indexDataChecksumVerified; + } + + @Override + public int getLatestStorageVersion() { + return LATEST_STORAGE_VERSION; + } + + @Override + public boolean isConsumable() { + return isConsumable + && Files.isReadable(fileMeta.getDataFilePath()) + && Files.isReadable(fileMeta.getIndexFilePath()); + } + + @Override + public LocalMapPartitionFileMeta getFileMeta() { + return fileMeta; + } + + /** + * Opens this file for reading. This method maintains the reader set and will only open the new + * file channels when the first reader is registered. It guarantees that all opened resources + * will be released if any exception occurs. + */ + public void openFile(Object reader) throws Exception { + CommonUtils.checkArgument(reader != null, "Must be not null."); + CommonUtils.checkState(fileMeta.getStorageVersion() <= 1, "Illegal storage version."); + + if (!readers.isEmpty()) { + readers.add(reader); + return; + } + + if (!indexDataChecksumVerified) { + verifyIndexDataChecksum(); + indexDataChecksumVerified = true; + } + + try { + Path dataFilePath = fileMeta.getDataFilePath(); + dataReadingChannel = IOUtils.openReadableFileChannel(dataFilePath); + Path indexFilePath = fileMeta.getIndexFilePath(); + indexReadingChannel = IOUtils.openReadableFileChannel(indexFilePath); + + readers.add(reader); + } catch (Throwable throwable) { + CommonUtils.runQuietly(this::closeFileChannels); + LOG.error("Failed to open the partition for reading: {}.", fileMeta.getFilePath()); + throw throwable; + } + } + + /** + * Marks the given file reader as closed. This method maintains the reader set and will only + * close the opened file channels when all file readers have been closed. + */ + public void closeFile(Object reader) throws Exception { + CommonUtils.checkArgument(reader != null, "Must be not null."); + + readers.remove(reader); + if (readers.isEmpty()) { + closeFileChannels(); + } + } + + @Override + public void deleteFile() throws Exception { + LOG.info("Deleting the partition file: {}.", fileMeta.getFilePath()); + Throwable exception = null; + + try { + // close file channel first if opened + closeFileChannels(); + } catch (Throwable throwable) { + exception = throwable; + } + + Path filePath = fileMeta.getDataFilePath(); + try { + CommonUtils.deleteFileWithRetry(filePath); + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + LOG.error("Failed to delete data file: {}.", filePath, throwable); + } + + filePath = fileMeta.getIndexFilePath(); + try { + CommonUtils.deleteFileWithRetry(filePath); + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + LOG.error("Failed to delete index file: {}.", filePath, throwable); + } + + filePath = fileMeta.getPartialDataFilePath(); + try { + CommonUtils.deleteFileWithRetry(filePath); + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + LOG.error("Failed to delete data file: {}.", filePath, throwable); + } + + filePath = fileMeta.getPartialIndexFilePath(); + try { + CommonUtils.deleteFileWithRetry(filePath); + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + LOG.error("Failed to delete index file: {}.", filePath, throwable); + } + + if (exception != null) { + ExceptionUtils.rethrowException(exception); + } + } + + @Override + public void onError(Throwable throwable) { + if (throwable instanceof ClosedChannelException) { + return; + } + + if (throwable instanceof FileCorruptedException + || (throwable instanceof IOException && ++failureCounter > tolerableFailures)) { + // to trigger partition not found exception + setConsumable(false); + LOG.error("Corrupted partition file: {}.", fileMeta.getFilePath(), throwable); + } + } + + @Override + public void setConsumable(boolean isConsumable) { + this.isConsumable = isConsumable; + + // delete the physical files eagerly + if (!isConsumable) { + CommonUtils.runQuietly(this::closeFileChannels); + CommonUtils.runQuietly(this::deleteFile); + } + } + + @Nullable + public FileChannel getDataReadingChannel() { + return dataReadingChannel; + } + + @Nullable + public FileChannel getIndexReadingChannel() { + return indexReadingChannel; + } + + /** Returns the total size of all index entries for a data region. */ + public long getIndexRegionSize() { + return fileMeta.getNumReducePartitions() * (long) INDEX_ENTRY_SIZE; + } + + /** Returns the offset in the index file of the target index entry. */ + public long getIndexEntryOffset(int regionIndex, int targetPartitionIndex) { + return regionIndex * getIndexRegionSize() + targetPartitionIndex * (long) INDEX_ENTRY_SIZE; + } + + public static Path getDataFilePath(String baseFilePath) { + CommonUtils.checkArgument(baseFilePath != null, "Must be not null."); + + return new File(baseFilePath + DATA_FILE_SUFFIX).toPath(); + } + + public static Path getIndexFilePath(String baseFilePath) { + CommonUtils.checkArgument(baseFilePath != null, "Must be not null."); + + return new File(baseFilePath + INDEX_FILE_SUFFIX).toPath(); + } + + public static Path getPartialDataFilePath(String baseFilePath) { + CommonUtils.checkArgument(baseFilePath != null, "Must be not null."); + + return new File(baseFilePath + PARTIAL_DATA_FILE_SUFFIX).toPath(); + } + + public static Path getPartialIndexFilePath(String baseFilePath) { + CommonUtils.checkArgument(baseFilePath != null, "Must be not null."); + + return new File(baseFilePath + PARTIAL_INDEX_FILE_SUFFIX).toPath(); + } + + /** Closes the opened data file channel and index file channel. */ + private void closeFileChannels() throws Exception { + Throwable exception = null; + try { + CommonUtils.closeWithRetry(dataReadingChannel); + dataReadingChannel = null; + } catch (Throwable throwable) { + exception = throwable; + LOG.error( + "Failed to close data file channel: {}.", + fileMeta.getDataFilePath(), + throwable); + } + + try { + CommonUtils.closeWithRetry(indexReadingChannel); + indexReadingChannel = null; + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + LOG.error( + "Failed to close index file channel: {}.", + fileMeta.getIndexFilePath(), + throwable); + } + + if (exception != null) { + ExceptionUtils.rethrowException(exception); + } + } + + private void verifyIndexDataChecksum() throws Exception { + FileChannel fileChannel = null; + try { + fileChannel = IOUtils.openReadableFileChannel(fileMeta.getIndexFilePath()); + int numPartitions = fileMeta.getNumReducePartitions(); + int indexRegionSize = IOUtils.calculateIndexRegionSize(numPartitions); + + long remainingDataSize = fileChannel.size() - INDEX_DATA_CHECKSUM_SIZE; + if (remainingDataSize % indexRegionSize != 0) { + throw new FileCorruptedException(); + } + + Checksum checksum = new CRC32(); + int numRegions = CommonUtils.checkedDownCast(remainingDataSize / indexRegionSize); + ByteBuffer indexBuffer = IOUtils.allocateIndexBuffer(numPartitions); + + while (remainingDataSize > 0) { + long length = Math.min(remainingDataSize, indexBuffer.capacity()); + IOUtils.readBuffer(fileChannel, indexBuffer, CommonUtils.checkedDownCast(length)); + remainingDataSize -= length; + while (indexBuffer.hasRemaining()) { + checksum.update(indexBuffer.get()); + } + } + + IOUtils.readBuffer(fileChannel, indexBuffer, INDEX_DATA_CHECKSUM_SIZE); + if (numRegions != indexBuffer.getLong() + || checksum.getValue() != indexBuffer.getLong()) { + throw new FileCorruptedException(); + } + } catch (Throwable throwable) { + setConsumable(false); + LOG.error("Failed to verify index data checksum, releasing partition data.", throwable); + throw throwable; + } finally { + CommonUtils.closeWithRetry(fileChannel); + } + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileMeta.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileMeta.java new file mode 100644 index 00000000..13bf4c01 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileMeta.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.config.StorageOptions; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Objects; + +/** {@link PersistentFileMeta} of {@link LocalMapPartitionFile}. */ +public class LocalMapPartitionFileMeta implements PersistentFileMeta { + + private static final long serialVersionUID = -6682157834905760822L; + + /** File name (without suffix) of both the data file and index file. */ + private final String filePath; + + /** Number of reduce partitions in the corresponding partition file. */ + private final int numReducePartitions; + + /** Path of the data file. */ + private final Path dataFilePath; + + /** Path of the index file. */ + private final Path indexFilePath; + + /** Path of the in-progressing partial data file. */ + private final Path partialDataFilePath; + + /** Path of the in-progressing partial index file. */ + private final Path partialIndexFilePath; + + /** + * Storage version of the target {@link PersistentFile}. Different versions may need different + * processing logics. + */ + private final int storageVersion; + + public LocalMapPartitionFileMeta(String filePath, int numReducePartitions, int storageVersion) { + CommonUtils.checkArgument(filePath != null, "Must be not null."); + CommonUtils.checkArgument(numReducePartitions > 0, "Must be positive."); + + this.filePath = filePath; + this.numReducePartitions = numReducePartitions; + this.storageVersion = storageVersion; + + this.partialDataFilePath = LocalMapPartitionFile.getPartialDataFilePath(filePath); + this.partialIndexFilePath = LocalMapPartitionFile.getPartialIndexFilePath(filePath); + this.dataFilePath = LocalMapPartitionFile.getDataFilePath(filePath); + this.indexFilePath = LocalMapPartitionFile.getIndexFilePath(filePath); + } + + @Override + public int getStorageVersion() { + return storageVersion; + } + + /** + * Reconstructs the {@link LocalMapPartitionFileMeta} instance from the {@link DataInput} when + * recovering from failure. + */ + public static LocalMapPartitionFileMeta readFrom(DataInput dataInput) throws IOException { + int storageVersion = dataInput.readInt(); + return readFrom(dataInput, storageVersion); + } + + private static LocalMapPartitionFileMeta readFrom(DataInput dataInput, int storageVersion) + throws IOException { + if (storageVersion <= 1) { + int numReducePartitions = dataInput.readInt(); + String filePath = dataInput.readUTF(); + return new LocalMapPartitionFileMeta(filePath, numReducePartitions, storageVersion); + } + + throw new ShuffleException( + String.format( + "Illegal storage version, data format version: %d, supported version: %d.", + storageVersion, LocalMapPartitionFile.LATEST_STORAGE_VERSION)); + } + + public String getFilePath() { + return filePath; + } + + public Path getDataFilePath() { + return dataFilePath; + } + + public Path getIndexFilePath() { + return indexFilePath; + } + + public Path getPartialDataFilePath() { + return partialDataFilePath; + } + + public Path getPartialIndexFilePath() { + return partialIndexFilePath; + } + + public int getNumReducePartitions() { + return numReducePartitions; + } + + @Override + public LocalMapPartitionFile createPersistentFile(Configuration configuration) { + ConfigOption configOption = StorageOptions.STORAGE_FILE_TOLERABLE_FAILURES; + int tolerableFailures = CommonUtils.checkNotNull(configuration.getInteger(configOption)); + + LocalMapPartitionFile partitionFile = + new LocalMapPartitionFile(this, tolerableFailures, false); + partitionFile.setConsumable(true); + return partitionFile; + } + + @Override + public void writeTo(DataOutput dataOutput) throws Exception { + dataOutput.writeInt(storageVersion); + dataOutput.writeInt(numReducePartitions); + dataOutput.writeUTF(filePath); + } + + @Override + public boolean equals(Object that) { + if (this == that) { + return true; + } + + if (!(that instanceof LocalMapPartitionFileMeta)) { + return false; + } + + LocalMapPartitionFileMeta fileMeta = (LocalMapPartitionFileMeta) that; + return numReducePartitions == fileMeta.numReducePartitions + && storageVersion == fileMeta.storageVersion + && Objects.equals(filePath, fileMeta.filePath); + } + + @Override + public int hashCode() { + return Objects.hash(filePath, numReducePartitions, storageVersion); + } + + @Override + public String toString() { + return "LocalMapPartitionFileMeta{" + + "FilePath=" + + filePath + + ", NumReducePartitions=" + + numReducePartitions + + ", StorageVersion=" + + storageVersion + + '}'; + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileReader.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileReader.java new file mode 100644 index 00000000..c4c42841 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileReader.java @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.storage.exception.FileCorruptedException; +import com.alibaba.flink.shuffle.storage.utils.IOUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.NotThreadSafe; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Files; + +/** + * File reader for the {@link LocalMapPartitionFile}. Each {@link LocalMapPartitionFileReader} only + * reads the data belonging to the target reduce partition range ({@link #startPartitionIndex} to + * {@link #endPartitionIndex}). + */ +@NotThreadSafe +public class LocalMapPartitionFileReader { + + private static final Logger LOG = LoggerFactory.getLogger(LocalMapPartitionFileReader.class); + + /** Buffer for reading an index entry to memory. */ + private final ByteBuffer indexBuffer; + + /** Buffer for reading buffer header to memory. */ + private final ByteBuffer headerBuffer; + + /** Target partition file to be read data from. */ + private final LocalMapPartitionFile partitionFile; + + /** First index (inclusive) of the target reduce partitions to be read. */ + private final int startPartitionIndex; + + /** Last index (inclusive) of the target reduce partitions to be read. */ + private final int endPartitionIndex; + + /** Whether to enable data checksum or not. */ + private final boolean dataChecksumEnabled; + + /** Number of data regions in the target file. */ + private int numRegions; + + /** Opened data file channel to read data from. */ + private FileChannel dataFileChannel; + + /** Opened index file channel to read index from. */ + private FileChannel indexFileChannel; + + /** Number of remaining reduce partitions to read in the current data region. */ + private int numRemainingPartitions; + + /** Current data region index to read data from. */ + private int currentDataRegion = -1; + + /** File offset in data file to read data from. */ + private long dataConsumingOffset; + + /** Number of remaining bytes to read from the current reduce partition. */ + private long currentPartitionRemainingBytes; + + /** Whether this file reader is closed or not. A closed file reader can not read any data. */ + private boolean isClosed; + + /** Whether this file reader is opened or not. */ + private boolean isOpened; + + public LocalMapPartitionFileReader( + boolean dataChecksumEnabled, + int startPartitionIndex, + int endPartitionIndex, + LocalMapPartitionFile partitionFile) { + CommonUtils.checkArgument(partitionFile != null, "Must be not null."); + CommonUtils.checkArgument(startPartitionIndex >= 0, "Must be non-negative."); + CommonUtils.checkArgument( + endPartitionIndex >= startPartitionIndex, + "Ending partition index must be no smaller than starting partition index."); + CommonUtils.checkState( + endPartitionIndex < partitionFile.getFileMeta().getNumReducePartitions(), + "Ending partition index must be smaller than number of reduce partitions."); + + this.partitionFile = partitionFile; + this.startPartitionIndex = startPartitionIndex; + this.endPartitionIndex = endPartitionIndex; + this.dataChecksumEnabled = dataChecksumEnabled; + + int indexBufferSize = + LocalMapPartitionFile.INDEX_ENTRY_SIZE + * (endPartitionIndex - startPartitionIndex + 1); + this.indexBuffer = CommonUtils.allocateDirectByteBuffer(indexBufferSize); + this.headerBuffer = + CommonUtils.allocateDirectByteBuffer( + IOUtils.getHeaderBufferSizeOfLocalMapPartitionFile( + partitionFile.getFileMeta().getStorageVersion())); + } + + public void open() throws Exception { + CommonUtils.checkState(!isOpened, "Partition file reader has been opened."); + CommonUtils.checkState(!isClosed, "Partition file reader has been closed."); + + try { + isOpened = true; + partitionFile.openFile(this); + dataFileChannel = CommonUtils.checkNotNull(partitionFile.getDataReadingChannel()); + indexFileChannel = CommonUtils.checkNotNull(partitionFile.getIndexReadingChannel()); + + long indexFileSize = indexFileChannel.size(); + long indexRegionSize = partitionFile.getIndexRegionSize(); + + // if this checks fail, the partition file must be corrupted + if (indexFileSize % indexRegionSize != LocalMapPartitionFile.INDEX_DATA_CHECKSUM_SIZE) { + throw new FileCorruptedException(); + } + + numRegions = CommonUtils.checkedDownCast(indexFileChannel.size() / indexRegionSize); + updateConsumingOffset(); + } catch (Throwable throwable) { + CommonUtils.runQuietly(this::close); + partitionFile.onError(throwable); + + LOG.debug("Failed to open partition file.", throwable); + throw throwable; + } + } + + /** + * Reads data to the target buffer and returns true if there is remaining data in current data + * region. The caller is responsible for recycling the target buffer if any exception occurs. + */ + public boolean readBuffer(ByteBuffer buffer) throws Exception { + CommonUtils.checkArgument(buffer != null, "Must be not null."); + + CommonUtils.checkState(isOpened, "Partition file reader is not opened."); + CommonUtils.checkState(!isClosed, "Partition file reader has been closed."); + // one must check remaining before reading + CommonUtils.checkState(hasRemaining(), "No remaining data to read."); + + try { + dataFileChannel.position(dataConsumingOffset); + currentPartitionRemainingBytes -= + IOUtils.readBuffer( + dataFileChannel, + headerBuffer, + buffer, + headerBuffer.capacity(), + dataChecksumEnabled); + + // if this check fails, the partition file must be corrupted + if (currentPartitionRemainingBytes < 0) { + throw new FileCorruptedException(); + } else if (currentPartitionRemainingBytes == 0) { + int prevDataRegion = currentDataRegion; + updateConsumingOffset(); + return prevDataRegion == currentDataRegion && currentPartitionRemainingBytes > 0; + } + + dataConsumingOffset = dataFileChannel.position(); + return true; + } catch (Throwable throwable) { + partitionFile.onError(throwable); + LOG.debug("Failed to read partition file.", throwable); + throw throwable; + } + } + + /** Returns true if there is remaining bytes in the target partition file. */ + public boolean hasRemaining() { + return currentPartitionRemainingBytes > 0; + } + + /** Returns the next data reading offset in the target data partition file. */ + public long geConsumingOffset() { + return dataConsumingOffset; + } + + /** Updates to next data reading offset in the target data partition file. */ + private void updateConsumingOffset() throws IOException { + while (currentPartitionRemainingBytes == 0 + && (currentDataRegion < numRegions - 1 || numRemainingPartitions > 0)) { + if (numRemainingPartitions <= 0) { + ++currentDataRegion; + numRemainingPartitions = endPartitionIndex - startPartitionIndex + 1; + + // read the target index entry to the target index buffer + indexFileChannel.position( + partitionFile.getIndexEntryOffset(currentDataRegion, startPartitionIndex)); + IOUtils.readBuffer(indexFileChannel, indexBuffer, indexBuffer.capacity()); + } + + // get the data file offset and the data size + dataConsumingOffset = indexBuffer.getLong(); + currentPartitionRemainingBytes = indexBuffer.getLong(); + int magicNumber = indexBuffer.getInt(); + --numRemainingPartitions; + + // if these checks fail, the partition file must be corrupted + if (dataConsumingOffset < 0 + || dataConsumingOffset + currentPartitionRemainingBytes > dataFileChannel.size() + || currentPartitionRemainingBytes < 0 + || magicNumber != IOUtils.MAGIC_NUMBER) { + throw new FileCorruptedException(); + } + } + } + + public void finishReading() throws Exception { + close(); + + LocalMapPartitionFileMeta fileMeta = partitionFile.getFileMeta(); + CommonUtils.checkState( + Files.exists(fileMeta.getDataFilePath()), "Data file has been deleted."); + CommonUtils.checkState( + Files.exists(fileMeta.getIndexFilePath()), "Index file has been deleted."); + } + + /** Closes this partition file reader. After closed, no data can be read any more. */ + public void close() throws Exception { + isClosed = true; + partitionFile.closeFile(this); + } + + public boolean isOpened() { + return isOpened; + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileWriter.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileWriter.java new file mode 100644 index 00000000..ddea1899 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileWriter.java @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.storage.utils.IOUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.zip.CRC32; +import java.util.zip.Checksum; + +/** File writer for the {@link LocalMapPartitionFile}. */ +public class LocalMapPartitionFileWriter { + + private static final Logger LOG = LoggerFactory.getLogger(LocalMapPartitionFileWriter.class); + + /** Maximum number of data buffers can be cached in {@link #dataBuffers} before flushing. */ + private final int dataBufferCacheSize; + + /** + * All pending {@link BufferOrMarker.DataBuffer}s to be written. This list is for batch writing + * which can be better for IO performance. + */ + protected final List dataBuffers; + + /** Target {@link LocalMapPartitionFile} to write data to. */ + private final LocalMapPartitionFile partitionFile; + + /** Number of bytes of all reduce partitions in the current data region. */ + private final long[] numReducePartitionBytes; + + /** Caches the index data before flushing the data to target index file. */ + private final ByteBuffer indexBuffer; + + /** Opened data file channel to write data buffers to. */ + private FileChannel dataFileChannel; + + /** Opened index file channel to write index info to. */ + private FileChannel indexFileChannel; + + /** Current reduce partition index to which the data buffer is written. */ + private int currentReducePartition; + + /** Total bytes of data have been written to the target partition file. */ + private long totalBytes; + + /** Staring offset in the target data file of the current data region. */ + private long regionStartingOffset; + + /** + * Whether current data region is a broadcast region or not. If true, buffers added to this + * region will be written to all reduce partitions. + */ + private boolean isBroadcastRegion; + + /** Whether this file writer has been closed or not. */ + private boolean isClosed; + + /** Whether this file writer has been opened or not. */ + private boolean isOpened; + + /** Number of finished data regions in the target {@link LocalMapPartitionFile} currently. */ + private long numDataRegions; + + /** + * Checksum util to calculate the checksum value the index data. The completeness of index data + * is important because it is used to index the real data. The lost of index data just means the + * lost of the real data. + */ + private final Checksum checksum = new CRC32(); + + /** Whether to enable data checksum or not. */ + private final boolean dataChecksumEnabled; + + public LocalMapPartitionFileWriter( + LocalMapPartitionFile partitionFile, + int dataBufferCacheSize, + boolean dataChecksumEnabled) { + CommonUtils.checkArgument(partitionFile != null, "Must be not null."); + CommonUtils.checkArgument(dataBufferCacheSize > 0, "Must be positive."); + + this.partitionFile = partitionFile; + this.dataBufferCacheSize = dataBufferCacheSize; + this.dataBuffers = new ArrayList<>(2 * dataBufferCacheSize); + this.dataChecksumEnabled = dataChecksumEnabled; + + LocalMapPartitionFileMeta fileMeta = partitionFile.getFileMeta(); + int numReducePartitions = fileMeta.getNumReducePartitions(); + this.numReducePartitionBytes = new long[numReducePartitions]; + this.indexBuffer = IOUtils.allocateIndexBuffer(numReducePartitions); + } + + public void open() throws Exception { + CommonUtils.checkState(!isOpened, "Partition file writer has been opened."); + CommonUtils.checkState(!isClosed, "Partition file writer has been closed."); + + try { + isOpened = true; + Path dataFilePath = partitionFile.getFileMeta().getPartialDataFilePath(); + dataFileChannel = IOUtils.createWritableFileChannel(dataFilePath); + + Path indexFilePath = partitionFile.getFileMeta().getPartialIndexFilePath(); + indexFileChannel = IOUtils.createWritableFileChannel(indexFilePath); + } catch (Throwable throwable) { + CommonUtils.runQuietly(this::close); + throw throwable; + } + } + + /** + * Writes the given data buffer of the corresponding reduce partition to the target {@link + * LocalMapPartitionFile}. + */ + public void writeBuffer(BufferOrMarker.DataBuffer dataBuffer) throws IOException { + CommonUtils.checkArgument(dataBuffer != null, "Must be not null."); + + CommonUtils.checkState(isOpened, "Partition file writer is not opened."); + CommonUtils.checkState(!isClosed, "Partition file writer has been closed."); + + if (!dataBuffer.getBuffer().isReadable()) { + dataBuffer.release(); + return; + } + + dataBuffers.add(dataBuffer); + if (dataBuffers.size() >= dataBufferCacheSize) { + flushDataBuffers(); + CommonUtils.checkState( + dataBuffers.isEmpty(), + "Leaking buffers, some buffers are not released after flush."); + } + } + + private void flushDataBuffers() throws IOException { + try { + checkNotClosed(); + + if (!dataBuffers.isEmpty()) { + ByteBuffer[] bufferWithHeaders = collectBufferWithHeaders(); + IOUtils.writeBuffers(dataFileChannel, bufferWithHeaders); + } + } finally { + releaseAllDataBuffers(); + } + } + + private ByteBuffer[] collectBufferWithHeaders() { + int index = 0; + ByteBuffer[] bufferWithHeaders = new ByteBuffer[2 * dataBuffers.size()]; + + for (BufferOrMarker.DataBuffer dataBuffer : dataBuffers) { + int reducePartitionIndex = dataBuffer.getReducePartitionID().getPartitionIndex(); + CommonUtils.checkState( + reducePartitionIndex >= currentReducePartition, + "Must writing data in reduce partition index order."); + CommonUtils.checkState( + !isBroadcastRegion || reducePartitionIndex == 0, + "Reduce partition index must be 0 for broadcast region."); + + if (reducePartitionIndex > currentReducePartition) { + currentReducePartition = reducePartitionIndex; + } + + ByteBuffer data = dataBuffer.getBuffer().nioBuffer(); + ByteBuffer header = IOUtils.getHeaderBuffer(data, dataChecksumEnabled); + + long length = data.remaining() + header.remaining(); + totalBytes += length; + numReducePartitionBytes[reducePartitionIndex] += length; + + bufferWithHeaders[index] = header; + bufferWithHeaders[index + 1] = data; + index += 2; + } + return bufferWithHeaders; + } + + /** + * Marks that a new data region has been started. If the new data region is a broadcast region, + * buffers added to this region will be written to all reduce partitions. + */ + public void startRegion(boolean isBroadcastRegion) { + checkNotClosed(); + checkRegionFinished(); + + currentReducePartition = 0; + this.isBroadcastRegion = isBroadcastRegion; + } + + /** + * Marks that the current data region has been finished and flushes the index region to the + * index file. + */ + public void finishRegion() throws IOException { + checkNotClosed(); + + flushDataBuffers(); + if (regionStartingOffset == totalBytes) { + return; + } + + int numReducePartitions = partitionFile.getFileMeta().getNumReducePartitions(); + long fileOffset = regionStartingOffset; + + // write the index information of the current data region + for (int partitionIndex = 0; partitionIndex < numReducePartitions; ++partitionIndex) { + indexBuffer.putLong(fileOffset); + if (!isBroadcastRegion) { + indexBuffer.putLong(numReducePartitionBytes[partitionIndex]); + fileOffset += numReducePartitionBytes[partitionIndex]; + } else { + indexBuffer.putLong(numReducePartitionBytes[0]); + } + indexBuffer.putInt(IOUtils.MAGIC_NUMBER); + } + + if (!indexBuffer.hasRemaining()) { + flushIndexBuffer(); + } + + ++numDataRegions; + regionStartingOffset = totalBytes; + Arrays.fill(numReducePartitionBytes, 0); + } + + private void flushIndexBuffer() throws IOException { + indexBuffer.flip(); + if (indexBuffer.hasRemaining()) { + for (int index = 0; index < indexBuffer.limit(); ++index) { + checksum.update(indexBuffer.get(index)); + } + IOUtils.writeBuffer(indexFileChannel, indexBuffer); + } + indexBuffer.clear(); + } + + /** + * Closes this partition file writer and marks the target {@link LocalMapPartitionFile} as + * consumable after finishing data writing. + */ + public void finishWriting() throws Exception { + // handle empty data file + if (!isOpened()) { + open(); + } + + checkNotClosed(); + checkRegionFinished(); + + flushIndexBuffer(); + + // flush the number of data regions and the index data checksum for integrity checking + indexBuffer.putLong(numDataRegions); + indexBuffer.putLong(checksum.getValue()); + indexBuffer.flip(); + IOUtils.writeBuffer(indexFileChannel, indexBuffer); + close(); + + LocalMapPartitionFileMeta fileMeta = partitionFile.getFileMeta(); + File dataFile = fileMeta.getDataFilePath().toFile(); + renameFile(fileMeta.getPartialDataFilePath().toFile(), dataFile); + + File indexFile = fileMeta.getIndexFilePath().toFile(); + renameFile(fileMeta.getPartialIndexFilePath().toFile(), indexFile); + + CommonUtils.checkState(dataFile.exists(), "Data file has been deleted."); + CommonUtils.checkState(indexFile.exists(), "Index file has been deleted."); + partitionFile.setConsumable(true); + } + + private void renameFile(File sourceFile, File targetFile) throws IOException { + CommonUtils.checkArgument(sourceFile != null, "Must be not null."); + CommonUtils.checkArgument(targetFile != null, "Must be not null."); + + if (!sourceFile.renameTo(targetFile)) { + throw new IOException( + String.format( + "Failed to rename file %s to file %s.", + sourceFile.getAbsolutePath(), targetFile.getAbsolutePath())); + } + } + + /** Releases this partition file writer when any exception occurs. */ + public void close() throws Exception { + isClosed = true; + Throwable exception = null; + + try { + CommonUtils.closeWithRetry(dataFileChannel); + } catch (Throwable throwable) { + exception = throwable; + Path dataFilePath = partitionFile.getFileMeta().getDataFilePath(); + LOG.error("Failed to close data file channel: {}.", dataFilePath, throwable); + } + + try { + CommonUtils.closeWithRetry(indexFileChannel); + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + Path dataFilePath = partitionFile.getFileMeta().getIndexFilePath(); + LOG.error("Failed to close index file channel: {}.", dataFilePath, throwable); + } + + try { + releaseAllDataBuffers(); + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + LOG.error("Failed to release the pending data buffers.", throwable); + } + + if (exception != null) { + ExceptionUtils.rethrowException(exception); + } + } + + public boolean isOpened() { + return isOpened; + } + + private void releaseAllDataBuffers() { + for (BufferOrMarker.DataBuffer dataBuffer : dataBuffers) { + BufferOrMarker.releaseBuffer(dataBuffer); + } + dataBuffers.clear(); + } + + private void checkRegionFinished() { + CommonUtils.checkState( + regionStartingOffset == totalBytes, + "Must finish the current data region before starting a new one."); + } + + private void checkNotClosed() { + CommonUtils.checkState(!isClosed, "Partition file writer has been closed."); + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/PartitionProcessingTask.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/PartitionProcessingTask.java new file mode 100644 index 00000000..f598e119 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/PartitionProcessingTask.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +/** + * Interface of data partition processing task. All processing logics of data partition should be + * encapsulated in {@link PartitionProcessingTask}s. + */ +public interface PartitionProcessingTask { + + void process(); +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/PersistentFile.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/PersistentFile.java new file mode 100644 index 00000000..911ab738 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/PersistentFile.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +/** {@link PersistentFile} is the interface for persistent data partition file. */ +public interface PersistentFile { + + /** + * Returns the latest storage version of this persistent file. This is for storage format + * evolution and backward compatibility. + */ + int getLatestStorageVersion(); + + /** Checks whether this persistent file is consumable or not and returns true if so. */ + boolean isConsumable(); + + /** Gets the corresponding meta of this persistent file. */ + PersistentFileMeta getFileMeta(); + + /** Deletes this persistent file and throws the exception if any failure occurs. */ + void deleteFile() throws Exception; + + /** Notifies that an error happens while reading data from this persistent file. */ + void onError(Throwable throwable); + + /** Changes the consumable state of this persistent file. */ + void setConsumable(boolean consumable); +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/PersistentFileMeta.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/PersistentFileMeta.java new file mode 100644 index 00000000..5c7e50e3 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/PersistentFileMeta.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.config.Configuration; + +import java.io.DataOutput; +import java.io.Serializable; + +/** Meta information of {@link PersistentFile}. */ +public interface PersistentFileMeta extends Serializable { + + /** Creates the corresponding {@link PersistentFile} instance represented by this file meta. */ + PersistentFile createPersistentFile(Configuration configuration); + + /** + * Returns the storage version of the target {@link PersistentFile} for backward compatibility. + */ + int getStorageVersion(); + + /** + * Writes all meta information to the target {@link DataOutput} which can be used to recover + * data after failure. + */ + void writeTo(DataOutput dataOutput) throws Exception; +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/SSDOnlyLocalFileMapPartitionFactory.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/SSDOnlyLocalFileMapPartitionFactory.java new file mode 100644 index 00000000..73daf708 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/partition/SSDOnlyLocalFileMapPartitionFactory.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; + +/** + * A {@link LocalFileMapPartitionFactory} variant which only uses SSD to store data partition data. + */ +public class SSDOnlyLocalFileMapPartitionFactory extends LocalFileMapPartitionFactory { + + @Override + public void initialize(Configuration configuration) { + super.initialize(configuration); + + if (ssdStorageMetas.isEmpty()) { + throw new ConfigurationException( + String.format( + "No valid data dir of SSD storage type is configured for %s.", + StorageOptions.STORAGE_LOCAL_DATA_DIRS.key())); + } + } + + @Override + protected StorageMeta getNextDataStorageMeta() { + StorageMeta storageMeta = CommonUtils.checkNotNull(ssdStorageMetas.poll()); + ssdStorageMetas.add(storageMeta); + return storageMeta; + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/utils/DataPartitionUtils.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/utils/DataPartitionUtils.java new file mode 100644 index 00000000..88e8f728 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/utils/DataPartitionUtils.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.utils; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.listener.PartitionStateListener; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.core.storage.DataPartitionFactory; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReader; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWriter; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.util.Collection; +import java.util.concurrent.CompletableFuture; + +/** Utility methods to manipulate {@link DataPartition}s. */ +public class DataPartitionUtils { + + private static final Logger LOG = LoggerFactory.getLogger(DataPartitionUtils.class); + + /** + * Helper method which releases the target {@link DataPartitionWriter} and logs the encountered + * exception if any. + */ + public static void releaseDataPartitionWriter( + @Nullable DataPartitionWriter writer, @Nullable Throwable releaseCause) { + if (writer == null) { + return; + } + + try { + writer.release(releaseCause); + } catch (Throwable throwable) { + LOG.error("Fatal: failed to release data partition writer: {}.", writer, throwable); + } + } + + /** + * Helper method which releases the target {@link DataPartitionReader} and logs the encountered + * exception if any. + */ + public static void releaseDataPartitionReader( + @Nullable DataPartitionReader reader, @Nullable Throwable releaseCause) { + if (reader == null) { + return; + } + + try { + reader.release(releaseCause); + } catch (Throwable throwable) { + LOG.error("Fatal: failed to release data partition reader: {}.", reader, throwable); + } + } + + /** + * Helper method which releases all the given {@link DataPartitionReader}s and logs the + * encountered exception if any. + */ + public static void releaseDataPartitionReaders( + @Nullable Collection readers, @Nullable Throwable releaseCause) { + if (readers == null) { + return; + } + + for (DataPartitionReader partitionReader : readers) { + releaseDataPartitionReader(partitionReader, releaseCause); + } + // clear method is not supported by all collections + CommonUtils.runQuietly(readers::clear); + } + + /** + * Helper method which releases the target {@link DataPartition} and logs the encountered + * exception if any. + */ + public static CompletableFuture releaseDataPartition( + @Nullable DataPartition dataPartition, @Nullable Throwable releaseCause) { + if (dataPartition == null) { + return CompletableFuture.completedFuture(null); + } + + DataPartitionMeta partitionMeta = dataPartition.getPartitionMeta(); + try { + return dataPartition.releasePartition(releaseCause); + } catch (Throwable throwable) { + LOG.error("Fatal: failed to release data partition: {}.", partitionMeta, throwable); + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(throwable); + return future; + } + } + + /** + * Helper method which releases all the given {@link DataPartition}s and logs the encountered + * exception if any. + */ + public static void releaseDataPartitions( + @Nullable Collection dataPartitions, + @Nullable Throwable releaseCause, + PartitionStateListener partitionStateListener) { + if (dataPartitions == null) { + return; + } + + CommonUtils.checkArgument(partitionStateListener != null, "Must be not null."); + for (DataPartition dataPartition : dataPartitions) { + CommonUtils.runQuietly( + () -> { + releaseDataPartition(dataPartition, releaseCause).get(); + partitionStateListener.onPartitionRemoved(dataPartition.getPartitionMeta()); + }, + true); + } + // clear method is not supported by all collections + CommonUtils.runQuietly(dataPartitions::clear); + } + + /** + * Helper method which serializes the given {@link DataPartitionMeta} to the given {@link + * DataOutput} which can be used to reconstruct lost {@link DataPartition}s. + */ + public static void serializePartitionMeta( + DataPartitionMeta partitionMeta, DataOutput dataOutput) throws Exception { + dataOutput.writeUTF(partitionMeta.getPartitionFactoryClassName()); + partitionMeta.writeTo(dataOutput); + } + + /** + * Helper method which deserializes and creates a new {@link DataPartitionMeta} instance from + * the given {@link DataInput}. The created {@link DataPartitionMeta} can be used to reconstruct + * lost {@link DataPartition}s. + */ + public static DataPartitionMeta deserializePartitionMeta(DataInput dataInput) throws Exception { + String partitionFactoryClassName = dataInput.readUTF(); + Class factoryClass = Class.forName(partitionFactoryClassName); + DataPartitionFactory factory = (DataPartitionFactory) factoryClass.newInstance(); + return factory.recoverDataPartitionMeta(dataInput); + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/utils/IOUtils.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/utils/IOUtils.java new file mode 100644 index 00000000..10d304e6 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/utils/IOUtils.java @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.utils; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.storage.exception.FileCorruptedException; +import com.alibaba.flink.shuffle.storage.partition.LocalMapPartitionFile; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.zip.CRC32; +import java.util.zip.Checksum; + +/** Utility methods for IO. */ +public class IOUtils { + + private static final Logger LOG = LoggerFactory.getLogger(IOUtils.class); + + /** + * Size of buffer header: 4 bytes for buffer length, 4 bytes for magic number and 8 bytes for + * checksum. + */ + public static final int HEADER_BUFFER_SIZE = 4 + 4 + 8; + + /** + * Magic number used to check whether the data has corrupted or not. Note that the data is not + * guaranteed to be in good state even when the magic number is correct. + */ + public static final int MAGIC_NUMBER = 1431655765; // 01010101010101010101010101010101 + + /** + * Magic number used to check whether the data has corrupted or not. Note that the data is not + * guaranteed to be in good state even when the magic number is correct. + */ + public static final int MAGIC_NUMBER_WITHOUT_CHECKSUM = + MAGIC_NUMBER; // 01010101010101010101010101010101 + + /** + * Magic number used to check whether the data has corrupted or not. Note that the data is not + * guaranteed to be in good state even when the magic number is correct. + */ + public static final int MAGIC_NUMBER_WITH_CHECKSUM = + ~MAGIC_NUMBER; // 10101010101010101010101010101010 + + /** Opens a {@link FileChannel} for writing, will fail if the file already exists. */ + public static FileChannel createWritableFileChannel(Path path) throws IOException { + CommonUtils.checkArgument(path != null, "Must be not null."); + + return FileChannel.open(path, StandardOpenOption.CREATE_NEW, StandardOpenOption.WRITE); + } + + /** Opens a {@link FileChannel} for reading. */ + public static FileChannel openReadableFileChannel(Path path) throws IOException { + CommonUtils.checkArgument(path != null, "Must be not null."); + + return FileChannel.open(path, StandardOpenOption.READ); + } + + /** Writes all data of the given {@link ByteBuffer} to the target {@link FileChannel}. */ + public static void writeBuffer(FileChannel fileChannel, ByteBuffer buffer) throws IOException { + CommonUtils.checkArgument(fileChannel != null, "Must be not null."); + CommonUtils.checkArgument(buffer != null, "Must be not null."); + + while (buffer.hasRemaining()) { + fileChannel.write(buffer); + } + } + + /** Writes a collection of {@link ByteBuffer}s to the target {@link FileChannel}. */ + public static void writeBuffers(FileChannel fileChannel, ByteBuffer[] buffers) + throws IOException { + CommonUtils.checkArgument(fileChannel != null, "Must be not null."); + CommonUtils.checkArgument(buffers != null, "Must be not null."); + CommonUtils.checkArgument(buffers.length > 0, "No buffer to write."); + + long expectedBytes = 0; + for (ByteBuffer buffer : buffers) { + expectedBytes += buffer.remaining(); + } + + long bytesWritten = fileChannel.write(buffers); + while (bytesWritten < expectedBytes) { + int bufferOffset = 0; + for (ByteBuffer buffer : buffers) { + if (buffer.hasRemaining()) { + break; + } + ++bufferOffset; + } + bytesWritten += fileChannel.write(buffers, bufferOffset, buffers.length - bufferOffset); + } + } + + /** + * Creates and returns the corresponding header {@link ByteBuffer} of the target data {@link + * ByteBuffer}. + */ + public static ByteBuffer getHeaderBuffer(ByteBuffer buffer, boolean dataChecksumEnabled) { + CommonUtils.checkArgument(buffer != null, "Must be not null."); + + ByteBuffer header = allocateHeaderBuffer(); + header.putInt(buffer.remaining()); + if (!dataChecksumEnabled) { + header.putInt(MAGIC_NUMBER_WITHOUT_CHECKSUM); + header.putLong(0L); + } else { + header.putInt(MAGIC_NUMBER_WITH_CHECKSUM); + Checksum checksum = new CRC32(); + for (int i = 0; i < buffer.remaining(); ++i) { + checksum.update(buffer.get(i)); + } + header.putLong(checksum.getValue()); + } + header.flip(); + return header; + } + + /** + * Reads the target length of data from the given {@link FileChannel} to the target {@link + * ByteBuffer}. + */ + public static void readBuffer(FileChannel fileChannel, ByteBuffer buffer, int length) + throws IOException { + CommonUtils.checkArgument(length <= buffer.capacity(), "Too many bytes to read."); + + long remainingBytes = fileChannel.size() - fileChannel.position(); + if (remainingBytes < length) { + LOG.error( + String.format( + "File remaining bytes not not enough, remaining: %d, wanted: %d.", + remainingBytes, length)); + throw new FileCorruptedException(); + } + + buffer.clear(); + buffer.limit(length); + + while (buffer.hasRemaining()) { + fileChannel.read(buffer); + } + buffer.flip(); + } + + /** + * Reads data with header from the given {@link FileChannel} to the target {@link ByteBuffer}. + */ + public static int readBuffer( + FileChannel fileChannel, + ByteBuffer header, + ByteBuffer buffer, + int headerSize, + boolean dataChecksumEnabled) + throws IOException { + CommonUtils.checkArgument(fileChannel != null, "Must be not null."); + CommonUtils.checkArgument(header != null, "Must be not null."); + CommonUtils.checkArgument(buffer != null, "Must be not null."); + CommonUtils.checkArgument(header.capacity() >= headerSize, "Illegal header buffer."); + + readBuffer(fileChannel, header, headerSize); + int bufferLength = header.getInt(); + int magicNumber = header.getInt(); + if ((magicNumber != MAGIC_NUMBER_WITHOUT_CHECKSUM + && magicNumber != MAGIC_NUMBER_WITH_CHECKSUM) + || bufferLength <= 0 + || bufferLength > buffer.capacity()) { + LOG.error( + String.format( + "Incorrect buffer header, magic number: %d, buffer length: %d.", + magicNumber, bufferLength)); + throw new FileCorruptedException(); + } + + readBuffer(fileChannel, buffer, bufferLength); + if (dataChecksumEnabled && magicNumber == MAGIC_NUMBER_WITH_CHECKSUM) { + Checksum checksum = new CRC32(); + for (int i = 0; i < buffer.remaining(); ++i) { + checksum.update(buffer.get(i)); + } + if (checksum.getValue() != header.getLong()) { + LOG.error("Data checksum verification failed."); + throw new FileCorruptedException(); + } + } + return bufferLength + headerSize; + } + + /** + * Allocates a piece of unmanaged direct {@link ByteBuffer} as header buffer which can be reused + * multiple times. + */ + public static ByteBuffer allocateHeaderBuffer() { + return CommonUtils.allocateDirectByteBuffer(HEADER_BUFFER_SIZE); + } + + /** + * Allocates a piece of unmanaged direct {@link ByteBuffer} for index data writing/reading. The + * minimum index buffer size returned is 4096 bytes. + */ + public static ByteBuffer allocateIndexBuffer(int numPartitions) { + CommonUtils.checkArgument(numPartitions > 0, "Must be positive."); + + // the returned buffer size is no smaller than 4096 bytes to improve disk IO performance + int minBufferSize = 4096; + + int indexRegionSize = calculateIndexRegionSize(numPartitions); + if (indexRegionSize >= minBufferSize) { + return CommonUtils.allocateDirectByteBuffer(indexRegionSize); + } + + int numRegions = minBufferSize / indexRegionSize; + if (minBufferSize % indexRegionSize != 0) { + ++numRegions; + } + return CommonUtils.allocateDirectByteBuffer(numRegions * indexRegionSize); + } + + /** + * Allocates a piece of unmanaged direct {@link ByteBuffer} for index data checksum writing and + * reading. + */ + public static ByteBuffer allocateIndexDataChecksumBuffer() { + return CommonUtils.allocateDirectByteBuffer(LocalMapPartitionFile.INDEX_DATA_CHECKSUM_SIZE); + } + + /** Calculates and returns the size of index region in bytes. */ + public static int calculateIndexRegionSize(int numPartitions) { + return CommonUtils.checkedDownCast( + (long) numPartitions * LocalMapPartitionFile.INDEX_ENTRY_SIZE); + } + + /** Return the header buffer size. For different versions, this value may be different. */ + public static int getHeaderBufferSizeOfLocalMapPartitionFile(int storageVersion) { + switch (storageVersion) { + case 0: + return 8; + case 1: + return 16; + default: + throw new IllegalArgumentException("Unknown storage version: " + storageVersion); + } + } +} diff --git a/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/utils/StorageConfigParseUtils.java b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/utils/StorageConfigParseUtils.java new file mode 100644 index 00000000..f0a0e440 --- /dev/null +++ b/shuffle-storage/src/main/java/com/alibaba/flink/shuffle/storage/utils/StorageConfigParseUtils.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.utils; + +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; +import com.alibaba.flink.shuffle.core.config.StorageOptions; + +import java.io.File; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** Utilities to parse the configuration of Storage layer. */ +public class StorageConfigParseUtils { + + /** The result of parsing the configured paths. */ + public static class ParsedPathLists { + + private final List ssdPaths; + + private final List hddPaths; + + private final List allPaths; + + public ParsedPathLists( + List ssdPaths, List hddPaths, List allPaths) { + this.ssdPaths = checkNotNull(ssdPaths); + this.hddPaths = checkNotNull(hddPaths); + this.allPaths = checkNotNull(allPaths); + } + + public List getSsdPaths() { + return ssdPaths; + } + + public List getHddPaths() { + return hddPaths; + } + + public List getAllPaths() { + return allPaths; + } + } + + /** + * Parses the base paths configured by {@link StorageOptions#STORAGE_LOCAL_DATA_DIRS} + * + *

TODO: Will be replaced with a formal configuration object in the future. + */ + public static ParsedPathLists parseStoragePaths(String directories) { + List ssdPaths = new ArrayList<>(); + List hddPaths = new ArrayList<>(); + List allPaths = new ArrayList<>(); + + String[] paths = directories.split(","); + for (String pathString : paths) { + pathString = pathString.trim(); + if (pathString.equals("")) { + continue; + } + + if (pathString.startsWith("[SSD]")) { + pathString = pathString.substring(5); + pathString = pathString.endsWith("/") ? pathString : pathString + "/"; + ssdPaths.add(pathString); + } else if (pathString.startsWith("[HDD]")) { + pathString = pathString.substring(5); + pathString = pathString.endsWith("/") ? pathString : pathString + "/"; + hddPaths.add(pathString); + } else { + // if no storage type is configured, HDD will be the default + pathString = pathString.endsWith("/") ? pathString : pathString + "/"; + hddPaths.add(pathString); + } + allPaths.add(pathString); + + Path path = new File(pathString).toPath(); + if (!Files.exists(path)) { + throw new ConfigurationException( + String.format( + "The data dir '%s' configured by '%s' does not exist.", + pathString, StorageOptions.STORAGE_LOCAL_DATA_DIRS.key())); + } + + if (!Files.isDirectory(path)) { + throw new ConfigurationException( + String.format( + "The data dir '%s' configured by '%s' is not a directory.", + pathString, StorageOptions.STORAGE_LOCAL_DATA_DIRS.key())); + } + } + + return new ParsedPathLists(ssdPaths, hddPaths, allPaths); + } +} diff --git a/shuffle-storage/src/main/resources/META-INF/services/com.alibaba.flink.shuffle.core.storage.DataPartitionFactory b/shuffle-storage/src/main/resources/META-INF/services/com.alibaba.flink.shuffle.core.storage.DataPartitionFactory new file mode 100644 index 00000000..a92bd370 --- /dev/null +++ b/shuffle-storage/src/main/resources/META-INF/services/com.alibaba.flink.shuffle.core.storage.DataPartitionFactory @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +com.alibaba.flink.shuffle.storage.partition.LocalFileMapPartitionFactory +com.alibaba.flink.shuffle.storage.partition.SSDOnlyLocalFileMapPartitionFactory +com.alibaba.flink.shuffle.storage.partition.HDDOnlyLocalFileMapPartitionFactory diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/NoOpDataPartitionReader.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/NoOpDataPartitionReader.java new file mode 100644 index 00000000..a095a88f --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/NoOpDataPartitionReader.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.datastore; + +import com.alibaba.flink.shuffle.core.memory.BufferRecycler; +import com.alibaba.flink.shuffle.core.storage.BufferQueue; +import com.alibaba.flink.shuffle.core.storage.BufferWithBacklog; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReader; + +/** A no-op {@link DataPartitionReader} implementation for tests. */ +public class NoOpDataPartitionReader implements DataPartitionReader { + + @Override + public void open() {} + + @Override + public boolean readData(BufferQueue buffers, BufferRecycler recycler) { + return false; + } + + @Override + public BufferWithBacklog nextBuffer() { + return null; + } + + @Override + public void release(Throwable throwable) {} + + @Override + public boolean isFinished() { + return false; + } + + @Override + public long getPriority() { + return 0; + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public boolean isOpened() { + return false; + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/NoOpDataPartitionWriter.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/NoOpDataPartitionWriter.java new file mode 100644 index 00000000..642d0fb0 --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/NoOpDataPartitionWriter.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.datastore; + +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; +import com.alibaba.flink.shuffle.core.listener.DataCommitListener; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.memory.BufferRecycler; +import com.alibaba.flink.shuffle.core.storage.BufferQueue; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWriter; + +/** A no-op {@link DataPartitionWriter} implementation for tests. */ +public class NoOpDataPartitionWriter implements DataPartitionWriter { + + @Override + public Buffer pollBuffer() { + return null; + } + + @Override + public MapPartitionID getMapPartitionID() { + return null; + } + + @Override + public boolean writeData() { + return false; + } + + @Override + public void addBuffer(ReducePartitionID reducePartitionID, Buffer buffer) {} + + @Override + public void startRegion(int dataRegionIndex, boolean isBroadcastRegion) {} + + @Override + public void finishRegion() {} + + @Override + public void finishDataInput(DataCommitListener commitListener) {} + + @Override + public boolean assignCredits(BufferQueue credits, BufferRecycler recycler) { + return false; + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void release(Throwable throwable) {} +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/NoOpPartitionedDataStore.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/NoOpPartitionedDataStore.java new file mode 100644 index 00000000..5595ecfb --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/NoOpPartitionedDataStore.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.datastore; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.executor.SimpleSingleThreadExecutorPool; +import com.alibaba.flink.shuffle.core.executor.SingleThreadExecutorPool; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.memory.BufferDispatcher; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReadingView; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWritingView; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.core.storage.ReadingViewContext; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; +import com.alibaba.flink.shuffle.core.storage.WritingViewContext; +import com.alibaba.flink.shuffle.storage.utils.StorageTestUtils; + +import javax.annotation.Nullable; + +import java.util.Properties; + +/** A no-op {@link PartitionedDataStore} implementation for tests. */ +public class NoOpPartitionedDataStore implements PartitionedDataStore { + + private final BufferDispatcher writingBufferDispatcher = + new BufferDispatcher( + "Test Writing Buffer Pool", 1024, StorageTestUtils.DATA_BUFFER_SIZE); + + private final BufferDispatcher readingBufferDispatcher = + new BufferDispatcher( + "Test Reading Buffer Pool", 1024, StorageTestUtils.DATA_BUFFER_SIZE); + + private final SingleThreadExecutorPool executorPool = + new SimpleSingleThreadExecutorPool(4, "Test Executor Pool"); + + @Override + public DataPartitionWritingView createDataPartitionWritingView(WritingViewContext context) { + return null; + } + + @Override + public DataPartitionReadingView createDataPartitionReadingView(ReadingViewContext context) { + return null; + } + + @Override + public boolean isDataPartitionConsumable(DataPartitionMeta partitionMeta) { + return false; + } + + @Override + public void addDataPartition(DataPartitionMeta partitionMeta) {} + + @Override + public void removeDataPartition(DataPartitionMeta partitionMeta) {} + + @Override + public void releaseDataPartition( + DataSetID dataSetID, DataPartitionID partitionID, @Nullable Throwable throwable) {} + + @Override + public void releaseDataSet(DataSetID dataSetID, @Nullable Throwable throwable) {} + + @Override + public void releaseDataByJobID(JobID jobID, @Nullable Throwable throwable) {} + + @Override + public void shutDown(boolean releaseData) {} + + @Override + public boolean isShutDown() { + return false; + } + + @Override + public Configuration getConfiguration() { + return new Configuration(new Properties()); + } + + @Override + public BufferDispatcher getWritingBufferDispatcher() { + return writingBufferDispatcher; + } + + @Override + public BufferDispatcher getReadingBufferDispatcher() { + return readingBufferDispatcher; + } + + @Override + public SingleThreadExecutorPool getExecutorPool(StorageMeta storageMeta) { + return executorPool; + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/PartitionReadingViewImplTest.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/PartitionReadingViewImplTest.java new file mode 100644 index 00000000..809fbc04 --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/PartitionReadingViewImplTest.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.datastore; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReader; + +import org.junit.Test; + +/** Tests for {@link PartitionReadingViewImpl}. */ +public class PartitionReadingViewImplTest { + + private final DataPartitionReader partitionReader = new NoOpDataPartitionReader(); + + @Test(expected = IllegalStateException.class) + public void testReadAfterError() throws Exception { + PartitionReadingViewImpl readingView = new PartitionReadingViewImpl(partitionReader); + readingView.onError(new ShuffleException("Test exception.")); + readingView.nextBuffer(); + } + + @Test(expected = IllegalStateException.class) + public void testOnErrorAfterError() { + PartitionReadingViewImpl readingView = new PartitionReadingViewImpl(partitionReader); + readingView.onError(new ShuffleException("Test exception.")); + readingView.onError(new ShuffleException("Test exception.")); + } + + @Test(expected = IllegalStateException.class) + public void testCheckFinishAfterError() { + PartitionReadingViewImpl readingView = new PartitionReadingViewImpl(partitionReader); + readingView.onError(new ShuffleException("Test exception.")); + readingView.isFinished(); + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/PartitionWritingViewImplTest.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/PartitionWritingViewImplTest.java new file mode 100644 index 00000000..f8ca99fb --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/PartitionWritingViewImplTest.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.datastore; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; +import com.alibaba.flink.shuffle.core.listener.DataCommitListener; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWriter; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWritingView; + +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; + +/** Tests for {@link PartitionWritingViewImpl}. */ +public class PartitionWritingViewImplTest { + + private final DataPartitionWriter partitionWriter = new NoOpDataPartitionWriter(); + + private final DataCommitListener noOpDataCommitListener = () -> {}; + + @Test(expected = IllegalStateException.class) + public void testOnErrorAfterFinish() { + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + + writingView.finish(noOpDataCommitListener); + writingView.onError(new ShuffleException("Test exception.")); + } + + @Test(expected = IllegalStateException.class) + public void testOnErrorAfterError() { + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + + writingView.onError(new ShuffleException("Test exception.")); + writingView.onError(new ShuffleException("Test exception.")); + } + + @Test(expected = IllegalStateException.class) + public void testFinishAfterFinish() { + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + + writingView.finish(noOpDataCommitListener); + writingView.finish(noOpDataCommitListener); + } + + @Test(expected = IllegalStateException.class) + public void testFinishAfterError() { + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + + writingView.onError(new ShuffleException("Test exception.")); + writingView.finish(noOpDataCommitListener); + } + + @Test(expected = IllegalArgumentException.class) + public void testNullDataCommitListener() { + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + writingView.finish(null); + } + + @Test(expected = IllegalStateException.class) + public void testRegionFinishAfterFinish() { + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + + writingView.finish(noOpDataCommitListener); + writingView.regionFinished(); + } + + @Test(expected = IllegalStateException.class) + public void testRegionFinishAfterError() { + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + + writingView.onError(new ShuffleException("Test exception.")); + writingView.regionFinished(); + } + + @Test(expected = IllegalStateException.class) + public void testRegionStartAfterFinish() { + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + + writingView.finish(noOpDataCommitListener); + writingView.regionStarted(0, false); + } + + @Test(expected = IllegalStateException.class) + public void testRegionStartAfterError() { + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + + writingView.onError(new ShuffleException("Test exception.")); + writingView.regionStarted(0, false); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalDataRegionIndex() { + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + writingView.regionStarted(-1, false); + } + + @Test + public void testOnBufferAfterError() { + AtomicReference recycleBuffer = new AtomicReference<>(); + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + writingView.regionStarted(10, true); + + writingView.onError(new ShuffleException("Test exception.")); + Buffer buffer = new Buffer(ByteBuffer.allocateDirect(1024), recycleBuffer::set, 1024); + assertThrows( + IllegalStateException.class, + () -> writingView.onBuffer(buffer, new ReducePartitionID(0))); + + assertNotNull(recycleBuffer.get()); + } + + @Test + public void testOnBufferAfterFinish() { + AtomicReference recycleBuffer = new AtomicReference<>(); + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + + writingView.finish(noOpDataCommitListener); + Buffer buffer = new Buffer(ByteBuffer.allocateDirect(1024), recycleBuffer::set, 1024); + assertThrows( + IllegalStateException.class, + () -> writingView.onBuffer(buffer, new ReducePartitionID(0))); + + assertNotNull(recycleBuffer.get()); + } + + @Test + public void testOnBufferWithNullReducePartitionID() { + AtomicReference recycleBuffer = new AtomicReference<>(); + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + writingView.regionStarted(10, true); + + Buffer buffer = new Buffer(ByteBuffer.allocateDirect(1024), recycleBuffer::set, 1024); + assertThrows(IllegalArgumentException.class, () -> writingView.onBuffer(buffer, null)); + + assertNotNull(recycleBuffer.get()); + } + + @Test(expected = IllegalArgumentException.class) + public void testOnNullBuffer() { + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + writingView.onBuffer(null, new ReducePartitionID(0)); + } + + @Test + public void testOnBufferBeforeRegionStart() { + AtomicReference recycleBuffer = new AtomicReference<>(); + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + + Buffer buffer = new Buffer(ByteBuffer.allocateDirect(1024), recycleBuffer::set, 1024); + assertThrows(IllegalStateException.class, () -> writingView.onBuffer(buffer, null)); + + assertNotNull(recycleBuffer.get()); + } + + @Test(expected = IllegalStateException.class) + public void testRegionFinishBeforeStart() { + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + writingView.regionFinished(); + } + + @Test(expected = IllegalStateException.class) + public void testRegionStartBeforeFinish() { + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + writingView.regionStarted(10, true); + writingView.regionStarted(10, true); + } + + @Test(expected = IllegalStateException.class) + public void testFinishBeforeRegionFinish() { + DataPartitionWritingView writingView = new PartitionWritingViewImpl(partitionWriter); + writingView.regionStarted(10, true); + writingView.finish(noOpDataCommitListener); + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/PartitionedDataStoreImplTest.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/PartitionedDataStoreImplTest.java new file mode 100644 index 00000000..f80d5d47 --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/datastore/PartitionedDataStoreImplTest.java @@ -0,0 +1,904 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.datastore; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.exception.DuplicatedPartitionException; +import com.alibaba.flink.shuffle.core.exception.PartitionNotFoundException; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.storage.BufferWithBacklog; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReadingView; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWritingView; +import com.alibaba.flink.shuffle.core.storage.ReadingViewContext; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; +import com.alibaba.flink.shuffle.core.storage.WritingViewContext; +import com.alibaba.flink.shuffle.core.utils.BufferUtils; +import com.alibaba.flink.shuffle.storage.exception.ConcurrentWriteException; +import com.alibaba.flink.shuffle.storage.partition.LocalFileMapPartitionMeta; +import com.alibaba.flink.shuffle.storage.partition.LocalMapPartitionFile; +import com.alibaba.flink.shuffle.storage.partition.LocalMapPartitionFileMeta; +import com.alibaba.flink.shuffle.storage.utils.StorageTestUtils; +import com.alibaba.flink.shuffle.storage.utils.TestDataCommitListener; +import com.alibaba.flink.shuffle.storage.utils.TestDataListener; +import com.alibaba.flink.shuffle.storage.utils.TestDataRegionCreditListener; +import com.alibaba.flink.shuffle.storage.utils.TestFailureListener; +import com.alibaba.flink.shuffle.storage.utils.TestPartitionStateListener; + +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; + +import java.io.File; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +/** Tests for {@link PartitionedDataStoreImpl}. */ +public class PartitionedDataStoreImplTest { + + @Rule public Timeout timeout = new Timeout(60, TimeUnit.SECONDS); + + @Rule public final TemporaryFolder temporaryFolder = new TemporaryFolder(); + + private static final int numBuffersPerReducePartitions = 100; + + public PartitionedDataStoreImpl dataStore; + + public TestPartitionStateListener partitionStateListener; + + @Before + public void before() { + partitionStateListener = new TestPartitionStateListener(); + dataStore = + StorageTestUtils.createPartitionedDataStore( + temporaryFolder.getRoot().getAbsolutePath(), partitionStateListener); + } + + @After + public void after() { + dataStore.shutDown(true); + } + + @Test + public void testWriteEmptyDataPartition() throws Exception { + StorageTestUtils.createEmptyDataPartition(dataStore); + + StorageTestUtils.assertNoBufferLeaking(dataStore); + assertEquals(2, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + assertEquals(StorageTestUtils.getDefaultDataPartition(), dataStore.getStoredData()); + } + + @Test + public void testReadEmptyDataPartition() throws Exception { + StorageTestUtils.createEmptyDataPartition(dataStore); + + TestDataListener dataListener = new TestDataListener(); + DataPartitionReadingView readingView = + dataStore.createDataPartitionReadingView( + new ReadingViewContext( + StorageTestUtils.DATA_SET_ID, + StorageTestUtils.MAP_PARTITION_ID, + 0, + 0, + dataListener, + StorageTestUtils.NO_OP_BACKLOG_LISTENER, + StorageTestUtils.NO_OP_FAILURE_LISTENER)); + + dataListener.waitData(60000); + + StorageTestUtils.assertNoBufferLeaking(dataStore); + assertTrue(readingView.isFinished()); + assertNull(readingView.nextBuffer()); + } + + @Test(expected = PartitionNotFoundException.class) + public void testReadNonExistentDataPartition() throws Exception { + StorageTestUtils.createDataPartitionReadingView(dataStore, 0); + } + + @Test(expected = PartitionNotFoundException.class) + public void testDeleteIndexFileOfLocalFileMapPartition() throws Exception { + StorageTestUtils.createEmptyDataPartition(dataStore); + + for (File file : CommonUtils.checkNotNull(temporaryFolder.getRoot().listFiles())) { + if (file.getPath().contains(LocalMapPartitionFile.INDEX_FILE_SUFFIX)) { + Files.delete(file.toPath()); + } + } + + StorageTestUtils.createDataPartitionReadingView(dataStore, 0); + } + + @Test(expected = PartitionNotFoundException.class) + public void testDeleteDataFileOfLocalFileMapPartition() throws Exception { + StorageTestUtils.createEmptyDataPartition(dataStore); + + for (File file : CommonUtils.checkNotNull(temporaryFolder.getRoot().listFiles())) { + if (file.getPath().contains(LocalMapPartitionFile.DATA_FILE_SUFFIX)) { + Files.delete(file.toPath()); + } + } + + StorageTestUtils.createDataPartitionReadingView(dataStore, 0); + } + + @Test(expected = PartitionNotFoundException.class) + public void testReadUnfinishedDataPartition() throws Exception { + DataProducerTask producerTask = new DataProducerTask(dataStore, 120, false); + producerTask.start(); + + StorageTestUtils.createDataPartitionReadingView(dataStore, producerTask.getPartitionID()); + } + + @Test + public void testReleaseWhileWriting() throws Exception { + CountDownLatch latch = produceData(); + + dataStore.releaseDataPartition( + StorageTestUtils.DATA_SET_ID, + StorageTestUtils.MAP_PARTITION_ID, + new ShuffleException("Test.")); + latch.await(); + + StorageTestUtils.assertNoBufferLeaking(dataStore); + assertDataStoreEmpty(dataStore); + assertEquals(0, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + assertEquals(1, partitionStateListener.getNumCreated()); + assertEquals(1, partitionStateListener.getNumRemoved()); + } + + @Test + public void testReleaseWhileReading() throws Exception { + int numProducers = 1; + int numRegions = 20; + + MapPartitionID[] mapDataPartitionIDS = + produceData(dataStore, numProducers, numRegions, false); + CountDownLatch latch = consumeData(mapDataPartitionIDS[0]); + + dataStore.releaseDataByJobID(StorageTestUtils.JOB_ID, new ShuffleException("Test.")); + latch.await(); + + StorageTestUtils.assertNoBufferLeaking(dataStore); + assertDataStoreEmpty(dataStore); + assertEquals(0, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + assertEquals(numProducers, partitionStateListener.getNumCreated()); + assertEquals(numProducers, partitionStateListener.getNumRemoved()); + } + + @Test + public void testReleaseDataPartition() throws Exception { + Map>> dataPartitions = addRandomPartitions(); + Map>> dataPartitionIDs = new HashMap<>(); + for (JobID jobID : dataPartitions.keySet()) { + dataPartitionIDs.put(jobID, new HashMap<>()); + for (DataSetID dataSetID : dataPartitions.get(jobID).keySet()) { + dataPartitionIDs + .get(jobID) + .put(dataSetID, new HashSet<>(dataPartitions.get(jobID).get(dataSetID))); + } + } + + for (JobID jobID : dataPartitionIDs.keySet()) { + for (DataSetID dataSetID : dataPartitionIDs.get(jobID).keySet()) { + for (DataPartitionID partitionID : dataPartitionIDs.get(jobID).get(dataSetID)) { + dataPartitions.get(jobID).get(dataSetID).remove(partitionID); + if (dataPartitions.get(jobID).get(dataSetID).isEmpty()) { + dataPartitions.get(jobID).remove(dataSetID); + } + if (dataPartitions.get(jobID).isEmpty()) { + dataPartitions.remove(jobID); + } + dataStore.releaseDataPartition( + dataSetID, partitionID, new ShuffleException("Tests.")); + assertStoredDataIsExpected(dataStore, dataPartitions); + } + } + } + assertEquals(0, partitionStateListener.getNumCreated()); + assertEquals(220, partitionStateListener.getNumRemoved()); + } + + @Test + public void testReleaseDataSet() throws Exception { + Map>> dataPartitions = addRandomPartitions(); + Map> dataSetIDs = new HashMap<>(); + for (JobID jobID : dataPartitions.keySet()) { + dataSetIDs.put(jobID, new HashSet<>(dataPartitions.get(jobID).keySet())); + } + + for (JobID jobID : dataSetIDs.keySet()) { + for (DataSetID dataSetID : dataSetIDs.get(jobID)) { + dataPartitions.get(jobID).remove(dataSetID); + if (dataPartitions.get(jobID).isEmpty()) { + dataPartitions.remove(jobID); + } + dataStore.releaseDataSet(dataSetID, new ShuffleException("Test.")); + assertStoredDataIsExpected(dataStore, dataPartitions); + } + } + assertEquals(0, partitionStateListener.getNumCreated()); + assertEquals(220, partitionStateListener.getNumRemoved()); + } + + @Test + public void testReleaseByJobID() throws Exception { + Map>> dataPartitions = addRandomPartitions(); + + while (!dataPartitions.isEmpty()) { + Set jobIDS = new HashSet<>(dataPartitions.keySet()); + for (JobID jobID : jobIDS) { + dataPartitions.remove(jobID); + dataStore.releaseDataByJobID(jobID, new ShuffleException("Test.")); + assertStoredDataIsExpected(dataStore, dataPartitions); + } + } + assertEquals(0, partitionStateListener.getNumCreated()); + assertEquals(220, partitionStateListener.getNumRemoved()); + } + + @Test + public void testOnErrorWhileWritingWithoutData() throws Exception { + DataPartitionWritingView writingView = + CommonUtils.checkNotNull( + StorageTestUtils.createDataPartitionWritingView(dataStore)); + writingView.onError(new ShuffleException("Test exception.")); + + StorageTestUtils.assertNoBufferLeaking(dataStore); + assertDataStoreEmpty(dataStore); + assertEquals(1, partitionStateListener.getNumCreated()); + assertEquals(1, partitionStateListener.getNumRemoved()); + } + + @Test + public void testCreateDataPartitions() throws Exception { + Map>> dataSetsByJob = + addRandomPartitions( + (jobID, dataSetID, mapPartitionID, partitionID) -> { + DataPartitionWritingView writingView = + dataStore.createDataPartitionWritingView( + new WritingViewContext( + jobID, + dataSetID, + partitionID, + mapPartitionID, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + StorageTestUtils + .LOCAL_FILE_MAP_PARTITION_FACTORY, + StorageTestUtils.NO_OP_CREDIT_LISTENER, + StorageTestUtils.NO_OP_FAILURE_LISTENER)); + TestDataCommitListener commitListener = new TestDataCommitListener(); + writingView.finish(commitListener); + commitListener.waitForDataCommission(); + }); + + Map>> storedData = dataStore.getStoredData(); + assertEquals(dataSetsByJob, storedData); + + assertEquals(440, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + assertEquals(220, partitionStateListener.getNumCreated()); + assertEquals(0, partitionStateListener.getNumRemoved()); + } + + @Test + public void testCreateExistedDataPartition() throws Exception { + TestFailureListener failureLister = new TestFailureListener(); + DataPartitionWritingView writingView = + CommonUtils.checkNotNull( + StorageTestUtils.createDataPartitionWritingView(dataStore, failureLister)); + writingView.finish(StorageTestUtils.NO_OP_DATA_COMMIT_LISTENER); + + StorageTestUtils.createDataPartitionWritingView(dataStore, failureLister); + assertTrue(failureLister.getFailure().getCause() instanceof ConcurrentWriteException); + + StorageTestUtils.assertNoBufferLeaking(dataStore); + assertEquals(1, partitionStateListener.getNumCreated()); + assertEquals(0, partitionStateListener.getNumRemoved()); + } + + @Test + public void testAddDataPartitions() throws Throwable { + Map>> dataPartitions = addRandomPartitions(); + + Map>> storedData = dataStore.getStoredData(); + assertEquals(dataPartitions, storedData); + assertEquals(0, partitionStateListener.getNumCreated()); + assertEquals(0, partitionStateListener.getNumRemoved()); + } + + @Test + public void testAddExistedDataPartition() throws Exception { + LocalMapPartitionFileMeta fileMeta = + StorageTestUtils.createLocalMapPartitionFileMeta(temporaryFolder, true); + StorageMeta storageMeta = StorageTestUtils.getStorageMeta(temporaryFolder); + DataPartitionMeta partitionMeta = + StorageTestUtils.createLocalFileMapPartitionMeta(fileMeta, storageMeta); + + dataStore.addDataPartition(partitionMeta); + assertThrows( + DuplicatedPartitionException.class, + () -> dataStore.addDataPartition(partitionMeta)); + assertEquals(0, partitionStateListener.getNumCreated()); + while (partitionStateListener.getNumRemoved() != 1) { + // wait until succeed or timeout + Thread.sleep(100); + } + assertFalse(dataStore.getStoredData().isEmpty()); + } + + @Test + public void testAddFailedDataPartition() throws Exception { + LocalMapPartitionFileMeta fileMeta = + StorageTestUtils.createLocalMapPartitionFileMeta(temporaryFolder, false); + StorageMeta storageMeta = StorageTestUtils.getStorageMeta(temporaryFolder); + DataPartitionMeta partitionMeta = + StorageTestUtils.createLocalFileMapPartitionMeta(fileMeta, storageMeta); + + assertThrows(ShuffleException.class, () -> dataStore.addDataPartition(partitionMeta)); + assertDataStoreEmpty(dataStore); + assertEquals(0, partitionStateListener.getNumCreated()); + assertEquals(1, partitionStateListener.getNumRemoved()); + } + + @Test + public void testOnErrorWhileWriting() throws Exception { + int numProducers = 10; + int numRegions = 60; + + produceData(dataStore, numProducers, numRegions, true); + assertDataStoreEmpty(dataStore); + assertEquals(numProducers, partitionStateListener.getNumCreated()); + assertEquals(numProducers, partitionStateListener.getNumRemoved()); + } + + @Test + public void testOnErrorWhileReading() throws Exception { + int numProducers = 1; + int numRegions = 10; + + MapPartitionID[] mapDataPartitionIDS = + produceData(dataStore, numProducers, numRegions, false); + + consumeData(dataStore, mapDataPartitionIDS, numRegions, 1, true); + assertEquals(numProducers, partitionStateListener.getNumCreated()); + assertEquals(0, partitionStateListener.getNumRemoved()); + } + + @Test + public void testWriteReadLargeDataVolume() throws Exception { + int numProducers = 1; + int numRegions = 60; + + MapPartitionID[] mapDataPartitionIDS = + produceData(dataStore, numProducers, numRegions, false); + + consumeData(dataStore, mapDataPartitionIDS, numRegions, 1, false); + assertEquals(numProducers, partitionStateListener.getNumCreated()); + assertEquals(0, partitionStateListener.getNumRemoved()); + } + + @Test + public void testWriteReadMultiplePartitions() throws Exception { + int numProducers = 20; + int numRegions = 3; + + MapPartitionID[] mapDataPartitionIDS = + produceData(dataStore, numProducers, numRegions, false); + + consumeData(dataStore, mapDataPartitionIDS, numRegions, 1, false); + assertEquals(numProducers, partitionStateListener.getNumCreated()); + assertEquals(0, partitionStateListener.getNumRemoved()); + } + + @Test + public void testReadMultipleChannels() throws Exception { + int numProducers = 20; + int numRegions = 3; + + MapPartitionID[] mapDataPartitionIDS = + produceData(dataStore, numProducers, numRegions, false); + + consumeData(dataStore, mapDataPartitionIDS, numRegions, 3, false); + assertEquals(numProducers, partitionStateListener.getNumCreated()); + assertEquals(0, partitionStateListener.getNumRemoved()); + } + + @Test + public void testReadDataPartitionError() throws Exception { + int numProducers = 1; + int numRegions = 60; + + MapPartitionID[] mapDataPartitionIDS = + produceData(dataStore, numProducers, numRegions, false); + CountDownLatch latch = consumeData(mapDataPartitionIDS[0]); + + for (File file : CommonUtils.checkNotNull(temporaryFolder.getRoot().listFiles())) { + Files.delete(file.toPath()); + } + latch.await(); + + StorageTestUtils.assertNoBufferLeaking(dataStore); + assertEquals(0, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + assertEquals(1, dataStore.getStoredData().size()); + + assertThrows(PartitionNotFoundException.class, () -> consumeData(mapDataPartitionIDS[0])); + assertDataStoreEmpty(dataStore); + assertEquals(0, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + + assertEquals(numProducers, partitionStateListener.getNumCreated()); + assertEquals(numProducers, partitionStateListener.getNumRemoved()); + } + + @Test + public void testWriteDataPartitionError() throws Throwable { + CountDownLatch latch = produceData(); + + File[] files = CommonUtils.checkNotNull(temporaryFolder.getRoot().listFiles()); + while (files.length == 0) { + files = CommonUtils.checkNotNull(temporaryFolder.getRoot().listFiles()); + } + for (File file : files) { + Files.delete(file.toPath()); + } + latch.await(); + + StorageTestUtils.assertNoBufferLeaking(dataStore); + assertDataStoreEmpty(dataStore); + assertEquals(0, CommonUtils.checkNotNull(temporaryFolder.getRoot().listFiles()).length); + + assertEquals(1, partitionStateListener.getNumCreated()); + assertEquals(1, partitionStateListener.getNumRemoved()); + } + + private CountDownLatch produceData() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + ReducePartitionID reducePartitionID = new ReducePartitionID(0); + + TestDataRegionCreditListener creditListener = new TestDataRegionCreditListener(); + TestFailureListener failureLister = new TestFailureListener(); + + DataPartitionWritingView writingView = + CommonUtils.checkNotNull( + StorageTestUtils.createDataPartitionWritingView( + dataStore, creditListener, failureLister)); + + Thread writingThread = + new Thread( + () -> { + try { + int totalBuffers = 51200; + int numBuffers = 0; + int regionIndex = 0; + writingView.regionStarted(regionIndex, false); + while (!failureLister.isFailed() && numBuffers++ < totalBuffers) { + Object credit = creditListener.take(100, regionIndex); + if (credit != null) { + Buffer buffer = + writingView.getBufferSupplier().pollBuffer(); + buffer.writeBytes(StorageTestUtils.DATA_BYTES); + writingView.onBuffer(buffer, reducePartitionID); + } + } + writingView.regionFinished(); + writingView.finish(StorageTestUtils.NO_OP_DATA_COMMIT_LISTENER); + } catch (Throwable ignored) { + } + + latch.countDown(); + }); + writingThread.start(); + return latch; + } + + private CountDownLatch consumeData(MapPartitionID mapPartitionID) throws Exception { + int numReducePartitions = StorageTestUtils.NUM_REDUCE_PARTITIONS; + CountDownLatch latch = new CountDownLatch(numReducePartitions); + + DataPartitionReadingView[] readingViews = new DataPartitionReadingView[numReducePartitions]; + for (int reduceIndex = 0; reduceIndex < numReducePartitions; ++reduceIndex) { + TestDataListener dataListener = new TestDataListener(); + TestFailureListener failureLister = new TestFailureListener(); + readingViews[reduceIndex] = + dataStore.createDataPartitionReadingView( + new ReadingViewContext( + StorageTestUtils.DATA_SET_ID, + mapPartitionID, + reduceIndex, + reduceIndex, + dataListener, + StorageTestUtils.NO_OP_BACKLOG_LISTENER, + failureLister)); + + int finalIndex = reduceIndex; + Thread readingThread = + new Thread( + () -> { + DataPartitionReadingView readingView = readingViews[finalIndex]; + try { + while (!readingView.isFinished() && !failureLister.isFailed()) { + BufferWithBacklog buffer = + readingViews[finalIndex].nextBuffer(); + if (buffer != null) { + BufferUtils.recycleBuffer(buffer.getBuffer()); + } + } + } catch (Throwable ignored) { + } + latch.countDown(); + }); + readingThread.start(); + } + return latch; + } + + private MapPartitionID[] produceData( + PartitionedDataStoreImpl dataStore, int numProducers, int numRegions, boolean isError) + throws Exception { + DataProducerTask[] producerTasks = new DataProducerTask[numProducers]; + + for (int i = 0; i < numProducers; ++i) { + DataProducerTask producerTask = new DataProducerTask(dataStore, numRegions, isError); + producerTasks[i] = producerTask; + producerTask.start(); + } + + for (DataProducerTask producerTask : producerTasks) { + producerTask.join(); + assertEquals(isError, producerTask.isFailed()); + assertFalse(producerTask.getFailureListener().isFailed()); + } + + StorageTestUtils.assertNoBufferLeaking(dataStore); + + MapPartitionID[] mapDataPartitionIDS = new MapPartitionID[numProducers]; + for (int i = 0; i < numProducers; ++i) { + mapDataPartitionIDS[i] = producerTasks[i].getPartitionID(); + } + + if (isError) { + assertEquals(0, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + assertEquals(0, dataStore.getStoredData().size()); + } else { + Map>> dataPartitions = + StorageTestUtils.getDataPartitions(Arrays.asList(mapDataPartitionIDS)); + Map>> storedData = dataStore.getStoredData(); + assertEquals(dataPartitions, storedData); + } + + return mapDataPartitionIDS; + } + + private void consumeData( + PartitionedDataStoreImpl dataStore, + MapPartitionID[] mapDataPartitionIDS, + int numRegions, + int numPartitions, + boolean isError) + throws Exception { + List consumerTasks = new ArrayList<>(); + for (int partitionIndex = 0; partitionIndex < StorageTestUtils.NUM_REDUCE_PARTITIONS; ) { + DataConsumerTask consumerTask = + new DataConsumerTask( + dataStore, + partitionIndex, + Math.min( + partitionIndex + numPartitions - 1, + StorageTestUtils.NUM_REDUCE_PARTITIONS - 1), + mapDataPartitionIDS, + numRegions, + isError); + partitionIndex += numPartitions; + consumerTasks.add(consumerTask); + consumerTask.start(); + } + + for (DataConsumerTask consumerTask : consumerTasks) { + consumerTask.join(); + assertEquals(isError, consumerTask.isFailed()); + assertFalse(consumerTask.getFailureListener().isFailed()); + } + + StorageTestUtils.assertNoBufferLeaking(dataStore); + } + + private Map>> addRandomPartitions() + throws Exception { + LocalMapPartitionFileMeta fileMeta = + StorageTestUtils.createLocalMapPartitionFileMeta(temporaryFolder, true); + StorageMeta storageMeta = StorageTestUtils.getStorageMeta(temporaryFolder); + + return addRandomPartitions( + (jobID, dataSetID, mapPartitionID, partitionID) -> { + DataPartitionMeta partitionMeta = + new LocalFileMapPartitionMeta( + jobID, dataSetID, mapPartitionID, fileMeta, storageMeta); + dataStore.addDataPartition(partitionMeta); + }); + } + + private Map>> addRandomPartitions( + AddPartitionFunction function) throws Exception { + int numJobs = 10; + Map>> dataGenerated = new HashMap<>(); + + for (int jobIndex = 1; jobIndex <= numJobs; ++jobIndex) { + JobID jobID = new JobID(CommonUtils.randomBytes(16)); + Map> dataSetPartitions = new HashMap<>(); + dataGenerated.put(jobID, dataSetPartitions); + + for (int dataSetIndex = 1; dataSetIndex <= jobIndex; ++dataSetIndex) { + DataSetID dataSetID = new DataSetID(CommonUtils.randomBytes(16)); + Set dataPartitions = new HashSet<>(); + dataSetPartitions.put(dataSetID, dataPartitions); + + for (int partitionIndex = 1; partitionIndex <= dataSetIndex; ++partitionIndex) { + MapPartitionID partitionID = new MapPartitionID(CommonUtils.randomBytes(16)); + dataPartitions.add(partitionID); + + function.addDataPartition(jobID, dataSetID, partitionID, partitionID); + } + } + } + + return dataGenerated; + } + + private void assertDataStoreEmpty(PartitionedDataStoreImpl dataStore) throws Exception { + while (!dataStore.getStoredData().isEmpty()) { + Thread.sleep(100); + } + } + + private void assertStoredDataIsExpected( + PartitionedDataStoreImpl dataStore, + Map>> expected) + throws Exception { + while (!dataStore.getStoredData().equals(expected)) { + Thread.sleep(10); + } + } + + private static class DataProducerTask extends Thread { + + private final TestDataCommitListener commitListener = new TestDataCommitListener(); + + private final MapPartitionID partitionID = new MapPartitionID(CommonUtils.randomBytes(16)); + + private final TestDataRegionCreditListener creditListener = + new TestDataRegionCreditListener(); + + private final TestFailureListener failureListener = new TestFailureListener(); + + private final PartitionedDataStoreImpl dataStore; + + private final boolean triggerWritingViewError; + + private final int numRegions; + + private volatile Throwable failure; + + private DataProducerTask( + PartitionedDataStoreImpl dataStore, + int numRegions, + boolean triggerWritingViewError) { + CommonUtils.checkArgument(dataStore != null, "Must be not null."); + CommonUtils.checkArgument(numRegions > 0, "Must be positive."); + + this.dataStore = dataStore; + this.numRegions = numRegions; + this.triggerWritingViewError = triggerWritingViewError; + setName("Data Producer Task"); + } + + @Override + public void run() { + try { + DataPartitionWritingView writingView = + dataStore.createDataPartitionWritingView( + new WritingViewContext( + StorageTestUtils.JOB_ID, + StorageTestUtils.DATA_SET_ID, + partitionID, + partitionID, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + StorageTestUtils.LOCAL_FILE_MAP_PARTITION_FACTORY, + creditListener, + failureListener)); + + for (int regionID = 0; regionID < numRegions; ++regionID) { + int numReducePartitions = StorageTestUtils.NUM_REDUCE_PARTITIONS; + writingView.regionStarted(regionID, false); + + for (int reduceID = 0; reduceID < numReducePartitions; ++reduceID) { + for (int i = 0; i < numBuffersPerReducePartitions; ++i) { + creditListener.take(0, regionID); + Buffer buffer = writingView.getBufferSupplier().pollBuffer(); + buffer.writeBytes(StorageTestUtils.DATA_BYTES); + writingView.onBuffer(buffer, new ReducePartitionID(reduceID)); + } + + if (triggerWritingViewError) { + writingView.onError(new ShuffleException("Test exception.")); + break; + } + } + writingView.regionFinished(); + } + writingView.finish(commitListener); + commitListener.waitForDataCommission(); + } catch (Throwable throwable) { + failure = throwable; + } + } + + public MapPartitionID getPartitionID() { + return partitionID; + } + + public boolean isFailed() { + return failure != null || failureListener.isFailed(); + } + + public TestFailureListener getFailureListener() { + return failureListener; + } + } + + private static class DataConsumerTask extends Thread { + + private final TestDataListener dataListener = new TestDataListener(); + + private final TestFailureListener failureListener = new TestFailureListener(); + + private final PartitionedDataStoreImpl dataStore; + + private final MapPartitionID[] mapDataPartitionIDS; + + private final int startPartitionIndex; + + private final int endPartitionIndex; + + private final int numRegions; + + private final boolean isError; + + private volatile Throwable failure; + + private DataConsumerTask( + PartitionedDataStoreImpl dataStore, + int startPartitionIndex, + int endPartitionIndex, + MapPartitionID[] mapDataPartitionIDS, + int numRegions, + boolean isError) { + CommonUtils.checkArgument(dataStore != null, "Must be not null."); + CommonUtils.checkArgument(mapDataPartitionIDS != null, "Must be not null."); + CommonUtils.checkArgument(numRegions > 0, "Must be positive."); + + this.dataStore = dataStore; + this.mapDataPartitionIDS = mapDataPartitionIDS; + this.startPartitionIndex = startPartitionIndex; + this.endPartitionIndex = endPartitionIndex; + this.numRegions = numRegions; + this.isError = isError; + setName("Data Consumer Task"); + } + + @Override + public void run() { + try { + Queue readingViews = new ArrayDeque<>(); + Map bufferCounters = new HashMap<>(); + + for (MapPartitionID mapDataPartitionID : mapDataPartitionIDS) { + DataPartitionReadingView readingView = + dataStore.createDataPartitionReadingView( + new ReadingViewContext( + StorageTestUtils.DATA_SET_ID, + mapDataPartitionID, + startPartitionIndex, + endPartitionIndex, + dataListener, + StorageTestUtils.NO_OP_BACKLOG_LISTENER, + failureListener)); + readingViews.add(readingView); + bufferCounters.put(readingView, 0); + } + + while (!readingViews.isEmpty()) { + DataPartitionReadingView readingView = readingViews.poll(); + BufferWithBacklog bufferWithBacklog; + + boolean isErrorTriggered = false; + while ((bufferWithBacklog = readingView.nextBuffer()) != null) { + assertEquals( + ByteBuffer.wrap(StorageTestUtils.DATA_BYTES), + bufferWithBacklog.getBuffer().nioBuffer()); + bufferCounters.put(readingView, bufferCounters.get(readingView) + 1); + BufferUtils.recycleBuffer(bufferWithBacklog.getBuffer()); + + if (isError) { + readingView.onError(new ShuffleException("Test exception.")); + isErrorTriggered = true; + break; + } + } + + if (!isErrorTriggered && !readingView.isFinished()) { + dataListener.waitData(100); + readingViews.add(readingView); + } + } + + if (startPartitionIndex == endPartitionIndex) { + int numBuffersPerProducer = numBuffersPerReducePartitions * numRegions; + for (Integer numBuffers : bufferCounters.values()) { + assertEquals((Integer) numBuffersPerProducer, numBuffers); + } + } + } catch (Throwable throwable) { + failure = throwable; + } + } + + public boolean isFailed() { + return failure != null || failureListener.isFailed(); + } + + public TestFailureListener getFailureListener() { + return failureListener; + } + } + + private interface AddPartitionFunction { + + void addDataPartition( + JobID jobID, + DataSetID dataSetID, + MapPartitionID mapPartitionID, + DataPartitionID partitionID) + throws Exception; + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/HDDOnlyLocalFileMapPartitionFactoryTest.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/HDDOnlyLocalFileMapPartitionFactoryTest.java new file mode 100644 index 00000000..7e47c236 --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/HDDOnlyLocalFileMapPartitionFactoryTest.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; +import com.alibaba.flink.shuffle.core.storage.StorageType; +import com.alibaba.flink.shuffle.storage.utils.StorageTestUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Properties; + +import static org.junit.Assert.assertEquals; + +/** Tests for {@link HDDOnlyLocalFileMapPartitionFactory}. */ +public class HDDOnlyLocalFileMapPartitionFactoryTest { + + @Rule public final TemporaryFolder temporaryFolder1 = new TemporaryFolder(); + + @Rule public final TemporaryFolder temporaryFolder2 = new TemporaryFolder(); + + @Test(expected = ConfigurationException.class) + public void testWithoutValidHddDataDir() { + HDDOnlyLocalFileMapPartitionFactory partitionFactory = + new HDDOnlyLocalFileMapPartitionFactory(); + Properties properties = new Properties(); + properties.setProperty( + StorageOptions.STORAGE_LOCAL_DATA_DIRS.key(), + "[SSD]" + temporaryFolder1.getRoot().getAbsolutePath()); + partitionFactory.initialize(new Configuration(properties)); + } + + @Test + public void testHDDOnly() { + HDDOnlyLocalFileMapPartitionFactory partitionFactory = + new HDDOnlyLocalFileMapPartitionFactory(); + Properties properties = new Properties(); + properties.setProperty( + StorageOptions.STORAGE_LOCAL_DATA_DIRS.key(), + String.format( + "[SSD]%s,[HDD]%s", + temporaryFolder1.getRoot().getAbsolutePath(), + temporaryFolder2.getRoot().getAbsolutePath())); + partitionFactory.initialize(new Configuration(properties)); + for (int i = 0; i < 100; ++i) { + StorageMeta storageMeta = partitionFactory.getNextDataStorageMeta(); + assertEquals( + StorageTestUtils.getStoragePath(temporaryFolder2), + storageMeta.getStoragePath()); + assertEquals(StorageType.HDD, storageMeta.getStorageType()); + } + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionFactoryTest.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionFactoryTest.java new file mode 100644 index 00000000..97d7ee77 --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionFactoryTest.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; +import com.alibaba.flink.shuffle.core.storage.StorageType; +import com.alibaba.flink.shuffle.storage.utils.StorageTestUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** Tests for {@link LocalFileMapPartitionFactory}. */ +public class LocalFileMapPartitionFactoryTest { + + @Rule public final TemporaryFolder temporaryFolder1 = new TemporaryFolder(); + + @Rule public final TemporaryFolder temporaryFolder2 = new TemporaryFolder(); + + @Rule public final TemporaryFolder temporaryFolder3 = new TemporaryFolder(); + + @Test(expected = ConfigurationException.class) + public void testDataDirNotConfigured() { + LocalFileMapPartitionFactory partitionFactory = new LocalFileMapPartitionFactory(); + partitionFactory.initialize(new Configuration(new Properties())); + } + + @Test(expected = ConfigurationException.class) + public void testIllegalDiskType() { + LocalFileMapPartitionFactory partitionFactory = new LocalFileMapPartitionFactory(); + Properties properties = new Properties(); + properties.setProperty(StorageOptions.STORAGE_PREFERRED_TYPE.key(), "Illegal"); + partitionFactory.initialize(new Configuration(properties)); + } + + @Test(expected = ConfigurationException.class) + public void testConfiguredDataDirNotExists() { + LocalFileMapPartitionFactory partitionFactory = new LocalFileMapPartitionFactory(); + Properties properties = new Properties(); + properties.setProperty(StorageOptions.STORAGE_LOCAL_DATA_DIRS.key(), "Illegal"); + partitionFactory.initialize(new Configuration(properties)); + } + + @Test(expected = ConfigurationException.class) + public void testConfiguredDataDirIsNotDirectory() throws IOException { + LocalFileMapPartitionFactory partitionFactory = new LocalFileMapPartitionFactory(); + Properties properties = new Properties(); + properties.setProperty( + StorageOptions.STORAGE_LOCAL_DATA_DIRS.key(), + temporaryFolder1.newFile().getAbsolutePath()); + partitionFactory.initialize(new Configuration(properties)); + } + + @Test(expected = ConfigurationException.class) + public void testNoValidDataDir() { + LocalFileMapPartitionFactory partitionFactory = new LocalFileMapPartitionFactory(); + Properties properties = new Properties(); + properties.setProperty(StorageOptions.STORAGE_LOCAL_DATA_DIRS.key(), " "); + partitionFactory.initialize(new Configuration(properties)); + } + + @Test + public void testInitialization() { + LocalFileMapPartitionFactory partitionFactory = new LocalFileMapPartitionFactory(); + String path1 = temporaryFolder1.getRoot().getAbsolutePath() + "/"; + String path2 = temporaryFolder2.getRoot().getAbsolutePath() + "/"; + String path3 = temporaryFolder3.getRoot().getAbsolutePath() + "/"; + + StorageMeta storageMeta1 = new StorageMeta(path1, StorageType.SSD); + StorageMeta storageMeta2 = new StorageMeta(path2, StorageType.HDD); + StorageMeta storageMeta3 = new StorageMeta(path3, StorageType.HDD); + + Properties properties = new Properties(); + properties.setProperty( + StorageOptions.STORAGE_LOCAL_DATA_DIRS.key(), + String.format("[SSD]%s,[HDD]%s,%s", path1, path2, path3)); + properties.setProperty(StorageOptions.STORAGE_PREFERRED_TYPE.key(), "HDD"); + partitionFactory.initialize(new Configuration(properties)); + + assertEquals(1, partitionFactory.getSsdStorageMetas().size()); + assertTrue(partitionFactory.getSsdStorageMetas().contains(storageMeta1)); + + assertEquals(2, partitionFactory.getHddStorageMetas().size()); + assertTrue(partitionFactory.getHddStorageMetas().contains(storageMeta2)); + assertTrue(partitionFactory.getHddStorageMetas().contains(storageMeta3)); + + assertEquals(StorageType.HDD, partitionFactory.getPreferredStorageType()); + } + + @Test + public void testFairness() { + LocalFileMapPartitionFactory partitionFactory = new LocalFileMapPartitionFactory(); + String path1 = temporaryFolder1.getRoot().getAbsolutePath() + "/"; + String path2 = temporaryFolder2.getRoot().getAbsolutePath() + "/"; + String path3 = temporaryFolder3.getRoot().getAbsolutePath() + "/"; + + Properties properties = new Properties(); + properties.setProperty( + StorageOptions.STORAGE_LOCAL_DATA_DIRS.key(), + String.format("[SSD]%s,[SSD]%s,[SSD]%s", path1, path2, path3)); + partitionFactory.initialize(new Configuration(properties)); + + List selectedDirs = new ArrayList<>(); + for (int i = 0; i < 30; ++i) { + DataPartition dataPartition = + partitionFactory.createDataPartition( + StorageTestUtils.NO_OP_PARTITIONED_DATA_STORE, + StorageTestUtils.JOB_ID, + StorageTestUtils.DATA_SET_ID, + StorageTestUtils.MAP_PARTITION_ID, + StorageTestUtils.NUM_REDUCE_PARTITIONS); + selectedDirs.add(dataPartition.getPartitionMeta().getStorageMeta().getStoragePath()); + } + + int path1Count = 0; + int path2Count = 0; + int path3Count = 0; + for (String path : selectedDirs) { + if (path.equals(path1)) { + ++path1Count; + } else if (path.equals(path2)) { + ++path2Count; + } else if (path.equals(path3)) { + ++path3Count; + } + } + + assertEquals(10, path1Count); + assertEquals(10, path2Count); + assertEquals(10, path3Count); + } + + @Test + public void testPreferredDiskType() { + LocalFileMapPartitionFactory partitionFactory = new LocalFileMapPartitionFactory(); + String path1 = temporaryFolder1.getRoot().getAbsolutePath() + "/"; + String path2 = temporaryFolder2.getRoot().getAbsolutePath() + "/"; + + Properties properties = new Properties(); + properties.setProperty( + StorageOptions.STORAGE_LOCAL_DATA_DIRS.key(), + String.format("[SSD]%s,[HDD]%s", path1, path2)); + properties.setProperty(StorageOptions.STORAGE_PREFERRED_TYPE.key(), "HDD"); + partitionFactory.initialize(new Configuration(properties)); + + for (int i = 0; i < 100; ++i) { + DataPartition dataPartition = + partitionFactory.createDataPartition( + StorageTestUtils.NO_OP_PARTITIONED_DATA_STORE, + StorageTestUtils.JOB_ID, + StorageTestUtils.DATA_SET_ID, + StorageTestUtils.MAP_PARTITION_ID, + StorageTestUtils.NUM_REDUCE_PARTITIONS); + assertEquals(path2, dataPartition.getPartitionMeta().getStorageMeta().getStoragePath()); + } + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionReaderTest.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionReaderTest.java new file mode 100644 index 00000000..19f796b7 --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionReaderTest.java @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.listener.DataListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; +import com.alibaba.flink.shuffle.core.storage.BufferQueue; +import com.alibaba.flink.shuffle.core.storage.BufferWithBacklog; +import com.alibaba.flink.shuffle.core.utils.BufferUtils; +import com.alibaba.flink.shuffle.storage.utils.StorageTestUtils; +import com.alibaba.flink.shuffle.storage.utils.TestDataListener; +import com.alibaba.flink.shuffle.storage.utils.TestFailureListener; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.DataInputStream; +import java.io.FileInputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +/** Tests for {@link LocalFileMapPartitionReader}. */ +@RunWith(Parameterized.class) +public class LocalFileMapPartitionReaderTest { + + @Rule public final TemporaryFolder temporaryFolder = new TemporaryFolder(); + + private final boolean dataChecksumEnabled; + + @Parameterized.Parameters + public static Object[] data() { + return new Boolean[] {true, false}; + } + + public LocalFileMapPartitionReaderTest(boolean dataChecksumEnabled) { + this.dataChecksumEnabled = dataChecksumEnabled; + } + + @Test + public void testReadData() throws Throwable { + int numRegions = 10; + int numBuffers = 100; + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + numRegions, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + numBuffers, + false, + dataChecksumEnabled); + + int buffersRead = readData(partitionFile, 1); + assertEquals(numRegions * numBuffers * StorageTestUtils.NUM_REDUCE_PARTITIONS, buffersRead); + assertNull(partitionFile.getIndexReadingChannel()); + assertNull(partitionFile.getDataReadingChannel()); + } + + @Test + public void testReadEmptyData() throws Throwable { + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + 0, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + 0, + false, + dataChecksumEnabled); + + int buffersRead = readData(partitionFile, 1); + assertEquals(0, buffersRead); + assertNull(partitionFile.getIndexReadingChannel()); + assertNull(partitionFile.getDataReadingChannel()); + } + + @Test + public void testReadWithEmptyReducePartition() throws Throwable { + int numRegions = 10; + int numBuffers = 100; + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + numRegions, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + numBuffers, + true, + dataChecksumEnabled); + + int buffersRead = readData(partitionFile, 1); + assertEquals( + numRegions * numBuffers * StorageTestUtils.NUM_REDUCE_PARTITIONS / 2, buffersRead); + assertNull(partitionFile.getIndexReadingChannel()); + assertNull(partitionFile.getDataReadingChannel()); + } + + @Test + public void testReadMultipleReducePartitions() throws Throwable { + int numRegions = 10; + int numBuffers = 100; + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + numRegions, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + numBuffers, + false, + dataChecksumEnabled); + + int buffersRead = readData(partitionFile, 3); + assertEquals(numRegions * numBuffers * StorageTestUtils.NUM_REDUCE_PARTITIONS, buffersRead); + assertNull(partitionFile.getIndexReadingChannel()); + assertNull(partitionFile.getDataReadingChannel()); + } + + @Test + public void testReadMultipleReducePartitionsWithEmptyOnes() throws Throwable { + int numRegions = 10; + int numBuffers = 100; + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + numRegions, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + numBuffers, + true, + dataChecksumEnabled); + + int buffersRead = readData(partitionFile, 3); + assertEquals( + numRegions * numBuffers * StorageTestUtils.NUM_REDUCE_PARTITIONS / 2, buffersRead); + assertNull(partitionFile.getIndexReadingChannel()); + assertNull(partitionFile.getDataReadingChannel()); + } + + @Test + public void testRelease() throws Throwable { + testReleaseOrOnError(true); + } + + @Test + public void testOnError() throws Throwable { + testReleaseOrOnError(false); + } + + @Test + public void testCompatibleWithStorageV0() throws Throwable { + try (FileInputStream fis = + new FileInputStream( + "src/test/resources/data_for_storage_compatibility_test/storageV0.meta"); + DataInputStream dis = new DataInputStream(fis)) { + LocalMapPartitionFileMeta fileMeta = LocalMapPartitionFileMeta.readFrom(dis); + LocalMapPartitionFile partitionFile = + new LocalMapPartitionFile(fileMeta, Integer.MAX_VALUE, false); + int buffersRead = readData(partitionFile, 1); + assertEquals(10 * 10 * StorageTestUtils.NUM_REDUCE_PARTITIONS, buffersRead); + } + } + + private void testReleaseOrOnError(boolean isRelease) throws Throwable { + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + 10, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + 100, + false, + dataChecksumEnabled); + + TestDataListener dataListener = new TestDataListener(); + TestFailureListener failureListener = new TestFailureListener(); + LocalFileMapPartitionReader partitionReader = + createPartitionReader(0, 0, dataListener, failureListener, partitionFile); + + BufferQueue bufferQueue = createBufferQueue(20); + partitionReader.readData(bufferQueue, bufferQueue::add); + assertNotNull(dataListener.waitData(0)); + + BufferWithBacklog buffer = partitionReader.nextBuffer(); + assertNotNull(buffer); + assertTrue(buffer.getBacklog() > 0); + BufferUtils.recycleBuffer(buffer.getBuffer()); + + if (isRelease) { + partitionReader.release(new ShuffleException("Test exception.")); + assertTrue(failureListener.isFailed()); + } else { + partitionReader.onError(new ShuffleException("Test exception.")); + assertFalse(failureListener.isFailed()); + } + + assertNull(partitionReader.nextBuffer()); + assertThrows( + Exception.class, () -> partitionReader.readData(bufferQueue, bufferQueue::add)); + assertEquals(20, bufferQueue.size()); + + if (!isRelease) { + partitionReader.release(new ShuffleException("Test exception.")); + assertFalse(failureListener.isFailed()); + } + } + + private LocalMapPartitionFile createPartitionFile() { + return StorageTestUtils.createLocalMapPartitionFile( + temporaryFolder.getRoot().getAbsolutePath()); + } + + private int readData(LocalMapPartitionFile partitionFile, int numPartitions) throws Throwable { + Map readers = new ConcurrentHashMap<>(); + for (int partitionIndex = 0; partitionIndex < StorageTestUtils.NUM_REDUCE_PARTITIONS; ) { + TestDataListener dataListener = new TestDataListener(); + readers.put( + createPartitionReader( + partitionIndex, + Math.min( + partitionIndex + numPartitions - 1, + StorageTestUtils.NUM_REDUCE_PARTITIONS - 1), + dataListener, + partitionFile), + dataListener); + partitionIndex += numPartitions; + } + + int buffersRead = 0; + BufferQueue bufferQueue = createBufferQueue(10); + + while (!readers.isEmpty()) { + for (LocalFileMapPartitionReader reader : readers.keySet()) { + boolean hasRemaining = reader.readData(bufferQueue, bufferQueue::add); + TestDataListener dataListener = readers.get(reader); + + if (!hasRemaining) { + readers.remove(reader); + } + assertNotNull(dataListener.waitData(0)); + + BufferWithBacklog buffer; + while ((buffer = reader.nextBuffer()) != null) { + ++buffersRead; + buffer.getBuffer().release(); + } + + if (!hasRemaining) { + assertTrue(reader.isFinished()); + } + } + } + + assertEquals(10, bufferQueue.size()); + return buffersRead; + } + + private BufferQueue createBufferQueue(int numBuffers) { + List buffers = new ArrayList<>(); + for (int i = 0; i < numBuffers; ++i) { + buffers.add(CommonUtils.allocateDirectByteBuffer(StorageTestUtils.DATA_BUFFER_SIZE)); + } + return new BufferQueue(buffers); + } + + private LocalFileMapPartitionReader createPartitionReader( + int startPartitionIndex, + int endPartitionIndex, + DataListener dataListener, + LocalMapPartitionFile partitionFile) + throws Exception { + return createPartitionReader( + startPartitionIndex, + endPartitionIndex, + dataListener, + StorageTestUtils.NO_OP_FAILURE_LISTENER, + partitionFile); + } + + private LocalFileMapPartitionReader createPartitionReader( + int startPartitionIndex, + int endPartitionIndex, + DataListener dataListener, + FailureListener failureListener, + LocalMapPartitionFile partitionFile) + throws Exception { + LocalMapPartitionFileReader fileReader = + new LocalMapPartitionFileReader( + dataChecksumEnabled, startPartitionIndex, endPartitionIndex, partitionFile); + LocalFileMapPartitionReader partitionReader = + new LocalFileMapPartitionReader( + fileReader, + dataListener, + StorageTestUtils.NO_OP_BACKLOG_LISTENER, + failureListener); + partitionReader.open(); + return partitionReader; + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionTest.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionTest.java new file mode 100644 index 00000000..08753955 --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionTest.java @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; +import com.alibaba.flink.shuffle.core.listener.FailureListener; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.storage.BufferWithBacklog; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReader; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWriter; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; +import com.alibaba.flink.shuffle.core.storage.StorageType; +import com.alibaba.flink.shuffle.core.utils.BufferUtils; +import com.alibaba.flink.shuffle.storage.exception.ConcurrentWriteException; +import com.alibaba.flink.shuffle.storage.utils.StorageTestUtils; +import com.alibaba.flink.shuffle.storage.utils.TestDataCommitListener; +import com.alibaba.flink.shuffle.storage.utils.TestDataRegionCreditListener; +import com.alibaba.flink.shuffle.storage.utils.TestFailureListener; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; + +import java.io.File; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** Tests for {@link LocalFileMapPartition}. */ +public class LocalFileMapPartitionTest { + + @Rule public Timeout timeout = new Timeout(60, TimeUnit.SECONDS); + + @Rule public final TemporaryFolder temporaryFolder = new TemporaryFolder(); + + @Test + public void testWriteAndReadPartition() throws Exception { + LocalFileMapPartition dataPartition = createLocalFileMapPartition(); + + int buffersWritten = writeLocalFileMapPartition(dataPartition, 10, false, true); + int buffersRead = readLocalFileMapPartition(dataPartition, false); + + assertEquals(buffersWritten, buffersRead); + } + + @Test + public void testWriteAndReadEmptyPartition() throws Exception { + LocalFileMapPartition dataPartition = createLocalFileMapPartition(); + + int buffersWritten = writeLocalFileMapPartition(dataPartition, 0, false, true); + int buffersRead = readLocalFileMapPartition(dataPartition, false); + + assertEquals(0, buffersRead); + assertEquals(0, buffersWritten); + } + + @Test + public void testReleaseWhileWriting() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + + LocalFileMapPartition dataPartition = createLocalFileMapPartition(); + Thread writingThread = + new Thread( + () -> { + CommonUtils.runQuietly( + () -> + writeLocalFileMapPartition( + dataPartition, 10, false, true)); + latch.countDown(); + }); + writingThread.start(); + + Thread.sleep(10); + dataPartition.releasePartition(new ShuffleException("Test exception.")).get(); + latch.await(); + + StorageTestUtils.assertNoBufferLeaking(); + assertEquals(0, CommonUtils.checkNotNull(temporaryFolder.getRoot().listFiles()).length); + } + + @Test + public void testReleaseWhileReading() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + LocalFileMapPartition dataPartition = createLocalFileMapPartition(); + + writeLocalFileMapPartition(dataPartition, 10, false, true); + Thread readingThread = + new Thread( + () -> { + CommonUtils.runQuietly( + () -> readLocalFileMapPartition(dataPartition, false)); + latch.countDown(); + }); + readingThread.start(); + + Thread.sleep(10); + dataPartition.releasePartition(new ShuffleException("Test exception.")).get(); + latch.await(); + + StorageTestUtils.assertNoBufferLeaking(); + assertEquals(0, CommonUtils.checkNotNull(temporaryFolder.getRoot().listFiles()).length); + } + + @Test + public void testOnErrorWhileWriting() throws Exception { + LocalFileMapPartition dataPartition = createLocalFileMapPartition(); + writeLocalFileMapPartition(dataPartition, 10, true, true); + + StorageTestUtils.assertNoBufferLeaking(); + } + + @Test + public void testOnErrorWhileReading() throws Exception { + LocalFileMapPartition dataPartition = createLocalFileMapPartition(); + + writeLocalFileMapPartition(dataPartition, 10, false, true); + readLocalFileMapPartition(dataPartition, true); + + StorageTestUtils.assertNoBufferLeaking(); + } + + @Test + public void testIsConsumableOfReleasedPartition() throws Exception { + LocalFileMapPartition dataPartition = createLocalFileMapPartition(); + + writeLocalFileMapPartition(dataPartition, 10, false, true); + dataPartition.releasePartition(new ShuffleException("Test exception.")).get(); + assertFalse(dataPartition.isConsumable()); + StorageTestUtils.assertNoBufferLeaking(); + } + + @Test + public void testIsConsumableOfUnfinishedPartition() throws Exception { + LocalFileMapPartition dataPartition = createLocalFileMapPartition(); + + DataPartitionWriter partitionWriter = + dataPartition.createPartitionWriter( + dataPartition.getPartitionMeta().getDataPartitionID(), + StorageTestUtils.NO_OP_CREDIT_LISTENER, + StorageTestUtils.NO_OP_FAILURE_LISTENER); + partitionWriter.startRegion(10, false); + assertFalse(dataPartition.isConsumable()); + + dataPartition.releasePartition(new ShuffleException("Test exception.")).get(); + StorageTestUtils.assertNoBufferLeaking(); + } + + @Test + public void testWritePartitionFileError() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + + LocalFileMapPartition dataPartition = createLocalFileMapPartition(); + Thread writingThread = + new Thread( + () -> { + CommonUtils.runQuietly( + () -> + writeLocalFileMapPartition( + dataPartition, 10, false, false)); + latch.countDown(); + }); + writingThread.start(); + + Thread.sleep(100); + for (File file : CommonUtils.checkNotNull(temporaryFolder.getRoot().listFiles())) { + Files.deleteIfExists(file.toPath()); + } + latch.await(); + + StorageTestUtils.assertNoBufferLeaking(); + } + + @Test + public void testReadPartitionFileError() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + LocalFileMapPartition dataPartition = createLocalFileMapPartition(); + + writeLocalFileMapPartition(dataPartition, 10, false, true); + Thread readingThread = + new Thread( + () -> { + CommonUtils.runQuietly( + () -> readLocalFileMapPartition(dataPartition, false), true); + latch.countDown(); + }); + readingThread.start(); + + Thread.sleep(10); + for (File file : CommonUtils.checkNotNull(temporaryFolder.getRoot().listFiles())) { + Files.delete(file.toPath()); + } + latch.await(); + + StorageTestUtils.assertNoBufferLeaking(); + } + + @Test + public void testWritePartitionMultipleTimes() throws Exception { + LocalFileMapPartition dataPartition = createLocalFileMapPartition(); + + writeLocalFileMapPartition(dataPartition, 0, false, true); + TestFailureListener failureListener = new TestFailureListener(); + CommonUtils.runQuietly( + () -> writeLocalFileMapPartition(dataPartition, 0, false, false, failureListener)); + assertTrue(failureListener.getFailure().getCause() instanceof ConcurrentWriteException); + } + + @Test + public void testDeletePartitionIndexFile() throws Exception { + LocalFileMapPartition dataPartition = createLocalFileMapPartition(); + writeLocalFileMapPartition(dataPartition, 10, false, true); + + for (File file : CommonUtils.checkNotNull(temporaryFolder.getRoot().listFiles())) { + if (file.getPath().contains(LocalMapPartitionFile.INDEX_FILE_SUFFIX)) { + Files.delete(file.toPath()); + } + } + assertFalse(dataPartition.isConsumable()); + } + + @Test + public void testDeletePartitionDataFile() throws Exception { + LocalFileMapPartition dataPartition = createLocalFileMapPartition(); + writeLocalFileMapPartition(dataPartition, 10, false, true); + + for (File file : CommonUtils.checkNotNull(temporaryFolder.getRoot().listFiles())) { + if (file.getPath().contains(LocalMapPartitionFile.DATA_FILE_SUFFIX)) { + Files.delete(file.toPath()); + } + } + assertFalse(dataPartition.isConsumable()); + } + + private LocalFileMapPartition createLocalFileMapPartition() { + return new LocalFileMapPartition( + new StorageMeta(temporaryFolder.getRoot().getAbsolutePath() + "/", StorageType.SSD), + StorageTestUtils.NO_OP_PARTITIONED_DATA_STORE, + StorageTestUtils.JOB_ID, + StorageTestUtils.DATA_SET_ID, + StorageTestUtils.MAP_PARTITION_ID, + StorageTestUtils.NUM_REDUCE_PARTITIONS); + } + + private int writeLocalFileMapPartition( + LocalFileMapPartition dataPartition, + int numRegions, + boolean isError, + boolean waitDataCommission) + throws Exception { + return writeLocalFileMapPartition( + dataPartition, + numRegions, + isError, + waitDataCommission, + StorageTestUtils.NO_OP_FAILURE_LISTENER); + } + + private int writeLocalFileMapPartition( + LocalFileMapPartition dataPartition, + int numRegions, + boolean isError, + boolean waitDataCommission, + FailureListener failureListener) + throws Exception { + TestDataRegionCreditListener creditListener = new TestDataRegionCreditListener(); + DataPartitionWriter partitionWriter = + dataPartition.createPartitionWriter( + dataPartition.getPartitionMeta().getDataPartitionID(), + creditListener, + failureListener); + + int buffersWritten = 0; + int numBuffers = 100; + + for (int regionIndex = 0; regionIndex < numRegions; ++regionIndex) { + partitionWriter.startRegion(regionIndex, false); + for (int reduceIndex = 0; + reduceIndex < StorageTestUtils.NUM_REDUCE_PARTITIONS; + ++reduceIndex) { + for (int bufferIndex = 0; bufferIndex < numBuffers; ++bufferIndex) { + Buffer buffer; + while ((buffer = partitionWriter.pollBuffer()) == null) { + creditListener.take(100, regionIndex); + } + + buffer.writeBytes(StorageTestUtils.DATA_BYTES); + partitionWriter.addBuffer(new ReducePartitionID(reduceIndex), buffer); + ++buffersWritten; + } + + if (isError) { + partitionWriter.onError(new ShuffleException("Test exception.")); + return buffersWritten; + } + } + partitionWriter.finishRegion(); + } + + TestDataCommitListener commitListener = new TestDataCommitListener(); + partitionWriter.finishDataInput(commitListener); + if (waitDataCommission) { + commitListener.waitForDataCommission(); + } + return buffersWritten; + } + + public int readLocalFileMapPartition(LocalFileMapPartition dataPartition, boolean isError) + throws Exception { + ConcurrentHashMap readers = + new ConcurrentHashMap<>(); + for (int reduceIndex = 0; + reduceIndex < StorageTestUtils.NUM_REDUCE_PARTITIONS; + ++reduceIndex) { + TestFailureListener failureListener = new TestFailureListener(); + final int finalReduceIndex = reduceIndex; + CommonUtils.runQuietly( + () -> { + DataPartitionReader reader = + dataPartition.createPartitionReader( + finalReduceIndex, + finalReduceIndex, + StorageTestUtils.NO_OP_DATA_LISTENER, + StorageTestUtils.NO_OP_BACKLOG_LISTENER, + failureListener); + readers.put(reader, failureListener); + }); + } + + int buffersRead = 0; + while (!readers.isEmpty()) { + for (DataPartitionReader reader : readers.keySet()) { + BufferWithBacklog buffer; + while ((buffer = reader.nextBuffer()) != null) { + assertEquals( + ByteBuffer.wrap(StorageTestUtils.DATA_BYTES), + buffer.getBuffer().nioBuffer()); + BufferUtils.recycleBuffer(buffer.getBuffer()); + ++buffersRead; + } + + if (reader.isFinished() || readers.get(reader).isFailed()) { + readers.remove(reader); + } + + if (isError) { + reader.onError(new ShuffleException("Test exception.")); + readers.remove(reader); + } + } + } + return buffersRead; + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionWriterTest.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionWriterTest.java new file mode 100644 index 00000000..0b488dbf --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalFileMapPartitionWriterTest.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; +import com.alibaba.flink.shuffle.core.listener.DataRegionCreditListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.storage.BufferQueue; +import com.alibaba.flink.shuffle.storage.utils.StorageTestUtils; +import com.alibaba.flink.shuffle.storage.utils.TestDataCommitListener; +import com.alibaba.flink.shuffle.storage.utils.TestDataRegionCreditListener; +import com.alibaba.flink.shuffle.storage.utils.TestFailureListener; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** Tests for {@link LocalFileMapPartitionWriter}. */ +@RunWith(Parameterized.class) +public class LocalFileMapPartitionWriterTest { + + @Rule public final TemporaryFolder temporaryFolder = new TemporaryFolder(); + + private final boolean dataChecksumEnabled; + + @Parameterized.Parameters + public static Object[] data() { + return new Boolean[] {true, false}; + } + + public LocalFileMapPartitionWriterTest(boolean dataChecksumEnabled) { + this.dataChecksumEnabled = dataChecksumEnabled; + } + + @Test + public void testAddAndWriteData() throws Throwable { + TestMapPartition dataPartition = createBaseMapPartition(); + LocalFileMapPartitionWriter partitionWriter = + createLocalFileMapPartitionWriter(dataPartition); + TestMapPartition.TestPartitionWritingTask writingTask = + dataPartition.getPartitionWritingTask(); + + partitionWriter.startRegion(0, false); + assertEquals(1, writingTask.getNumWritingTriggers()); + assertEquals(1, partitionWriter.getNumPendingBuffers()); + + partitionWriter.addBuffer(new ReducePartitionID(0), createBuffer()); + assertEquals(1, writingTask.getNumWritingTriggers()); + assertEquals(2, partitionWriter.getNumPendingBuffers()); + + partitionWriter.writeData(); + assertEquals(0, partitionWriter.getNumPendingBuffers()); + + partitionWriter.addBuffer(new ReducePartitionID(1), createBuffer()); + assertEquals(2, writingTask.getNumWritingTriggers()); + assertEquals(1, partitionWriter.getNumPendingBuffers()); + + partitionWriter.finishRegion(); + TestDataCommitListener commitListener = new TestDataCommitListener(); + partitionWriter.finishDataInput(commitListener); + assertEquals(2, writingTask.getNumWritingTriggers()); + assertEquals(3, partitionWriter.getNumPendingBuffers()); + + partitionWriter.writeData(); + assertEquals(0, partitionWriter.getNumPendingBuffers()); + commitListener.waitForDataCommission(); + } + + @Test + public void testOnError() throws Throwable { + TestFailureListener failureListener = new TestFailureListener(); + TestMapPartition dataPartition = createBaseMapPartition(); + LocalFileMapPartitionWriter partitionWriter = + createLocalFileMapPartitionWriter(dataPartition, failureListener); + TestMapPartition.TestPartitionWritingTask writingTask = + dataPartition.getPartitionWritingTask(); + + partitionWriter.startRegion(10, false); + partitionWriter.addBuffer(new ReducePartitionID(0), createBuffer()); + assertEquals(2, partitionWriter.getNumPendingBuffers()); + assertEquals(1, writingTask.getNumWritingTriggers()); + + partitionWriter.onError(new ShuffleException("Test.")); + assertEquals(1, partitionWriter.getNumPendingBuffers()); + assertEquals(2, writingTask.getNumWritingTriggers()); + + BufferQueue buffers = new BufferQueue(new ArrayList<>()); + buffers.add(ByteBuffer.wrap(StorageTestUtils.DATA_BYTES)); + partitionWriter.assignCredits(buffers, (ignored) -> {}); + assertEquals(1, buffers.size()); + + partitionWriter.release(new ShuffleException("Test.")); + assertFalse(failureListener.isFailed()); + } + + @Test + public void testAssignCredit() throws Throwable { + int regionIndex = 0; + TestDataRegionCreditListener creditListener = new TestDataRegionCreditListener(); + LocalFileMapPartitionWriter partitionWriter = + createLocalFileMapPartitionWriter(creditListener); + + partitionWriter.startRegion(regionIndex, false); + partitionWriter.writeData(); + + BufferQueue buffers = new BufferQueue(new ArrayList<>()); + for (int i = 1; i < BaseDataPartitionWriter.MIN_CREDITS_TO_NOTIFY; ++i) { + buffers.add(ByteBuffer.allocateDirect(StorageTestUtils.DATA_BUFFER_SIZE)); + } + + partitionWriter.assignCredits(buffers, (ignored) -> {}); + assertEquals(BaseDataPartitionWriter.MIN_CREDITS_TO_NOTIFY - 1, buffers.size()); + assertNull(creditListener.take(1, regionIndex)); + + buffers.add(ByteBuffer.allocateDirect(StorageTestUtils.DATA_BUFFER_SIZE)); + + partitionWriter.assignCredits(buffers, (ignored) -> {}); + assertEquals(0, buffers.size()); + for (int i = 0; i < BaseDataPartitionWriter.MIN_CREDITS_TO_NOTIFY; ++i) { + assertNotNull(creditListener.take(0, regionIndex)); + } + + partitionWriter.finishRegion(); + partitionWriter.writeData(); + for (int i = 0; i < BaseDataPartitionWriter.MIN_CREDITS_TO_NOTIFY; ++i) { + buffers.add(ByteBuffer.allocateDirect(StorageTestUtils.DATA_BUFFER_SIZE)); + } + partitionWriter.assignCredits(buffers, (ignored) -> {}); + assertEquals(BaseDataPartitionWriter.MIN_CREDITS_TO_NOTIFY, buffers.size()); + assertNull(creditListener.take(1, regionIndex)); + } + + @Test + public void testRelease() throws Throwable { + TestFailureListener failureListener = new TestFailureListener(); + LocalFileMapPartitionWriter partitionWriter = + createLocalFileMapPartitionWriter(failureListener); + + partitionWriter.startRegion(10, false); + partitionWriter.addBuffer(new ReducePartitionID(0), createBuffer()); + partitionWriter.writeData(); + + partitionWriter.addBuffer(new ReducePartitionID(0), createBuffer()); + partitionWriter.finishRegion(); + assertEquals(2, partitionWriter.getNumPendingBuffers()); + assertFalse(failureListener.isFailed()); + + partitionWriter.release(new ShuffleException("Test.")); + assertEquals(0, partitionWriter.getNumPendingBuffers()); + assertTrue(failureListener.isFailed()); + + BufferQueue buffers = new BufferQueue(new ArrayList<>()); + buffers.add(ByteBuffer.allocateDirect(StorageTestUtils.DATA_BUFFER_SIZE)); + + partitionWriter.assignCredits(buffers, (ignored) -> {}); + assertEquals(1, buffers.size()); + } + + private LocalFileMapPartitionWriter createLocalFileMapPartitionWriter( + DataRegionCreditListener dataRegionCreditListener) throws IOException { + return createLocalFileMapPartitionWriter( + createBaseMapPartition(), + dataRegionCreditListener, + StorageTestUtils.NO_OP_FAILURE_LISTENER); + } + + private LocalFileMapPartitionWriter createLocalFileMapPartitionWriter( + FailureListener failureListener) throws IOException { + return createLocalFileMapPartitionWriter( + createBaseMapPartition(), StorageTestUtils.NO_OP_CREDIT_LISTENER, failureListener); + } + + private LocalFileMapPartitionWriter createLocalFileMapPartitionWriter( + BaseMapPartition dataPartition, FailureListener failureListener) throws IOException { + return createLocalFileMapPartitionWriter( + dataPartition, StorageTestUtils.NO_OP_CREDIT_LISTENER, failureListener); + } + + private LocalFileMapPartitionWriter createLocalFileMapPartitionWriter( + BaseMapPartition dataPartition) throws IOException { + return createLocalFileMapPartitionWriter( + dataPartition, + StorageTestUtils.NO_OP_CREDIT_LISTENER, + StorageTestUtils.NO_OP_FAILURE_LISTENER); + } + + private LocalFileMapPartitionWriter createLocalFileMapPartitionWriter( + BaseMapPartition dataPartition, + DataRegionCreditListener dataRegionCreditListener, + FailureListener failureListener) + throws IOException { + LocalMapPartitionFile partitionFile = + StorageTestUtils.createLocalMapPartitionFile( + temporaryFolder.getRoot().getAbsolutePath()); + return new LocalFileMapPartitionWriter( + dataChecksumEnabled, + StorageTestUtils.MAP_PARTITION_ID, + dataPartition, + dataRegionCreditListener, + failureListener, + partitionFile); + } + + private Buffer createBuffer() { + return new Buffer( + ByteBuffer.allocateDirect(StorageTestUtils.DATA_BUFFER_SIZE), + (ignored) -> {}, + StorageTestUtils.DATA_BUFFER_SIZE); + } + + private TestMapPartition createBaseMapPartition() { + return new TestMapPartition(StorageTestUtils.NO_OP_PARTITIONED_DATA_STORE); + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileMetaTest.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileMetaTest.java new file mode 100644 index 00000000..2327a874 --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileMetaTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.storage.utils.StorageTestUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +/** Tests for {@link LocalMapPartitionFileMeta}. */ +public class LocalMapPartitionFileMetaTest { + + @Rule public final TemporaryFolder temporaryFolder = new TemporaryFolder(); + + @Test + public void testSerializeAndDeserialize() throws Exception { + LocalMapPartitionFileMeta fileMeta = StorageTestUtils.createLocalMapPartitionFileMeta(); + + File tmpFile = temporaryFolder.newFile(); + try (DataOutputStream output = new DataOutputStream(new FileOutputStream(tmpFile))) { + fileMeta.writeTo(output); + } + + LocalMapPartitionFileMeta recovered; + try (DataInputStream input = new DataInputStream(new FileInputStream(tmpFile))) { + recovered = LocalMapPartitionFileMeta.readFrom(input); + } + + assertEquals(fileMeta, recovered); + } + + @Test + public void testIllegalArgument() { + assertThrows( + IllegalArgumentException.class, + () -> + new LocalMapPartitionFileMeta( + null, 10, LocalMapPartitionFile.LATEST_STORAGE_VERSION)); + assertThrows( + IllegalArgumentException.class, + () -> + new LocalMapPartitionFileMeta( + "/tmp/test", 0, LocalMapPartitionFile.LATEST_STORAGE_VERSION)); + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileReaderTest.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileReaderTest.java new file mode 100644 index 00000000..970360b1 --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileReaderTest.java @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.storage.exception.FileCorruptedException; +import com.alibaba.flink.shuffle.storage.utils.StorageTestUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.File; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.StandardOpenOption; +import java.util.ArrayDeque; +import java.util.Queue; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; + +/** Tests for {@link LocalMapPartitionFileReader}. */ +@RunWith(Parameterized.class) +public class LocalMapPartitionFileReaderTest { + + @Rule public final TemporaryFolder temporaryFolder = new TemporaryFolder(); + + private final boolean dataChecksumEnabled; + + @Parameterized.Parameters + public static Object[] data() { + return new Boolean[] {true, false}; + } + + public LocalMapPartitionFileReaderTest(boolean dataChecksumEnabled) { + this.dataChecksumEnabled = dataChecksumEnabled; + } + + @Test + public void testReadData() throws Exception { + int numRegions = 10; + int numBuffers = 100; + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + numRegions, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + numBuffers, + false, + dataChecksumEnabled); + + int buffersRead = readData(partitionFile, 1); + assertEquals(numRegions * numBuffers * StorageTestUtils.NUM_REDUCE_PARTITIONS, buffersRead); + assertNull(partitionFile.getIndexReadingChannel()); + assertNull(partitionFile.getDataReadingChannel()); + } + + @Test + public void testReadEmptyData() throws Exception { + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, 0, 1, 1, false, dataChecksumEnabled); + + LocalMapPartitionFileReader fileReader = + new LocalMapPartitionFileReader(dataChecksumEnabled, 0, 0, partitionFile); + assertFalse(fileReader.hasRemaining()); + fileReader.finishReading(); + } + + @Test + public void testReadWithEmptyReducePartitions() throws Exception { + int numRegions = 10; + int numBuffers = 100; + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + numRegions, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + numBuffers, + true, + dataChecksumEnabled); + + int buffersRead = readData(partitionFile, 1); + assertEquals( + numRegions * numBuffers * StorageTestUtils.NUM_REDUCE_PARTITIONS / 2, buffersRead); + assertNull(partitionFile.getIndexReadingChannel()); + assertNull(partitionFile.getDataReadingChannel()); + } + + @Test + public void testReadMultipleReducePartitions() throws Exception { + int numRegions = 10; + int numBuffers = 100; + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + numRegions, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + numBuffers, + false, + dataChecksumEnabled); + + int buffersRead = readData(partitionFile, 3); + assertEquals(numRegions * numBuffers * StorageTestUtils.NUM_REDUCE_PARTITIONS, buffersRead); + assertNull(partitionFile.getIndexReadingChannel()); + assertNull(partitionFile.getDataReadingChannel()); + } + + @Test + public void testReadMultipleReducePartitionsWithBroadcastRegion() throws Exception { + int numRegions = 10; + int numBuffers = 100; + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + numRegions, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + numBuffers, + false, + 5, + dataChecksumEnabled); + + int buffersRead = readData(partitionFile, 3); + assertEquals(numRegions * numBuffers * StorageTestUtils.NUM_REDUCE_PARTITIONS, buffersRead); + assertNull(partitionFile.getIndexReadingChannel()); + assertNull(partitionFile.getDataReadingChannel()); + } + + @Test + public void testReadWithBroadcastRegion() throws Exception { + int numRegions = 10; + int numBuffers = 100; + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + numRegions, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + numBuffers, + false, + 5, + dataChecksumEnabled); + + int buffersRead = readData(partitionFile, 1); + assertEquals(numRegions * numBuffers * StorageTestUtils.NUM_REDUCE_PARTITIONS, buffersRead); + assertNull(partitionFile.getIndexReadingChannel()); + assertNull(partitionFile.getDataReadingChannel()); + } + + @Test + public void testReadMultipleReducePartitionsWithEmptyOnes() throws Exception { + int numRegions = 10; + int numBuffers = 100; + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + numRegions, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + numBuffers, + true, + dataChecksumEnabled); + + int buffersRead = readData(partitionFile, 3); + assertEquals( + numRegions * numBuffers * StorageTestUtils.NUM_REDUCE_PARTITIONS / 2, buffersRead); + assertNull(partitionFile.getIndexReadingChannel()); + assertNull(partitionFile.getDataReadingChannel()); + } + + @Test + public void testIndexFileCorruptedWithIncompleteRegion() throws Exception { + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + 10, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + 10, + false, + dataChecksumEnabled); + + for (File file : CommonUtils.checkNotNull(temporaryFolder.getRoot().listFiles())) { + if (file.getPath().contains(LocalMapPartitionFile.INDEX_FILE_SUFFIX)) { + try (FileChannel fileChannel = + FileChannel.open(file.toPath(), StandardOpenOption.WRITE)) { + fileChannel.truncate(10); + } + } + } + + LocalMapPartitionFileReader fileReader = + new LocalMapPartitionFileReader(dataChecksumEnabled, 0, 0, partitionFile); + assertThrows(FileCorruptedException.class, fileReader::open); + assertFalse(partitionFile.isConsumable()); + assertEquals(0, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + } + + @Test + public void testIndexFileCorruptedWithWrongChecksum() throws Exception { + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + 10, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + 10, + false, + dataChecksumEnabled); + + for (File file : CommonUtils.checkNotNull(temporaryFolder.getRoot().listFiles())) { + if (file.getPath().contains(LocalMapPartitionFile.INDEX_FILE_SUFFIX)) { + try (FileChannel fileChannel = + FileChannel.open(file.toPath(), StandardOpenOption.WRITE)) { + int indexRegionSize = + LocalMapPartitionFile.INDEX_ENTRY_SIZE + * StorageTestUtils.NUM_REDUCE_PARTITIONS; + fileChannel.truncate(fileChannel.size() - indexRegionSize); + } + } + } + + LocalMapPartitionFileReader fileReader = + new LocalMapPartitionFileReader(dataChecksumEnabled, 0, 0, partitionFile); + assertThrows(FileCorruptedException.class, fileReader::open); + assertFalse(partitionFile.isConsumable()); + assertEquals(0, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + } + + @Test + public void testDataFileCorrupted() throws Exception { + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + 10, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + 10, + false, + dataChecksumEnabled); + + for (File file : CommonUtils.checkNotNull(temporaryFolder.getRoot().listFiles())) { + if (file.getPath().contains(LocalMapPartitionFile.DATA_FILE_SUFFIX)) { + try (FileChannel fileChannel = + FileChannel.open(file.toPath(), StandardOpenOption.WRITE)) { + fileChannel.truncate(10); + } + } + } + + LocalMapPartitionFileReader fileReader = + new LocalMapPartitionFileReader(dataChecksumEnabled, 0, 0, partitionFile); + assertThrows(FileCorruptedException.class, fileReader::open); + assertFalse(partitionFile.isConsumable()); + assertEquals(0, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + } + + private LocalMapPartitionFile createPartitionFile() { + return StorageTestUtils.createLocalMapPartitionFile( + temporaryFolder.getRoot().getAbsolutePath()); + } + + private int readData(LocalMapPartitionFile partitionFile, int numPartitions) throws Exception { + Queue fileReaders = new ArrayDeque<>(); + for (int partitionIndex = 0; partitionIndex < StorageTestUtils.NUM_REDUCE_PARTITIONS; ) { + LocalMapPartitionFileReader fileReader = + new LocalMapPartitionFileReader( + dataChecksumEnabled, + partitionIndex, + Math.min( + partitionIndex + numPartitions - 1, + StorageTestUtils.NUM_REDUCE_PARTITIONS - 1), + partitionFile); + fileReader.open(); + fileReaders.add(fileReader); + partitionIndex += numPartitions; + } + + int buffersRead = 0; + ByteBuffer buffer = ByteBuffer.allocate(StorageTestUtils.DATA_BUFFER_SIZE); + while (!fileReaders.isEmpty()) { + LocalMapPartitionFileReader fileReader = fileReaders.poll(); + if (!fileReader.hasRemaining()) { + fileReader.finishReading(); + continue; + } + fileReaders.add(fileReader); + + buffer.clear(); + fileReader.readBuffer(buffer); + ++buffersRead; + assertEquals(ByteBuffer.wrap(StorageTestUtils.DATA_BYTES), buffer); + } + return buffersRead; + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileTest.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileTest.java new file mode 100644 index 00000000..b1fe10e5 --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileTest.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.storage.exception.FileCorruptedException; +import com.alibaba.flink.shuffle.storage.utils.StorageTestUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.FileChannel; +import java.nio.file.Files; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** Tests for {@link LocalMapPartitionFile}. */ +@RunWith(Parameterized.class) +public class LocalMapPartitionFileTest { + + @Rule public final TemporaryFolder temporaryFolder = new TemporaryFolder(); + + private final boolean dataChecksumEnabled; + + @Parameterized.Parameters + public static Object[] data() { + return new Boolean[] {true, false}; + } + + public LocalMapPartitionFileTest(boolean dataChecksumEnabled) { + this.dataChecksumEnabled = dataChecksumEnabled; + } + + @Test + public void testConsumable() throws Exception { + LocalMapPartitionFile partitionFile1 = createPartitionFile(); + assertTrue(partitionFile1.isConsumable()); + Files.delete(partitionFile1.getFileMeta().getDataFilePath()); + assertFalse(partitionFile1.isConsumable()); + partitionFile1.deleteFile(); + + LocalMapPartitionFile partitionFile2 = createPartitionFile(); + assertTrue(partitionFile2.isConsumable()); + Files.delete(partitionFile2.getFileMeta().getIndexFilePath()); + assertFalse(partitionFile2.isConsumable()); + partitionFile2.deleteFile(); + + LocalMapPartitionFile partitionFile3 = createPartitionFile(); + assertTrue(partitionFile3.isConsumable()); + partitionFile3.setConsumable(false); + assertFalse(partitionFile3.isConsumable()); + assertEquals(0, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + } + + @Test + public void testOpenAndCloseFile() throws Exception { + LocalMapPartitionFile partitionFile = createPartitionFile(); + + Object reader1 = new Object(); + partitionFile.openFile(reader1); + + FileChannel dataChannel = CommonUtils.checkNotNull(partitionFile.getDataReadingChannel()); + FileChannel indexChannel = CommonUtils.checkNotNull(partitionFile.getIndexReadingChannel()); + assertTrue(dataChannel.isOpen()); + assertTrue(indexChannel.isOpen()); + + Object reader2 = new Object(); + partitionFile.openFile(reader2); + + assertTrue(dataChannel.isOpen()); + assertTrue(indexChannel.isOpen()); + + partitionFile.closeFile(reader1); + + assertTrue(dataChannel.isOpen()); + assertTrue(indexChannel.isOpen()); + + partitionFile.closeFile(reader2); + + assertFalse(dataChannel.isOpen()); + assertFalse(indexChannel.isOpen()); + assertNull(partitionFile.getDataReadingChannel()); + assertNull(partitionFile.getIndexReadingChannel()); + } + + @Test + public void testDeleteFile() throws Exception { + LocalMapPartitionFile partitionFile = createPartitionFile(); + + Object reader = new Object(); + partitionFile.openFile(reader); + + partitionFile.deleteFile(); + + assertNull(partitionFile.getDataReadingChannel()); + assertNull(partitionFile.getIndexReadingChannel()); + assertEquals(0, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + } + + @Test + public void testOnError() throws Exception { + LocalMapPartitionFile partitionFile1 = createPartitionFile(); + partitionFile1.onError(new Exception("Test exception.")); + partitionFile1.onError(new Exception("Test exception.")); + partitionFile1.onError(new Exception("Test exception.")); + partitionFile1.onError(new Exception("Test exception.")); + assertTrue(partitionFile1.isConsumable()); + + partitionFile1.onError(new FileCorruptedException()); + assertFalse(partitionFile1.isConsumable()); + + LocalMapPartitionFile partitionFile2 = createPartitionFile(); + partitionFile2.onError(new ClosedChannelException()); + partitionFile2.onError(new ClosedChannelException()); + partitionFile2.onError(new ClosedChannelException()); + partitionFile2.onError(new ClosedChannelException()); + assertTrue(partitionFile2.isConsumable()); + + assertTrue(partitionFile2.isConsumable()); + partitionFile2.onError(new IOException("Test exception.")); + + assertTrue(partitionFile2.isConsumable()); + partitionFile2.onError(new IOException("Test exception.")); + + assertTrue(partitionFile2.isConsumable()); + partitionFile2.onError(new IOException("Test exception.")); + + assertTrue(partitionFile2.isConsumable()); + partitionFile2.onError(new IOException("Test exception.")); + assertFalse(partitionFile2.isConsumable()); + } + + private LocalMapPartitionFile createPartitionFile() throws Exception { + String baseDir = temporaryFolder.getRoot().getAbsolutePath(); + LocalMapPartitionFile partitionFile = StorageTestUtils.createLocalMapPartitionFile(baseDir); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, 1, 1, 1, false, dataChecksumEnabled); + return partitionFile; + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileWriterTest.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileWriterTest.java new file mode 100644 index 00000000..8a5f317a --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/LocalMapPartitionFileWriterTest.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.storage.utils.StorageTestUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** Tests for {@link LocalMapPartitionFileWriter}. */ +@RunWith(Parameterized.class) +public class LocalMapPartitionFileWriterTest { + + @Rule public final TemporaryFolder temporaryFolder = new TemporaryFolder(); + + private final boolean dataChecksumEnabled; + + @Parameterized.Parameters + public static Object[] data() { + return new Boolean[] {true, false}; + } + + public LocalMapPartitionFileWriterTest(boolean dataChecksumEnabled) { + this.dataChecksumEnabled = dataChecksumEnabled; + } + + @Test + public void testWriteData() throws Exception { + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + 10, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + 10, + false, + dataChecksumEnabled); + + assertTrue(partitionFile.isConsumable()); + assertEquals(2, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + } + + @Test + public void testWriteEmptyData() throws Exception { + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + 0, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + 0, + false, + dataChecksumEnabled); + + assertTrue(partitionFile.isConsumable()); + assertEquals(2, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + + LocalMapPartitionFileMeta fileMeta = partitionFile.getFileMeta(); + assertEquals(0, new File(fileMeta.getDataFilePath().toString()).length()); + assertEquals( + LocalMapPartitionFile.INDEX_DATA_CHECKSUM_SIZE, + new File(fileMeta.getIndexFilePath().toString()).length()); + } + + @Test + public void testWriteWithEmptyReducePartition() throws Exception { + LocalMapPartitionFile partitionFile = createPartitionFile(); + StorageTestUtils.writeLocalMapPartitionFile( + partitionFile, + 10, + StorageTestUtils.NUM_REDUCE_PARTITIONS, + 10, + true, + dataChecksumEnabled); + + assertTrue(partitionFile.isConsumable()); + assertEquals(2, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + } + + @Test + public void testWriteEmptyBuffer() throws Exception { + LocalMapPartitionFile partitionFile = createPartitionFile(); + LocalMapPartitionFileWriter fileWriter = + new LocalMapPartitionFileWriter(partitionFile, 2, dataChecksumEnabled); + fileWriter.open(); + + fileWriter.startRegion(false); + ByteBuffer data = ByteBuffer.allocateDirect(0); + fileWriter.writeBuffer(StorageTestUtils.createDataBuffer(data, 0)); + fileWriter.finishRegion(); + fileWriter.finishWriting(); + + assertTrue(partitionFile.isConsumable()); + assertEquals(2, CommonUtils.checkNotNull(temporaryFolder.getRoot().list()).length); + + LocalMapPartitionFileMeta fileMeta = partitionFile.getFileMeta(); + assertEquals(0, new File(fileMeta.getDataFilePath().toString()).length()); + assertEquals( + LocalMapPartitionFile.INDEX_DATA_CHECKSUM_SIZE, + new File(fileMeta.getIndexFilePath().toString()).length()); + } + + @Test(expected = IllegalStateException.class) + public void testStartRegionAfterClose() throws Exception { + LocalMapPartitionFileWriter fileWriter = + new LocalMapPartitionFileWriter(createPartitionFile(), 2, dataChecksumEnabled); + + fileWriter.close(); + fileWriter.startRegion(false); + } + + @Test(expected = IllegalStateException.class) + public void testFinishRegionAfterClose() throws Exception { + LocalMapPartitionFileWriter fileWriter = + new LocalMapPartitionFileWriter(createPartitionFile(), 2, dataChecksumEnabled); + + fileWriter.close(); + fileWriter.finishRegion(); + } + + @Test(expected = IllegalStateException.class) + public void testWriteBufferAfterClose() throws Exception { + LocalMapPartitionFileWriter fileWriter = + new LocalMapPartitionFileWriter(createPartitionFile(), 1, dataChecksumEnabled); + + fileWriter.close(); + fileWriter.writeBuffer( + StorageTestUtils.createDataBuffer(StorageTestUtils.createRandomData(), 0)); + } + + @Test(expected = IllegalStateException.class) + public void testFinishWritingAfterClose() throws Exception { + LocalMapPartitionFileWriter fileWriter = + new LocalMapPartitionFileWriter(createPartitionFile(), 2, dataChecksumEnabled); + + fileWriter.close(); + fileWriter.finishWriting(); + } + + @Test(expected = IllegalStateException.class) + public void testWriteNotInPartitionIndexOrder() throws IOException { + LocalMapPartitionFileWriter fileWriter = + new LocalMapPartitionFileWriter(createPartitionFile(), 2, dataChecksumEnabled); + + fileWriter.startRegion(false); + fileWriter.writeBuffer( + StorageTestUtils.createDataBuffer(StorageTestUtils.createRandomData(), 5)); + fileWriter.writeBuffer( + StorageTestUtils.createDataBuffer(StorageTestUtils.createRandomData(), 0)); + } + + @Test(expected = IllegalStateException.class) + public void testStartRegionBeforeFinish() throws IOException { + LocalMapPartitionFileWriter fileWriter = + new LocalMapPartitionFileWriter(createPartitionFile(), 1, dataChecksumEnabled); + + fileWriter.startRegion(false); + fileWriter.writeBuffer( + StorageTestUtils.createDataBuffer(StorageTestUtils.createRandomData(), 5)); + fileWriter.startRegion(false); + } + + @Test(expected = IllegalStateException.class) + public void testFinishWritingBeforeFinishRegion() throws Exception { + LocalMapPartitionFileWriter fileWriter = + new LocalMapPartitionFileWriter(createPartitionFile(), 1, dataChecksumEnabled); + + fileWriter.startRegion(false); + fileWriter.writeBuffer( + StorageTestUtils.createDataBuffer(StorageTestUtils.createRandomData(), 0)); + fileWriter.finishWriting(); + } + + private LocalMapPartitionFile createPartitionFile() { + return StorageTestUtils.createLocalMapPartitionFile( + temporaryFolder.getRoot().getAbsolutePath()); + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/SSDOnlyLocalFileMapPartitionFactoryTest.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/SSDOnlyLocalFileMapPartitionFactoryTest.java new file mode 100644 index 00000000..a820bcbe --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/SSDOnlyLocalFileMapPartitionFactoryTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.exception.ConfigurationException; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; +import com.alibaba.flink.shuffle.core.storage.StorageType; +import com.alibaba.flink.shuffle.storage.utils.StorageTestUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Properties; + +import static org.junit.Assert.assertEquals; + +/** Tests for {@link SSDOnlyLocalFileMapPartitionFactory}. */ +public class SSDOnlyLocalFileMapPartitionFactoryTest { + + @Rule public final TemporaryFolder temporaryFolder1 = new TemporaryFolder(); + + @Rule public final TemporaryFolder temporaryFolder2 = new TemporaryFolder(); + + @Test(expected = ConfigurationException.class) + public void testPreferHddWithoutValidHddDataDir() { + SSDOnlyLocalFileMapPartitionFactory partitionFactory = + new SSDOnlyLocalFileMapPartitionFactory(); + Properties properties = new Properties(); + properties.setProperty( + StorageOptions.STORAGE_LOCAL_DATA_DIRS.key(), + "[HDD]" + temporaryFolder1.getRoot().getAbsolutePath()); + partitionFactory.initialize(new Configuration(properties)); + } + + @Test + public void testSSDOnly() { + SSDOnlyLocalFileMapPartitionFactory partitionFactory = + new SSDOnlyLocalFileMapPartitionFactory(); + Properties properties = new Properties(); + properties.setProperty( + StorageOptions.STORAGE_LOCAL_DATA_DIRS.key(), + String.format( + "[SSD]%s,[HDD]%s", + temporaryFolder1.getRoot().getAbsolutePath(), + temporaryFolder2.getRoot().getAbsolutePath())); + partitionFactory.initialize(new Configuration(properties)); + + for (int i = 0; i < 100; ++i) { + StorageMeta storageMeta = partitionFactory.getNextDataStorageMeta(); + assertEquals( + StorageTestUtils.getStoragePath(temporaryFolder1), + storageMeta.getStoragePath()); + assertEquals(StorageType.SSD, storageMeta.getStorageType()); + } + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/TestMapPartition.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/TestMapPartition.java new file mode 100644 index 00000000..02386549 --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/partition/TestMapPartition.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.partition; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.listener.BacklogListener; +import com.alibaba.flink.shuffle.core.listener.DataListener; +import com.alibaba.flink.shuffle.core.listener.DataRegionCreditListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; +import com.alibaba.flink.shuffle.core.storage.DataPartition; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReader; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWriter; +import com.alibaba.flink.shuffle.core.storage.MapPartitionMeta; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.storage.utils.StorageTestUtils; + +import javax.annotation.Nullable; + +import java.util.Properties; + +/** A fake {@link DataPartition} implementation for tests. */ +public class TestMapPartition extends BaseMapPartition { + + private final TestPartitionWritingTask writingTask; + + public TestMapPartition(PartitionedDataStore dataStore) { + super( + dataStore, + dataStore + .getExecutorPool(StorageTestUtils.getStorageMeta()) + .getSingleThreadExecutor()); + + this.writingTask = new TestPartitionWritingTask(new Configuration(new Properties())); + } + + @Override + public MapPartitionMeta getPartitionMeta() { + return null; + } + + @Override + public DataPartitionType getPartitionType() { + return null; + } + + @Override + public boolean isConsumable() { + return false; + } + + @Override + protected DataPartitionReader getDataPartitionReader( + int startPartitionIndex, + int endPartitionIndex, + DataListener dataListener, + BacklogListener backlogListener, + FailureListener failureListener) { + return null; + } + + @Override + protected DataPartitionWriter getDataPartitionWriter( + MapPartitionID mapPartitionID, + DataRegionCreditListener dataRegionCreditListener, + FailureListener failureListener) { + return null; + } + + @Override + public TestPartitionWritingTask getPartitionWritingTask() { + return writingTask; + } + + @Override + public MapPartitionReadingTask getPartitionReadingTask() { + return null; + } + + /** A fake {@link DataPartitionWritingTask} implementation for tests. */ + final class TestPartitionWritingTask extends BaseMapPartition.MapPartitionWritingTask { + + private int numWritingTriggers; + + protected TestPartitionWritingTask(Configuration configuration) { + super(configuration); + } + + @Override + public void allocateResources() {} + + @Override + public void triggerWriting() { + ++numWritingTriggers; + } + + @Override + public void release(@Nullable Throwable throwable) {} + + @Override + public void process() {} + + public int getNumWritingTriggers() { + return numWritingTriggers; + } + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/IOUtilsTest.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/IOUtilsTest.java new file mode 100644 index 00000000..1d94c330 --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/IOUtilsTest.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.utils; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.StandardOpenOption; +import java.util.Random; + +import static org.junit.Assert.assertEquals; + +/** Tests for {@link IOUtils}. */ +public class IOUtilsTest { + + @Rule public final TemporaryFolder temporaryFolder = new TemporaryFolder(); + + @Test + public void testWriteBuffers() throws Exception { + int numBuffers = 4000; + int bufferSize = 4096; + Random random = new Random(); + File file = temporaryFolder.newFile(); + + try (FileChannel fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.WRITE)) { + long totalBytes = 0; + ByteBuffer[] buffers = new ByteBuffer[numBuffers]; + for (int i = 0; i < numBuffers; ++i) { + ByteBuffer buffer = + CommonUtils.allocateDirectByteBuffer(random.nextInt(bufferSize) + 1); + buffer.put(StorageTestUtils.DATA_BYTES, 0, buffer.capacity()); + buffer.flip(); + buffers[i] = buffer; + totalBytes += buffer.capacity(); + } + + IOUtils.writeBuffers(fileChannel, buffers); + assertEquals(totalBytes, fileChannel.size()); + } + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/StorageTestUtils.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/StorageTestUtils.java new file mode 100644 index 00000000..2d59fd7c --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/StorageTestUtils.java @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.utils; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.config.MemoryOptions; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; +import com.alibaba.flink.shuffle.core.listener.BacklogListener; +import com.alibaba.flink.shuffle.core.listener.DataCommitListener; +import com.alibaba.flink.shuffle.core.listener.DataListener; +import com.alibaba.flink.shuffle.core.listener.DataRegionCreditListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; +import com.alibaba.flink.shuffle.core.listener.PartitionStateListener; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.memory.BufferDispatcher; +import com.alibaba.flink.shuffle.core.memory.BufferRecycler; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWritingView; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.core.storage.ReadingViewContext; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; +import com.alibaba.flink.shuffle.core.storage.StorageType; +import com.alibaba.flink.shuffle.core.storage.WritingViewContext; +import com.alibaba.flink.shuffle.storage.datastore.NoOpPartitionedDataStore; +import com.alibaba.flink.shuffle.storage.datastore.PartitionedDataStoreImpl; +import com.alibaba.flink.shuffle.storage.partition.BufferOrMarker; +import com.alibaba.flink.shuffle.storage.partition.LocalFileMapPartitionFactory; +import com.alibaba.flink.shuffle.storage.partition.LocalFileMapPartitionMeta; +import com.alibaba.flink.shuffle.storage.partition.LocalMapPartitionFile; +import com.alibaba.flink.shuffle.storage.partition.LocalMapPartitionFileMeta; +import com.alibaba.flink.shuffle.storage.partition.LocalMapPartitionFileWriter; + +import org.junit.rules.TemporaryFolder; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; +import java.util.Set; + +/** Utility methods commonly used by tests of storage layer. */ +public class StorageTestUtils { + + public static final int NUM_REDUCE_PARTITIONS = 10; + + public static final int DATA_BUFFER_SIZE = 32 * 1024; + + public static final byte[] DATA_BYTES = CommonUtils.randomBytes(DATA_BUFFER_SIZE); + + public static final PartitionedDataStore NO_OP_PARTITIONED_DATA_STORE = + new NoOpPartitionedDataStore(); + + public static final JobID JOB_ID = new JobID(CommonUtils.randomBytes(16)); + + public static final DataSetID DATA_SET_ID = new DataSetID(CommonUtils.randomBytes(16)); + + public static final FailureListener NO_OP_FAILURE_LISTENER = (ignored) -> {}; + + public static final MapPartitionID MAP_PARTITION_ID = + new MapPartitionID(CommonUtils.randomBytes(16)); + + public static final DataRegionCreditListener NO_OP_CREDIT_LISTENER = (ignored1, ignored2) -> {}; + + public static final DataListener NO_OP_DATA_LISTENER = () -> {}; + + public static final BacklogListener NO_OP_BACKLOG_LISTENER = (ignored) -> {}; + + public static final BufferRecycler NO_OP_BUFFER_RECYCLER = (ignored) -> {}; + + public static final DataCommitListener NO_OP_DATA_COMMIT_LISTENER = () -> {}; + + public static final String LOCAL_FILE_MAP_PARTITION_FACTORY = + LocalFileMapPartitionFactory.class.getName(); + + public static LocalMapPartitionFileMeta createLocalMapPartitionFileMeta() { + String basePath = CommonUtils.randomHexString(128) + "/"; + String fileName = CommonUtils.randomHexString(32); + Random random = new Random(); + int numReducePartitions = random.nextInt(Integer.MAX_VALUE) + 1; + + return new LocalMapPartitionFileMeta( + basePath + fileName, + numReducePartitions, + LocalMapPartitionFile.LATEST_STORAGE_VERSION); + } + + public static LocalMapPartitionFile createLocalMapPartitionFile(String baseDir) { + String basePath = baseDir + "/"; + String fileName = CommonUtils.randomHexString(32); + LocalMapPartitionFileMeta fileMeta = + new LocalMapPartitionFileMeta( + basePath + fileName, + NUM_REDUCE_PARTITIONS, + LocalMapPartitionFile.LATEST_STORAGE_VERSION); + return new LocalMapPartitionFile(fileMeta, 3, false); + } + + public static void writeLocalMapPartitionFile( + LocalMapPartitionFile partitionFile, + int numRegions, + int numReducePartitions, + int numBuffers, + boolean withEmptyReducePartitions, + boolean dataChecksumEnabled) + throws Exception { + writeLocalMapPartitionFile( + partitionFile, + numRegions, + numReducePartitions, + numBuffers, + withEmptyReducePartitions, + -1, + dataChecksumEnabled); + } + + public static void writeLocalMapPartitionFile( + LocalMapPartitionFile partitionFile, + int numRegions, + int numReducePartitions, + int numBuffers, + boolean withEmptyReducePartitions, + int broadcastRegionIndex, + boolean dataChecksumEnabled) + throws Exception { + LocalMapPartitionFileWriter fileWriter = + new LocalMapPartitionFileWriter(partitionFile, 2, dataChecksumEnabled); + fileWriter.open(); + for (int regionIndex = 0; regionIndex < numRegions; ++regionIndex) { + if (regionIndex == broadcastRegionIndex) { + fileWriter.startRegion(true); + for (int bufferIndex = 0; bufferIndex < numBuffers; ++bufferIndex) { + fileWriter.writeBuffer(createDataBuffer(createRandomData(), 0)); + } + } else { + fileWriter.startRegion(false); + for (int partition = 0; partition < numReducePartitions; ++partition) { + if (withEmptyReducePartitions && partition % 2 == 0) { + continue; + } + for (int bufferIndex = 0; bufferIndex < numBuffers; ++bufferIndex) { + fileWriter.writeBuffer(createDataBuffer(createRandomData(), partition)); + } + } + } + fileWriter.finishRegion(); + } + fileWriter.finishWriting(); + } + + public static ByteBuffer createRandomData() { + ByteBuffer data = ByteBuffer.allocateDirect(DATA_BUFFER_SIZE); + data.put(DATA_BYTES); + data.flip(); + return data; + } + + public static BufferOrMarker.DataBuffer createDataBuffer(ByteBuffer data, int channelIndex) { + return new BufferOrMarker.DataBuffer( + StorageTestUtils.MAP_PARTITION_ID, + new ReducePartitionID(channelIndex), + new Buffer(data, StorageTestUtils.NO_OP_BUFFER_RECYCLER, data.remaining())); + } + + public static void assertNoBufferLeaking() throws Exception { + assertNoBufferLeaking(NO_OP_PARTITIONED_DATA_STORE); + } + + public static void assertNoBufferLeaking(PartitionedDataStore dataStore) throws Exception { + assertNoBufferLeaking(dataStore.getWritingBufferDispatcher()); + assertNoBufferLeaking(dataStore.getReadingBufferDispatcher()); + } + + public static void assertNoBufferLeaking(BufferDispatcher bufferDispatcher) throws Exception { + while (bufferDispatcher.numAvailableBuffers() != bufferDispatcher.numTotalBuffers()) { + Thread.sleep(100); + } + } + + public static PartitionedDataStoreImpl createPartitionedDataStore( + String storageDir, PartitionStateListener partitionStateListener) { + Properties properties = new Properties(); + properties.setProperty( + MemoryOptions.MEMORY_SIZE_FOR_DATA_READING.key(), + MemoryOptions.MIN_VALID_MEMORY_SIZE.getBytes() + "b"); + properties.setProperty( + MemoryOptions.MEMORY_SIZE_FOR_DATA_WRITING.key(), + MemoryOptions.MIN_VALID_MEMORY_SIZE.getBytes() + "b"); + properties.setProperty(StorageOptions.STORAGE_LOCAL_DATA_DIRS.key(), storageDir); + Configuration configuration = new Configuration(properties); + return new PartitionedDataStoreImpl(configuration, partitionStateListener); + } + + public static void createEmptyDataPartition(PartitionedDataStore dataStore) throws Exception { + DataPartitionWritingView writingView = + CommonUtils.checkNotNull(createDataPartitionWritingView(dataStore)); + + TestDataCommitListener commitListener = new TestDataCommitListener(); + writingView.finish(commitListener); + commitListener.waitForDataCommission(); + } + + public static DataPartitionWritingView createDataPartitionWritingView( + PartitionedDataStore dataStore) throws Exception { + return dataStore.createDataPartitionWritingView( + new WritingViewContext( + JOB_ID, + DATA_SET_ID, + MAP_PARTITION_ID, + MAP_PARTITION_ID, + NUM_REDUCE_PARTITIONS, + LOCAL_FILE_MAP_PARTITION_FACTORY, + NO_OP_CREDIT_LISTENER, + NO_OP_FAILURE_LISTENER)); + } + + public static DataPartitionWritingView createDataPartitionWritingView( + PartitionedDataStore dataStore, + DataRegionCreditListener dataRegionCreditListener, + FailureListener failureListener) + throws Exception { + return dataStore.createDataPartitionWritingView( + new WritingViewContext( + JOB_ID, + DATA_SET_ID, + MAP_PARTITION_ID, + MAP_PARTITION_ID, + NUM_REDUCE_PARTITIONS, + LOCAL_FILE_MAP_PARTITION_FACTORY, + dataRegionCreditListener, + failureListener)); + } + + public static DataPartitionWritingView createDataPartitionWritingView( + PartitionedDataStore dataStore, FailureListener failureListener) throws Exception { + return dataStore.createDataPartitionWritingView( + new WritingViewContext( + JOB_ID, + DATA_SET_ID, + MAP_PARTITION_ID, + MAP_PARTITION_ID, + NUM_REDUCE_PARTITIONS, + LOCAL_FILE_MAP_PARTITION_FACTORY, + NO_OP_CREDIT_LISTENER, + failureListener)); + } + + public static Map>> getDefaultDataPartition() { + return getDataPartitions(Collections.singletonList(MAP_PARTITION_ID)); + } + + public static Map>> getDataPartitions( + List mapPartitionIDS) { + Map>> dataPartitions = new HashMap<>(); + Map> dataSetPartitions = new HashMap<>(); + dataPartitions.put(JOB_ID, dataSetPartitions); + dataSetPartitions.put(DATA_SET_ID, new HashSet<>(mapPartitionIDS)); + return dataPartitions; + } + + public static void createDataPartitionReadingView( + PartitionedDataStore dataStore, MapPartitionID mapPartitionID) throws Exception { + dataStore.createDataPartitionReadingView( + new ReadingViewContext( + DATA_SET_ID, + mapPartitionID, + 0, + 0, + NO_OP_DATA_LISTENER, + NO_OP_BACKLOG_LISTENER, + NO_OP_FAILURE_LISTENER)); + } + + public static void createDataPartitionReadingView( + PartitionedDataStore dataStore, int reduceIndex) throws Exception { + dataStore.createDataPartitionReadingView( + new ReadingViewContext( + DATA_SET_ID, + MAP_PARTITION_ID, + reduceIndex, + reduceIndex, + NO_OP_DATA_LISTENER, + NO_OP_BACKLOG_LISTENER, + NO_OP_FAILURE_LISTENER)); + } + + public static LocalMapPartitionFileMeta createLocalMapPartitionFileMeta( + TemporaryFolder temporaryFolder, boolean createFile) throws IOException { + String fileName = CommonUtils.randomHexString(32); + if (createFile) { + temporaryFolder.newFile(fileName + LocalMapPartitionFile.DATA_FILE_SUFFIX); + temporaryFolder.newFile(fileName + LocalMapPartitionFile.INDEX_FILE_SUFFIX); + } + return new LocalMapPartitionFileMeta( + getStoragePath(temporaryFolder) + fileName, + NUM_REDUCE_PARTITIONS, + LocalMapPartitionFile.LATEST_STORAGE_VERSION); + } + + public static LocalFileMapPartitionMeta createLocalFileMapPartitionMeta( + LocalMapPartitionFileMeta fileMeta, StorageMeta storageMeta) { + return new LocalFileMapPartitionMeta( + JOB_ID, DATA_SET_ID, MAP_PARTITION_ID, fileMeta, storageMeta); + } + + public static StorageMeta getStorageMeta() { + return new StorageMeta("/tmp/", StorageType.SSD); + } + + public static StorageMeta getStorageMeta(TemporaryFolder temporaryFolder) { + return new StorageMeta(getStoragePath(temporaryFolder), StorageType.SSD); + } + + public static String getStoragePath(TemporaryFolder temporaryFolder) { + return temporaryFolder.getRoot().getAbsolutePath() + "/"; + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/TestDataCommitListener.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/TestDataCommitListener.java new file mode 100644 index 00000000..616c295b --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/TestDataCommitListener.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.utils; + +import com.alibaba.flink.shuffle.core.listener.DataCommitListener; + +import java.util.concurrent.CountDownLatch; + +/** A {@link DataCommitListener} implementation for tests. */ +public class TestDataCommitListener implements DataCommitListener { + + private final CountDownLatch latch = new CountDownLatch(1); + + @Override + public void notifyDataCommitted() { + latch.countDown(); + } + + public void waitForDataCommission() throws InterruptedException { + latch.await(); + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/TestDataListener.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/TestDataListener.java new file mode 100644 index 00000000..e055d1b4 --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/TestDataListener.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.utils; + +import com.alibaba.flink.shuffle.core.listener.DataListener; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** A {@link DataListener} implementation for tests. */ +public class TestDataListener implements DataListener { + + private final BlockingQueue notifications = new LinkedBlockingQueue<>(); + + @Override + public void notifyDataAvailable() { + notifications.add(new Object()); + } + + public Object waitData(long timeout) throws InterruptedException { + return notifications.poll(timeout, TimeUnit.MILLISECONDS); + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/TestDataRegionCreditListener.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/TestDataRegionCreditListener.java new file mode 100644 index 00000000..bca0cf1c --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/TestDataRegionCreditListener.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.utils; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.listener.DataRegionCreditListener; + +import java.util.ArrayDeque; +import java.util.Queue; + +/** A {@link DataRegionCreditListener} implementation for tests. */ +public class TestDataRegionCreditListener implements DataRegionCreditListener { + + private final Queue buffers = new ArrayDeque<>(); + + private int currentDataRegion = -1; + + @Override + public void notifyCredits(int availableCredits, int dataRegionIndex) { + CommonUtils.checkState(availableCredits > 0, "Must be positive."); + synchronized (buffers) { + if (dataRegionIndex != currentDataRegion) { + currentDataRegion = dataRegionIndex; + buffers.clear(); + } + + while (availableCredits > 0) { + buffers.add(new Object()); + --availableCredits; + } + buffers.notify(); + } + } + + public Object take(long timeout, int dataRegionIndex) throws InterruptedException { + synchronized (buffers) { + while (buffers.isEmpty() || dataRegionIndex != currentDataRegion) { + buffers.wait(timeout); + if (timeout > 0) { + break; + } + } + return buffers.poll(); + } + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/TestFailureListener.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/TestFailureListener.java new file mode 100644 index 00000000..5abe8a90 --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/TestFailureListener.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.utils; + +import com.alibaba.flink.shuffle.core.listener.FailureListener; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +/** A {@link FailureListener} implementation for tests. */ +public class TestFailureListener implements FailureListener { + + private final CompletableFuture throwable = new CompletableFuture<>(); + + @Override + public void notifyFailure(Throwable throwable) { + this.throwable.complete(throwable); + } + + public boolean isFailed() { + return throwable.isDone(); + } + + public Throwable getFailure() throws ExecutionException, InterruptedException { + return throwable.get(); + } +} diff --git a/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/TestPartitionStateListener.java b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/TestPartitionStateListener.java new file mode 100644 index 00000000..59c67de2 --- /dev/null +++ b/shuffle-storage/src/test/java/com/alibaba/flink/shuffle/storage/utils/TestPartitionStateListener.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.storage.utils; + +import com.alibaba.flink.shuffle.core.listener.PartitionStateListener; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; + +import java.util.concurrent.atomic.AtomicInteger; + +/** A {@link PartitionStateListener} implementation for tests. */ +public class TestPartitionStateListener implements PartitionStateListener { + + private final AtomicInteger numCreated = new AtomicInteger(); + + private final AtomicInteger numRemoved = new AtomicInteger(); + + @Override + public void onPartitionCreated(DataPartitionMeta partitionMeta) { + numCreated.incrementAndGet(); + } + + @Override + public void onPartitionRemoved(DataPartitionMeta partitionMeta) { + numRemoved.incrementAndGet(); + } + + public int getNumCreated() { + return numCreated.get(); + } + + public int getNumRemoved() { + return numRemoved.get(); + } +} diff --git a/shuffle-storage/src/test/resources/data_for_storage_compatibility_test/E3F3CE037733E8C1D7344BF930535107.data b/shuffle-storage/src/test/resources/data_for_storage_compatibility_test/E3F3CE037733E8C1D7344BF930535107.data new file mode 100644 index 00000000..5024a9f8 Binary files /dev/null and b/shuffle-storage/src/test/resources/data_for_storage_compatibility_test/E3F3CE037733E8C1D7344BF930535107.data differ diff --git a/shuffle-storage/src/test/resources/data_for_storage_compatibility_test/E3F3CE037733E8C1D7344BF930535107.index b/shuffle-storage/src/test/resources/data_for_storage_compatibility_test/E3F3CE037733E8C1D7344BF930535107.index new file mode 100644 index 00000000..8e271fe8 Binary files /dev/null and b/shuffle-storage/src/test/resources/data_for_storage_compatibility_test/E3F3CE037733E8C1D7344BF930535107.index differ diff --git a/shuffle-storage/src/test/resources/data_for_storage_compatibility_test/storageV0.meta b/shuffle-storage/src/test/resources/data_for_storage_compatibility_test/storageV0.meta new file mode 100644 index 00000000..25a7185f Binary files /dev/null and b/shuffle-storage/src/test/resources/data_for_storage_compatibility_test/storageV0.meta differ diff --git a/shuffle-storage/src/test/resources/log4j2-test.properties b/shuffle-storage/src/test/resources/log4j2-test.properties new file mode 100644 index 00000000..d7fcb327 --- /dev/null +++ b/shuffle-storage/src/test/resources/log4j2-test.properties @@ -0,0 +1,26 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +rootLogger.level=OFF +rootLogger.appenderRef.test.ref=TestLogger +appender.testlogger.name=TestLogger +appender.testlogger.type=CONSOLE +appender.testlogger.target=SYSTEM_ERR +appender.testlogger.layout.type=PatternLayout +appender.testlogger.layout.pattern=%-4r [%t] %-5p %c %x - %m%n diff --git a/shuffle-transfer/pom.xml b/shuffle-transfer/pom.xml new file mode 100644 index 00000000..24c0df41 --- /dev/null +++ b/shuffle-transfer/pom.xml @@ -0,0 +1,72 @@ + + + + + com.alibaba.flink.shuffle + flink-shuffle-parent + 1.0-SNAPSHOT + + 4.0.0 + + shuffle-transfer + + + + com.alibaba.flink.shuffle + shuffle-common + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-core + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-metrics + ${project.version} + + + + org.apache.flink + flink-shaded-netty + 4.1.49.Final-${flink.shaded.version} + provided + + + + org.apache.commons + commons-lang3 + 3.3.2 + test + + + + org.mockito + mockito-core + ${mockito.version} + jar + test + + + + diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ChannelFutureListenerImpl.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ChannelFutureListenerImpl.java new file mode 100644 index 00000000..fbfbc696 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ChannelFutureListenerImpl.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.functions.BiConsumerWithException; + +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener; + +/** An implementation of {@link ChannelFutureListenerImpl} allows customizing processing logic. */ +public class ChannelFutureListenerImpl implements ChannelFutureListener { + + private final BiConsumerWithException errorHandler; + + public ChannelFutureListenerImpl( + BiConsumerWithException errorHandler) { + this.errorHandler = errorHandler; + } + + @Override + public void operationComplete(ChannelFuture channelFuture) throws Exception { + if (!channelFuture.isSuccess()) { + final Throwable cause; + if (channelFuture.cause() != null) { + cause = channelFuture.cause(); + } else { + cause = new IllegalStateException("Sending cancelled."); + } + errorHandler.accept(channelFuture, cause); + } + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/CloseChannelWhenFailure.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/CloseChannelWhenFailure.java new file mode 100644 index 00000000..b07acf96 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/CloseChannelWhenFailure.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener; + +/** A {@link ChannelFutureListener} to close connection when sending failure. */ +public class CloseChannelWhenFailure implements ChannelFutureListener { + @Override + public void operationComplete(ChannelFuture channelFuture) throws Exception { + if (!channelFuture.isSuccess()) { + ChannelFutureListener.CLOSE.operationComplete(channelFuture); + } + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/CommonTransferMessageDecoder.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/CommonTransferMessageDecoder.java new file mode 100644 index 00000000..43873027 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/CommonTransferMessageDecoder.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.transfer.TransferMessage.BacklogAnnouncement; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseChannel; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseConnection; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ErrorResponse; +import com.alibaba.flink.shuffle.transfer.TransferMessage.Heartbeat; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadAddCredit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadHandshakeRequest; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteAddCredit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteFinish; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteFinishCommit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteHandshakeRequest; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteRegionFinish; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteRegionStart; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A {@link TransferMessageDecoder} to decode messages which doesn't carry shuffle data. */ +public class CommonTransferMessageDecoder extends TransferMessageDecoder { + + private static final Logger LOG = LoggerFactory.getLogger(CommonTransferMessageDecoder.class); + + private ByteBuf messageBuffer; + + @Override + public void onNewMessageReceived(ChannelHandlerContext ctx, int msgId, int messageLength) { + super.onNewMessageReceived(ctx, msgId, messageLength); + messageBuffer = ctx.alloc().directBuffer(messageLength); + messageBuffer.clear(); + ensureBufferCapacity(); + } + + @Override + public TransferMessageDecoder.DecodingResult onChannelRead(ByteBuf data) throws Exception { + boolean accumulationFinished = + DecodingUtil.accumulate( + messageBuffer, data, messageLength, messageBuffer.readableBytes()); + if (!accumulationFinished) { + return TransferMessageDecoder.DecodingResult.NOT_FINISHED; + } + + switch (msgId) { + case ErrorResponse.ID: + return TransferMessageDecoder.DecodingResult.fullMessage( + ErrorResponse.readFrom(messageBuffer)); + case WriteHandshakeRequest.ID: + return TransferMessageDecoder.DecodingResult.fullMessage( + WriteHandshakeRequest.readFrom(messageBuffer)); + case WriteAddCredit.ID: + return TransferMessageDecoder.DecodingResult.fullMessage( + WriteAddCredit.readFrom(messageBuffer)); + case WriteRegionStart.ID: + return TransferMessageDecoder.DecodingResult.fullMessage( + WriteRegionStart.readFrom(messageBuffer)); + case WriteRegionFinish.ID: + return TransferMessageDecoder.DecodingResult.fullMessage( + WriteRegionFinish.readFrom(messageBuffer)); + case WriteFinish.ID: + return TransferMessageDecoder.DecodingResult.fullMessage( + WriteFinish.readFrom(messageBuffer)); + case WriteFinishCommit.ID: + return TransferMessageDecoder.DecodingResult.fullMessage( + WriteFinishCommit.readFrom(messageBuffer)); + case ReadHandshakeRequest.ID: + return TransferMessageDecoder.DecodingResult.fullMessage( + ReadHandshakeRequest.readFrom(messageBuffer)); + case ReadAddCredit.ID: + return TransferMessageDecoder.DecodingResult.fullMessage( + ReadAddCredit.readFrom(messageBuffer)); + case CloseChannel.ID: + return TransferMessageDecoder.DecodingResult.fullMessage( + CloseChannel.readFrom(messageBuffer)); + case CloseConnection.ID: + return TransferMessageDecoder.DecodingResult.fullMessage( + CloseConnection.readFrom(messageBuffer)); + case Heartbeat.ID: + return TransferMessageDecoder.DecodingResult.fullMessage( + Heartbeat.readFrom(messageBuffer)); + case BacklogAnnouncement.ID: + return TransferMessageDecoder.DecodingResult.fullMessage( + BacklogAnnouncement.readFrom(messageBuffer)); + default: + // not throw any exception to keep better compatibility + LOG.debug("Received unknown message from producer: " + msgId); + return DecodingResult.UNKNOWN_MESSAGE; + } + } + + /** + * Ensures the message header accumulation buffer has enough capacity for the current message. + */ + private void ensureBufferCapacity() { + if (messageBuffer.capacity() < messageLength) { + messageBuffer.capacity(messageLength); + } + } + + @Override + public void close() { + if (isClosed) { + return; + } + if (messageBuffer != null) { + messageBuffer.release(); + } + isClosed = true; + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ConnectionManager.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ConnectionManager.java new file mode 100644 index 00000000..1b20e391 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ConnectionManager.java @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseChannel; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseConnection; + +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; +import org.apache.flink.shaded.netty4.io.netty.handler.timeout.IdleStateHandler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.currentProtocolVersion; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyExtraMessage; +import static org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener.CLOSE_ON_FAILURE; + +/** Connection manager manages physical connections for remote input channels at runtime. */ +public class ConnectionManager { + + private static final Logger LOG = LoggerFactory.getLogger(ConnectionManager.class); + + /** The number of times to retry when connection failure. */ + private final int connectionRetries; + + /** Time to wait between two consecutive connection retries. */ + private final Duration connectionRetryWait; + + /** Netty configuration. */ + private final NettyConfig nettyConfig; + + /** Netty client to create connections to remote. */ + protected volatile NettyClient nettyClient; + + /** {@link Channel}s by remote shuffle address. */ + private final Map> channelsByAddress = + new ConcurrentHashMap<>(); + + /** {@link ChannelHandler}s when create connections. */ + private final Supplier channelHandlersSupplier; + + /** If the manager is shutdown. */ + private volatile boolean isShutdown; + + public ConnectionManager( + NettyConfig nettyConfig, + Supplier channelHandlersSupplier, + int connectionRetries, + Duration connectionRetryWait) { + this.nettyConfig = nettyConfig; + this.channelHandlersSupplier = channelHandlersSupplier; + this.connectionRetries = connectionRetries; + this.connectionRetryWait = connectionRetryWait; + } + + /** Start internal related components for network connection. */ + public void start() throws IOException { + if (nettyClient != null) { + return; + } + nettyClient = new NettyClient(nettyConfig); + nettyClient.init(channelHandlersSupplier); + } + + /** Shutdown internal related components for network connection. */ + public void shutdown() { + if (isShutdown) { + return; + } + nettyClient.shutdown(); + isShutdown = true; + } + + /** + * Ask for a physical connection. It will return from exists or create a new connection. + * + * @param channelID {@link ChannelID} related with the connection. + * @param address {@link InetSocketAddress} of the connection. + */ + public Channel getChannel(ChannelID channelID, InetSocketAddress address) + throws IOException, InterruptedException { + Channel ret = null; + while ((ret = getOrCreateChannel(channelID, address)) == null) { + continue; + } + return ret; + } + + private Channel getOrCreateChannel(ChannelID channelID, InetSocketAddress addr) + throws IOException, InterruptedException { + CompletableFuture newFuture = new CompletableFuture<>(); + CompletableFuture oldFuture = + channelsByAddress.putIfAbsent(addr, newFuture); + + if (oldFuture == null) { + try { + Channel channel = createChannel(addr); + PhysicalChannel physicalChannel = new PhysicalChannel(channel); + physicalChannel.register(channelID); + newFuture.complete(physicalChannel); + return channel; + } catch (Throwable t) { + newFuture.completeExceptionally(new IOException("Cannot create connection.", t)); + channelsByAddress.remove(addr); + throw t; + } + } else { + try { + PhysicalChannel physicalChannel = oldFuture.get(); + if (!physicalChannel.register(channelID)) { + return null; + } else { + return physicalChannel.nettyChannel; + } + } catch (ExecutionException t) { + throw new IOException("Cannot get a channel.", t); + } + } + } + + /** + * Release a reference of a physical connection. A connection is closed when no reference on it. + * + * @param address {@link InetSocketAddress} of the connection. + * @param channelID {@link ChannelID} related with the connection. + */ + public void releaseChannel(InetSocketAddress address, ChannelID channelID) throws IOException { + try { + CompletableFuture future = channelsByAddress.get(address); + if (future == null) { + return; + } + + PhysicalChannel pChannel = future.get(); + if (!pChannel.isRegistered(channelID)) { + return; + } + + CloseChannel closeChannel = + new CloseChannel(currentProtocolVersion(), channelID, emptyExtraMessage()); + LOG.debug("(remote: {}) Send {}.", pChannel.nettyChannel.remoteAddress(), closeChannel); + pChannel.nettyChannel.writeAndFlush(closeChannel).addListener(CLOSE_ON_FAILURE); + pChannel.unRegister(address, channelID); + } catch (Throwable t) { + throw new IOException("Failed to release channel.", t); + } + } + + private void closeConnection(InetSocketAddress address, Channel channel) { + LOG.debug("Close connection to {}.", address); + channel.writeAndFlush(new CloseConnection()).addListener(ChannelFutureListener.CLOSE); + channelsByAddress.remove(address); + } + + /** Number of physical connections. */ + public int numPhysicalConnections() { + return channelsByAddress.size(); + } + + private Channel createChannel(InetSocketAddress address) throws InterruptedException { + LOG.debug("Create connection to {}.", address); + + long retryWait = Math.max(1, connectionRetryWait.toMillis()); + for (int i = 0; i < connectionRetries; i++) { + try { + return nettyClient.connect(address).sync().channel(); + } catch (InterruptedException e) { + throw e; + } catch (Throwable throwable) { + LOG.warn( + "(remote: {}) Fire connection failed {} time(s).", + address, + i + 1, + throwable); + if (i + 1 >= connectionRetries) { + throw throwable; + } + // Sleep for a period of connect timeout, thus remote can have some time to recover. + Thread.sleep(retryWait); + } + } + throw new IllegalStateException("Cannot arrive here."); + } + + private class PhysicalChannel { + final Channel nettyChannel; + final Set channelIDs = new HashSet<>(); + boolean isReleased; + + PhysicalChannel(Channel nettyChannel) { + this.nettyChannel = nettyChannel; + } + + synchronized boolean register(ChannelID channelID) { + if (isReleased) { + return false; + } + channelIDs.add(channelID); + return true; + } + + synchronized boolean isRegistered(ChannelID channelID) { + return channelIDs.contains(channelID); + } + + synchronized void unRegister(InetSocketAddress address, ChannelID channelID) { + channelIDs.remove(channelID); + if (channelIDs.isEmpty()) { + closeConnection(address, nettyChannel); + isReleased = true; + } + } + } + + /** Create {@link ConnectionManager} for write-client. */ + public static ConnectionManager createWriteConnectionManager( + NettyConfig nettyConfig, boolean enableHeartbeat) { + return new ConnectionManager( + nettyConfig, + () -> + new ChannelHandler[] { + new TransferMessageEncoder(), + DecoderDelegate.writeClientDecoderDelegate(), + new IdleStateHandler( + nettyConfig.getHeartbeatTimeoutSeconds(), + 0, + 0, + TimeUnit.SECONDS), + new WriteClientHandler( + enableHeartbeat + ? nettyConfig.getHeartbeatIntervalSeconds() + : -1) + }, + nettyConfig.getConnectionRetries(), + nettyConfig.getConnectionRetryWait()); + } + + /** Create {@link ConnectionManager} for read-client. */ + public static ConnectionManager createReadConnectionManager( + NettyConfig nettyConfig, boolean enableHeartbeat) { + return new ConnectionManager( + nettyConfig, + () -> { + ReadClientHandler handler = + new ReadClientHandler( + enableHeartbeat + ? nettyConfig.getHeartbeatIntervalSeconds() + : -1); + return new ChannelHandler[] { + new TransferMessageEncoder(), + DecoderDelegate.readClientDecoderDelegate(handler.bufferSuppliers()), + new IdleStateHandler( + nettyConfig.getHeartbeatTimeoutSeconds(), 0, 0, TimeUnit.SECONDS), + handler + }; + }, + nettyConfig.getConnectionRetries(), + nettyConfig.getConnectionRetryWait()); + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/CreditListener.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/CreditListener.java new file mode 100644 index 00000000..280cf427 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/CreditListener.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +/** Listener to be notified when there is any available credit. */ +public abstract class CreditListener { + + private int numCreditsNeeded; + + private boolean isRegistered; + + void increaseNumCreditsNeeded(int numCredits) { + numCreditsNeeded += numCredits; + } + + void decreaseNumCreditsNeeded(int numCredits) { + numCreditsNeeded -= numCredits; + } + + int getNumCreditsNeeded() { + return numCreditsNeeded; + } + + boolean isRegistered() { + return isRegistered; + } + + void setRegistered(boolean registered) { + isRegistered = registered; + } + + public abstract void notifyAvailableCredits(int numCredits); +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/DataSender.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/DataSender.java new file mode 100644 index 00000000..4d26221b --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/DataSender.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.storage.BufferWithBacklog; + +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.SocketAddress; + +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.currentProtocolVersion; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyExtraMessage; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyOffset; + +/** A {@link ChannelInboundHandlerAdapter} sending shuffle read data. */ +public class DataSender extends ChannelInboundHandlerAdapter { + + private static final Logger LOG = LoggerFactory.getLogger(DataSender.class); + + private final ReadingService readingService; + + private final ChannelFutureListenerImpl channelFutureListener; + + public DataSender(ReadingService readingService) { + this.readingService = readingService; + this.channelFutureListener = + new ChannelFutureListenerImpl( + (channelFuture, cause) -> { + if (readingService.getNumServingChannels() > 0) { + readingService.releaseOnError(cause, null); + } + ChannelFutureListener.CLOSE.operationComplete(channelFuture); + }); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object msg) { + if (msg.getClass() == DataViewReader.class) { + DataViewReader viewReader = (DataViewReader) msg; + SocketAddress addr = ctx.channel().remoteAddress(); + ChannelID channelID = viewReader.getChannelID(); + try { + LOG.debug("({}) Received {}.", addr, viewReader); + + while (true) { + BufferWithBacklog bufferWithBacklog = viewReader.getNextBuffer(); + if (bufferWithBacklog == null) { + break; + } + + Buffer buffer = bufferWithBacklog.getBuffer(); + int backlog = (int) bufferWithBacklog.getBacklog(); + NetworkMetrics.numBytesWritingThroughput().mark(buffer.readableBytes()); + TransferMessage.ReadData readData = + new TransferMessage.ReadData( + currentProtocolVersion(), + channelID, + backlog, + buffer.readableBytes(), + emptyOffset(), + buffer, + emptyExtraMessage()); + writeAndFlush(ctx, readData); + } + + if (viewReader.isEOF()) { + LOG.info("{} finished.", viewReader); + readingService.readFinish(viewReader.getChannelID()); + } + } catch (Throwable t) { + ctx.pipeline() + .fireUserEventTriggered( + new ReadServerHandler.ReadingFailureEvent(t, channelID)); + } + + } else { + ctx.fireUserEventTriggered(msg); + } + } + + private void writeAndFlush(ChannelHandlerContext ctx, Object obj) { + ctx.writeAndFlush(obj).addListener(channelFutureListener); + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/DataViewReader.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/DataViewReader.java new file mode 100644 index 00000000..f50bbb30 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/DataViewReader.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.core.storage.BufferWithBacklog; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReadingView; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.function.Consumer; + +/** A wrapper of {@link DataPartitionReadingView} providing credit related functionalities. */ +public class DataViewReader { + + private static final Logger LOG = LoggerFactory.getLogger(ReadingService.class); + + private DataPartitionReadingView readingView; + + private int credit; + + private final ChannelID channelID; + + private final String addressStr; + + private final Consumer dataListener; + + public DataViewReader( + ChannelID channelID, String addressStr, Consumer dataListener) { + this.channelID = channelID; + this.addressStr = addressStr; + this.dataListener = dataListener; + } + + public void setReadingView(DataPartitionReadingView readingView) { + this.readingView = readingView; + } + + public BufferWithBacklog getNextBuffer() throws Throwable { + if (credit > 0) { + BufferWithBacklog res = readingView.nextBuffer(); + if (res != null) { + credit--; + } + return res; + } + return null; + } + + public void addCredit(int credit) { + this.credit += credit; + } + + public int getCredit() { + return credit; + } + + public Consumer getDataListener() { + return dataListener; + } + + public boolean isEOF() { + return readingView.isFinished(); + } + + public ChannelID getChannelID() { + return channelID; + } + + public DataPartitionReadingView getReadingView() { + return readingView; + } + + @Override + public String toString() { + return String.format("DataViewReader [channelID: %s, credit: %d]", channelID, credit); + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/DataViewWriter.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/DataViewWriter.java new file mode 100644 index 00000000..61e941f2 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/DataViewWriter.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWritingView; + +/** A wrapper of {@link DataPartitionWritingView}. */ +public class DataViewWriter { + + private final DataPartitionWritingView writingView; + + private final String addressStr; + + public DataViewWriter(DataPartitionWritingView writingView, String addressStr) { + CommonUtils.checkArgument(writingView != null, "Must be not null."); + CommonUtils.checkArgument(addressStr != null, "Must be not null."); + + this.writingView = writingView; + this.addressStr = addressStr; + } + + public DataPartitionWritingView getWritingView() { + return writingView; + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/DecoderDelegate.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/DecoderDelegate.java new file mode 100644 index 00000000..4fb4b309 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/DecoderDelegate.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.transfer.ReadClientHandler.ClientReadingFailureEvent; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseChannel; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseConnection; +import com.alibaba.flink.shuffle.transfer.TransferMessage.Heartbeat; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadAddCredit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadData; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadHandshakeRequest; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteData; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteFinish; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteHandshakeRequest; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteRegionFinish; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteRegionStart; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; + +import java.util.function.Function; +import java.util.function.Supplier; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.transfer.TransferMessage.FRAME_HEADER_LENGTH; +import static com.alibaba.flink.shuffle.transfer.TransferMessage.MAGIC_NUMBER; + +/** + * A {@link ChannelInboundHandlerAdapter} wrapping {@link TransferMessageDecoder}s to decode + * different kinds of {@link TransferMessage}s. + */ +public class DecoderDelegate extends ChannelInboundHandlerAdapter { + + private TransferMessageDecoder currentDecoder; + + private ByteBuf frameHeaderBuffer; + + private final Function messageDecoders; + + public DecoderDelegate(Function messageDecoders) { + this.messageDecoders = messageDecoders; + } + + public static DecoderDelegate writeClientDecoderDelegate() { + Function decoderMap = + msgID -> new CommonTransferMessageDecoder(); + return new DecoderDelegate(decoderMap); + } + + public static DecoderDelegate readClientDecoderDelegate( + Function> bufferSuppliers) { + Function decoderMap = + msgID -> + msgID == ReadData.ID + ? new ShuffleReadDataDecoder(bufferSuppliers) + : new CommonTransferMessageDecoder(); + return new DecoderDelegate(decoderMap); + } + + public static DecoderDelegate serverDecoderDelegate( + Function> bufferSuppliers) { + Function decoderMap = + msgID -> { + switch (msgID) { + case WriteHandshakeRequest.ID: + case WriteRegionStart.ID: + case WriteRegionFinish.ID: + case WriteFinish.ID: + case ReadHandshakeRequest.ID: + case ReadAddCredit.ID: + case CloseChannel.ID: + case CloseConnection.ID: + case Heartbeat.ID: + return new CommonTransferMessageDecoder(); + case WriteData.ID: + return new ShuffleWriteDataDecoder(bufferSuppliers); + default: + throw new RuntimeException("No decoder found for message ID: " + msgID); + } + }; + return new DecoderDelegate(decoderMap); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + frameHeaderBuffer = ctx.alloc().directBuffer(FRAME_HEADER_LENGTH); + super.channelActive(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + if (currentDecoder != null) { + currentDecoder.close(); + } + frameHeaderBuffer.release(); + super.channelInactive(ctx); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (!(msg instanceof ByteBuf)) { + ctx.fireChannelRead(msg); + return; + } + ByteBuf byteBuf = (ByteBuf) msg; + try { + while (byteBuf.isReadable()) { + if (currentDecoder != null) { + TransferMessageDecoder.DecodingResult result = + currentDecoder.onChannelRead(byteBuf); + if (!result.isFinished()) { + break; + } + + ctx.fireChannelRead(result.getMessage()); + + currentDecoder.close(); + currentDecoder = null; + frameHeaderBuffer.clear(); + } + decodeHeader(ctx, byteBuf); + } + } catch (Throwable e) { + if (currentDecoder != null) { + currentDecoder.close(); + } + if (e instanceof WritingExceptionWithChannelID) { + WritingExceptionWithChannelID ec = (WritingExceptionWithChannelID) e; + WriteServerHandler.WritingFailureEvent evt = + new WriteServerHandler.WritingFailureEvent( + ec.getChannelID(), ec.getCause()); + ctx.pipeline().fireUserEventTriggered(evt); + + } else if (e instanceof ReadingExceptionWithChannelID) { + ReadingExceptionWithChannelID ec = (ReadingExceptionWithChannelID) e; + ClientReadingFailureEvent evt = + new ClientReadingFailureEvent(ec.getChannelID(), ec.getCause()); + ctx.pipeline().fireUserEventTriggered(evt); + + } else { + ctx.fireExceptionCaught(e); + } + } finally { + byteBuf.release(); + } + } + + // For testing. + void setCurrentDecoder(TransferMessageDecoder decoder) { + currentDecoder = decoder; + } + + private void decodeHeader(ChannelHandlerContext ctx, ByteBuf data) { + boolean accumulated = + DecodingUtil.accumulate( + frameHeaderBuffer, + data, + FRAME_HEADER_LENGTH, + frameHeaderBuffer.readableBytes()); + if (!accumulated) { + return; + } + int frameLength = frameHeaderBuffer.readInt(); + int magicNumber = frameHeaderBuffer.readInt(); + if (magicNumber != MAGIC_NUMBER) { + throw new RuntimeException("BUG: magic number unexpected."); + } + byte msgID = frameHeaderBuffer.readByte(); + currentDecoder = checkNotNull(messageDecoders.apply(msgID)); + currentDecoder.onNewMessageReceived(ctx, msgID, frameLength - FRAME_HEADER_LENGTH); + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/DecodingUtil.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/DecodingUtil.java new file mode 100644 index 00000000..9e4b32a6 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/DecodingUtil.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +/** Utils when decoding. */ +public class DecodingUtil { + + /** + * Method to accumulate data from network. + * + * @param target {@link ByteBuf} to accumulate data to. + * @param source {@link ByteBuf} from network. + * @param targetAccumulationSize Target data size for the accumulation. + * @param accumulatedSize Already accumulated size. + * @return Whether accumulation is done. + */ + public static boolean accumulate( + ByteBuf target, ByteBuf source, int targetAccumulationSize, int accumulatedSize) { + int copyLength = Math.min(source.readableBytes(), targetAccumulationSize - accumulatedSize); + if (copyLength > 0) { + target.writeBytes(source, copyLength); + } + return accumulatedSize + copyLength == targetAccumulationSize; + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/NettyClient.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/NettyClient.java new file mode 100644 index 00000000..5af9ed31 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/NettyClient.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.core.executor.ExecutorThreadFactory; +import com.alibaba.flink.shuffle.core.utils.FatalExitExceptionHandler; + +import org.apache.flink.shaded.netty4.io.netty.bootstrap.Bootstrap; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelException; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOption; +import org.apache.flink.shaded.netty4.io.netty.channel.epoll.Epoll; +import org.apache.flink.shaded.netty4.io.netty.channel.epoll.EpollEventLoopGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.epoll.EpollSocketChannel; +import org.apache.flink.shaded.netty4.io.netty.channel.nio.NioEventLoopGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.SocketChannel; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioSocketChannel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.function.Supplier; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** Netty client to create connections to remote. */ +public class NettyClient { + + private static final Logger LOG = LoggerFactory.getLogger(NettyClient.class); + + private static final ExecutorThreadFactory.Builder THREAD_FACTORY_BUILDER = + new ExecutorThreadFactory.Builder() + .setExceptionHandler(FatalExitExceptionHandler.INSTANCE); + + private final NettyConfig config; + + private Supplier channelHandlersSupplier; + + private Bootstrap bootstrap; + + public NettyClient(NettyConfig config) { + this.config = config; + } + + public void init(Supplier channelHandlersSupplier) throws IOException { + + checkState(bootstrap == null, "Netty client has already been initialized."); + + this.channelHandlersSupplier = channelHandlersSupplier; + + final long start = System.nanoTime(); + + bootstrap = new Bootstrap(); + + // -------------------------------------------------------------------- + // Transport-specific configuration + // -------------------------------------------------------------------- + + switch (config.getTransportType()) { + case NIO: + initNioBootstrap(); + break; + + case EPOLL: + initEpollBootstrap(); + break; + + case AUTO: + if (Epoll.isAvailable()) { + initEpollBootstrap(); + LOG.info("Transport type 'auto': using EPOLL."); + } else { + initNioBootstrap(); + LOG.info("Transport type 'auto': using NIO."); + } + } + + // -------------------------------------------------------------------- + // Configuration + // -------------------------------------------------------------------- + + bootstrap.option(ChannelOption.TCP_NODELAY, true); + bootstrap.option(ChannelOption.SO_KEEPALIVE, true); + + // Timeout for new connections + bootstrap.option( + ChannelOption.CONNECT_TIMEOUT_MILLIS, + config.getClientConnectTimeoutSeconds() * 1000); + + // Receive and send buffer size + int receiveAndSendBufferSize = config.getSendAndReceiveBufferSize(); + if (receiveAndSendBufferSize > 0) { + bootstrap.option(ChannelOption.SO_SNDBUF, receiveAndSendBufferSize); + bootstrap.option(ChannelOption.SO_RCVBUF, receiveAndSendBufferSize); + } + + final long duration = (System.nanoTime() - start) / 1_000_000; + LOG.info("Successful initialization (took {} ms).", duration); + } + + public void shutdown() { + final long start = System.nanoTime(); + + if (bootstrap != null) { + if (bootstrap.group() != null) { + bootstrap.group().shutdownGracefully(); + } + bootstrap = null; + } + + final long duration = (System.nanoTime() - start) / 1_000_000; + LOG.info("Successful shutdown (took {} ms).", duration); + } + + private void initNioBootstrap() { + // Add the server port number to the name in order to distinguish + // multiple clients running on the same host. + String name = NettyConfig.CLIENT_THREAD_GROUP_NAME + " (" + config.getServerPort() + ")"; + + NioEventLoopGroup nioGroup = + new NioEventLoopGroup( + config.getClientNumThreads(), + THREAD_FACTORY_BUILDER.setPoolName(name).build()); + bootstrap.group(nioGroup).channel(NioSocketChannel.class); + } + + private void initEpollBootstrap() { + // Add the server port number to the name in order to distinguish + // multiple clients running on the same host. + String name = NettyConfig.CLIENT_THREAD_GROUP_NAME + " (" + config.getServerPort() + ")"; + + EpollEventLoopGroup epollGroup = + new EpollEventLoopGroup( + config.getClientNumThreads(), + THREAD_FACTORY_BUILDER.setPoolName(name).build()); + bootstrap.group(epollGroup).channel(EpollSocketChannel.class); + } + + // ------------------------------------------------------------------------ + // Client connections + // ------------------------------------------------------------------------ + + public ChannelFuture connect(final InetSocketAddress serverSocketAddress) { + checkState(bootstrap != null, "Client has not been initialized yet."); + + // -------------------------------------------------------------------- + // Child channel pipeline for accepted connections + // -------------------------------------------------------------------- + + bootstrap.handler( + new ChannelInitializer() { + @Override + public void initChannel(SocketChannel channel) { + channel.pipeline().addLast(channelHandlersSupplier.get()); + } + }); + + try { + return bootstrap.connect(serverSocketAddress); + } catch (ChannelException e) { + if ((e.getCause() instanceof java.net.SocketException + && e.getCause().getMessage().equals("Too many open files")) + || (e.getCause() instanceof ChannelException + && e.getCause().getCause() instanceof java.net.SocketException + && e.getCause() + .getCause() + .getMessage() + .equals("Too many open files"))) { + throw new ChannelException( + "" + + "The operating system does not offer enough file handles" + + " to open the network connection. Please increase the" + + " number of available file handles.", + e.getCause()); + } else { + throw e; + } + } + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/NettyConfig.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/NettyConfig.java new file mode 100644 index 00000000..018a0690 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/NettyConfig.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.config.TransferOptions; +import com.alibaba.flink.shuffle.core.config.WorkerOptions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.time.Duration; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** Wrapper for network configurations. */ +public class NettyConfig { + + private static final Logger LOG = LoggerFactory.getLogger(NettyConfig.class); + + public static final String SERVER_THREAD_GROUP_NAME = "Remote Shuffle Netty Server"; + + public static final String CLIENT_THREAD_GROUP_NAME = "Remote Shuffle Netty Client"; + + private final Configuration config; // optional configuration + + private final int numPreferredClientThreads; + + public NettyConfig(Configuration config) { + this(config, 1); + } + + public NettyConfig(Configuration config, int numPreferredClientThreads) { + checkArgument(numPreferredClientThreads > 0, "Must be positive."); + + this.config = checkNotNull(config); + + this.numPreferredClientThreads = numPreferredClientThreads; + + LOG.info(this.toString()); + } + + public InetAddress getServerAddress() throws UnknownHostException { + String address = checkNotNull(config.getString(WorkerOptions.BIND_HOST)); + return InetAddress.getByName(address); + } + + public int getServerPort() { + int serverPort = config.getInteger(TransferOptions.SERVER_DATA_PORT); + checkArgument(CommonUtils.isValidHostPort(serverPort), "Invalid port number."); + return serverPort; + } + + public int getServerConnectBacklog() { + return config.getInteger(TransferOptions.CONNECT_BACKLOG); + } + + public int getServerNumThreads() { + int numThreads = config.getInteger(TransferOptions.NUM_THREADS_SERVER); + checkArgument(numThreads > 0, "Number of server thread must be positive."); + return numThreads; + } + + public int getClientNumThreads() { + int configValue = config.getInteger(TransferOptions.NUM_THREADS_CLIENT); + return configValue <= 0 ? numPreferredClientThreads : configValue; + } + + public int getConnectionRetries() { + int connectionRetries = config.getInteger(TransferOptions.CONNECTION_RETRIES); + return Math.max(1, connectionRetries); + } + + public Duration getConnectionRetryWait() { + return config.getDuration(TransferOptions.CONNECTION_RETRY_WAIT); + } + + public int getClientConnectTimeoutSeconds() { + return CommonUtils.checkedDownCast( + config.getDuration(TransferOptions.CLIENT_CONNECT_TIMEOUT).getSeconds()); + } + + public int getSendAndReceiveBufferSize() { + return CommonUtils.checkedDownCast( + config.getMemorySize(TransferOptions.SEND_RECEIVE_BUFFER_SIZE).getBytes()); + } + + public int getHeartbeatTimeoutSeconds() { + int heartbeatTimeout = + CommonUtils.checkedDownCast( + config.getDuration(TransferOptions.HEARTBEAT_TIMEOUT).getSeconds()); + checkArgument( + heartbeatTimeout > 0, + "Heartbeat timeout must be positive and no less than 1 second."); + return heartbeatTimeout; + } + + public int getHeartbeatIntervalSeconds() { + int heartbeatInterval = + CommonUtils.checkedDownCast( + config.getDuration(TransferOptions.HEARTBEAT_INTERVAL).getSeconds()); + checkArgument( + heartbeatInterval > 0, + "Heartbeat interval must be positive and no less than 1 second."); + return heartbeatInterval; + } + + public TransportType getTransportType() { + String transport = config.getString(TransferOptions.TRANSPORT_TYPE); + + switch (transport) { + case "nio": + return TransportType.NIO; + case "epoll": + return TransportType.EPOLL; + default: + return TransportType.AUTO; + } + } + + public Configuration getConfig() { + return config; + } + + @Override + public String toString() { + String format = + "NettyConfig [" + + "server port: %d, " + + "transport type: %s, " + + "number of server threads: %d, " + + "number of client threads: %d (%s), " + + "server connect backlog: %d (%s), " + + "client connect timeout (sec): %d, " + + "client connect retries: %d, " + + "client connect retry wait (sec): %d, " + + "send/receive buffer size (bytes): %d (%s), " + + "heartbeat timeout %d, " + + "heartbeat interval %d]"; + + String def = "use Netty's default"; + String man = "manual"; + + return String.format( + format, + getServerPort(), + getTransportType(), + getServerNumThreads(), + getClientNumThreads(), + getClientNumThreads() == 0 ? def : man, + getServerConnectBacklog(), + getServerConnectBacklog() == 0 ? def : man, + getClientConnectTimeoutSeconds(), + getConnectionRetries(), + getConnectionRetryWait().getSeconds(), + getSendAndReceiveBufferSize(), + getSendAndReceiveBufferSize() == 0 ? def : man, + getHeartbeatTimeoutSeconds(), + getHeartbeatIntervalSeconds()); + } + + /** Netty transportation types. */ + public enum TransportType { + NIO, + EPOLL, + AUTO + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/NettyServer.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/NettyServer.java new file mode 100644 index 00000000..571fa3ec --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/NettyServer.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.core.executor.ExecutorThreadFactory; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.core.utils.FatalExitExceptionHandler; + +import org.apache.flink.shaded.netty4.io.netty.bootstrap.ServerBootstrap; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOption; +import org.apache.flink.shaded.netty4.io.netty.channel.epoll.Epoll; +import org.apache.flink.shaded.netty4.io.netty.channel.epoll.EpollEventLoopGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.epoll.EpollServerSocketChannel; +import org.apache.flink.shaded.netty4.io.netty.channel.nio.NioEventLoopGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.SocketChannel; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioServerSocketChannel; +import org.apache.flink.shaded.netty4.io.netty.handler.timeout.IdleStateHandler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** Facility to provide shuffle service based on Netty framework. */ +public class NettyServer { + + /** Heartbeat timeout. */ + private final int heartbeatTimeout; + + /** Heartbeat interval. */ + private final int heartbeatInterval; + + private boolean heartbeatEnabled = true; + + private static final ExecutorThreadFactory.Builder THREAD_FACTORY_BUILDER = + new ExecutorThreadFactory.Builder() + .setExceptionHandler(FatalExitExceptionHandler.INSTANCE); + + private static final Logger LOG = LoggerFactory.getLogger(NettyServer.class); + + private final NettyConfig nettyConfig; + + private ServerBootstrap bootstrap; + + private ChannelFuture bindFuture; + + private final PartitionedDataStore dataStore; + + public NettyServer(PartitionedDataStore dataStore, NettyConfig nettyConfig) { + this.dataStore = dataStore; + this.nettyConfig = nettyConfig; + this.heartbeatTimeout = nettyConfig.getHeartbeatTimeoutSeconds(); + this.heartbeatInterval = nettyConfig.getHeartbeatIntervalSeconds(); + } + + public static ThreadFactory getNamedThreadFactory(String name) { + return THREAD_FACTORY_BUILDER.setPoolName(name).build(); + } + + public void start() throws IOException { + LOG.info("Starting NettyServer on port " + nettyConfig.getServerPort()); + init(() -> new ServerChannelInitializer(this::getServerHandlers)); + } + + public ChannelHandler[] getServerHandlers() { + WriteServerHandler writeServerHandler = + new WriteServerHandler(dataStore, heartbeatEnabled ? heartbeatInterval : -1); + ReadServerHandler readServerHandler = + new ReadServerHandler(dataStore, heartbeatEnabled ? heartbeatInterval : -1); + + WritingService writingService = writeServerHandler.getWritingService(); + ReadingService readingService = readServerHandler.getReadingService(); + + return new ChannelHandler[] { + new TransferMessageEncoder(), + DecoderDelegate.serverDecoderDelegate(writingService::getBufferSupplier), + new IdleStateHandler(heartbeatTimeout, 0, 0, TimeUnit.SECONDS), + writeServerHandler, + readServerHandler, + new DataSender(readingService), + }; + } + + public void disableHeartbeat() { + heartbeatEnabled = false; + } + + private void init(Supplier channelInitializer) throws IOException { + + checkState(bootstrap == null, "Netty server has already been initialized."); + + final long start = System.nanoTime(); + + bootstrap = new ServerBootstrap(); + + // -------------------------------------------------------------------- + // Transport-specific configuration + // -------------------------------------------------------------------- + + switch (nettyConfig.getTransportType()) { + case NIO: + initNioBootstrap(); + break; + + case EPOLL: + initEpollBootstrap(); + break; + + case AUTO: + if (Epoll.isAvailable()) { + initEpollBootstrap(); + LOG.info("Transport type 'auto': using EPOLL."); + } else { + initNioBootstrap(); + LOG.info("Transport type 'auto': using NIO."); + } + } + + // -------------------------------------------------------------------- + // Configuration + // -------------------------------------------------------------------- + + // Server bind address + bootstrap.localAddress(nettyConfig.getServerAddress(), nettyConfig.getServerPort()); + + int serverBacklog = nettyConfig.getServerConnectBacklog(); + if (serverBacklog > 0) { + bootstrap.option(ChannelOption.SO_BACKLOG, serverBacklog); + } + + // Receive and send buffer size + int receiveAndSendBufferSize = nettyConfig.getSendAndReceiveBufferSize(); + if (receiveAndSendBufferSize > 0) { + bootstrap.childOption(ChannelOption.SO_SNDBUF, receiveAndSendBufferSize); + bootstrap.childOption(ChannelOption.SO_RCVBUF, receiveAndSendBufferSize); + } + + // -------------------------------------------------------------------- + // Child channel pipeline for accepted connections + // -------------------------------------------------------------------- + + bootstrap.childHandler(channelInitializer.get()); + + // -------------------------------------------------------------------- + // Start Server + // -------------------------------------------------------------------- + + bindFuture = bootstrap.bind().syncUninterruptibly(); + InetSocketAddress localAddress = (InetSocketAddress) bindFuture.channel().localAddress(); + + final long duration = (System.nanoTime() - start) / 1_000_000; + LOG.info( + "Successful initialization (took {} ms). Listening on SocketAddress {}.", + duration, + localAddress); + } + + public void shutdown() { + final long start = System.nanoTime(); + if (bindFuture != null) { + bindFuture.channel().close().awaitUninterruptibly(); + bindFuture = null; + } + + if (bootstrap != null) { + if (bootstrap.group() != null) { + bootstrap.group().shutdownGracefully(); + } + bootstrap = null; + } + final long duration = (System.nanoTime() - start) / 1_000_000; + LOG.info( + "Successful shutdown on port {} (took {} ms).", + nettyConfig.getServerPort(), + duration); + } + + private void initNioBootstrap() { + // Add the server port number to the name in order to distinguish + // multiple servers running on the same host. + String name = + NettyConfig.SERVER_THREAD_GROUP_NAME + " (" + nettyConfig.getServerPort() + ")"; + + NioEventLoopGroup bossGroup = new NioEventLoopGroup(1); + NioEventLoopGroup workerGroup = + new NioEventLoopGroup( + nettyConfig.getServerNumThreads(), getNamedThreadFactory(name)); + bootstrap.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class); + } + + private void initEpollBootstrap() { + // Add the server port number to the name in order to distinguish + // multiple servers running on the same host. + String name = + NettyConfig.SERVER_THREAD_GROUP_NAME + " (" + nettyConfig.getServerPort() + ")"; + + EpollEventLoopGroup bossGroup = new EpollEventLoopGroup(1); + EpollEventLoopGroup workerGroup = + new EpollEventLoopGroup( + nettyConfig.getServerNumThreads(), getNamedThreadFactory(name)); + bootstrap.group(bossGroup, workerGroup).channel(EpollServerSocketChannel.class); + } + + private static class ServerChannelInitializer extends ChannelInitializer { + + private final Supplier serverHandlersProvider; + + public ServerChannelInitializer(Supplier serverHandlersProvider) { + this.serverHandlersProvider = serverHandlersProvider; + } + + @Override + public void initChannel(SocketChannel channel) { + channel.pipeline().addLast(serverHandlersProvider.get()); + } + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/NetworkMetrics.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/NetworkMetrics.java new file mode 100644 index 00000000..58e20116 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/NetworkMetrics.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.metrics.entry.MetricUtils; + +import com.alibaba.metrics.Counter; +import com.alibaba.metrics.Meter; + +/** Constants and util methods of network metrics. */ +public class NetworkMetrics { + + // Group name + public static final String NETWORK = "remote-shuffle.network"; + + // Current number of tcp writing connections. + public static final String NUM_WRITING_CONNECTIONS = NETWORK + ".num_writing_connections"; + + // Current number of tcp reading connections. + public static final String NUM_READING_CONNECTIONS = NETWORK + ".num_reading_connections"; + + // Current number of writing flows. + public static final String NUM_WRITING_FLOWS = NETWORK + ".num_writing_flows"; + + // Current number of reading flows. + public static final String NUM_READING_FLOWS = NETWORK + ".num_reading_flows"; + + // Current writing throughput in bytes. + public static final String NUM_BYTES_WRITING_THROUGHPUT = NETWORK + ".writing_throughput_bytes"; + + // Current reading throughput in bytes. + public static final String NUM_BYTES_READING_THROUGHPUT = NETWORK + ".reading_throughput_bytes"; + + public static Counter numWritingConnections() { + return MetricUtils.getCounter(NETWORK, NUM_WRITING_CONNECTIONS); + } + + public static Counter numReadingConnections() { + return MetricUtils.getCounter(NETWORK, NUM_READING_CONNECTIONS); + } + + public static Counter numWritingFlows() { + return MetricUtils.getCounter(NETWORK, NUM_WRITING_FLOWS); + } + + public static Counter numReadingFlows() { + return MetricUtils.getCounter(NETWORK, NUM_READING_FLOWS); + } + + public static Meter numBytesWritingThroughput() { + return MetricUtils.getMeter(NETWORK, NUM_BYTES_WRITING_THROUGHPUT); + } + + public static Meter numBytesReadingThroughput() { + return MetricUtils.getMeter(NETWORK, NUM_BYTES_READING_THROUGHPUT); + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ReadClientHandler.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ReadClientHandler.java new file mode 100644 index 00000000..86be242e --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ReadClientHandler.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.ChannelID; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.flink.shaded.netty4.io.netty.handler.timeout.IdleStateEvent; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.SocketAddress; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.function.Supplier; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * A {@link ChannelInboundHandlerAdapter} shared by multiple {@link ShuffleReadClient} for shuffle + * read. + */ +public class ReadClientHandler extends ChannelInboundHandlerAdapter { + + private static final Logger LOG = LoggerFactory.getLogger(ReadClientHandler.class); + + /** Heartbeat interval. */ + private final int heartbeatInterval; + + /** The sharing {@link ShuffleReadClient}s by {@link ChannelID}s. */ + private final Map readClientsByChannelID = + new ConcurrentHashMap<>(); + + /** {@link ScheduledFuture} for heartbeat. */ + private volatile ScheduledFuture heartbeatFuture; + + /** Whether heartbeat future canceled. */ + private boolean heartbeatFutureCanceled; + + /** Channel handler context for user event processing. */ + private volatile ChannelHandlerContext channelHandlerContext; + + /** + * @param heartbeatInterval Heartbeat interval -- client & server send heartbeat with each other + * to confirm existence. + */ + public ReadClientHandler(int heartbeatInterval) { + this.heartbeatInterval = heartbeatInterval; + } + + /** Register a {@link ShuffleReadClient}. */ + public void register(ShuffleReadClient shuffleReadClient) { + readClientsByChannelID.put(shuffleReadClient.getChannelID(), shuffleReadClient); + } + + /** Unregister a {@link ShuffleReadClient}. */ + public void unregister(ShuffleReadClient shuffleReadClient) { + readClientsByChannelID.remove(shuffleReadClient.getChannelID()); + } + + public void notifyReadCredit(TransferMessage.ReadAddCredit addCredit) { + channelHandlerContext + .executor() + .execute(() -> channelHandlerContext.pipeline().fireUserEventTriggered(addCredit)); + } + + /** Get buffer suppliers to decode shuffle read data. */ + public Function> bufferSuppliers() { + return channelID -> + () -> { + ShuffleReadClient shuffleReadClient = readClientsByChannelID.get(channelID); + if (shuffleReadClient == null) { + throw new ShuffleException( + "Channel of " + channelID + " is already released."); + } + return shuffleReadClient.requestBuffer(); + }; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + if (heartbeatInterval > 0) { + heartbeatFuture = + ctx.executor() + .scheduleAtFixedRate( + () -> ctx.writeAndFlush(new TransferMessage.Heartbeat()), + 0, + heartbeatInterval, + TimeUnit.SECONDS); + } + + if (channelHandlerContext == null) { + channelHandlerContext = ctx; + } + + super.channelActive(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + LOG.debug("({}) Connection inactive.", ctx.channel().remoteAddress()); + for (ShuffleReadClient shuffleReadClient : readClientsByChannelID.values()) { + shuffleReadClient.channelInactive(); + } + + if (heartbeatFuture != null && !heartbeatFutureCanceled) { + heartbeatFuture.cancel(true); + heartbeatFutureCanceled = true; + } + super.channelInactive(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + LOG.debug("({}) Connection exception caught.", ctx.channel().remoteAddress(), cause); + exceptionCaught(cause); + } + + private void exceptionCaught(Throwable cause) { + for (ShuffleReadClient shuffleReadClient : readClientsByChannelID.values()) { + shuffleReadClient.exceptionCaught(cause); + } + + if (heartbeatFuture != null && !heartbeatFutureCanceled) { + heartbeatFuture.cancel(true); + heartbeatFutureCanceled = true; + } + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + + ChannelID currentChannelID = null; + try { + if (msg.getClass() == TransferMessage.ReadData.class) { + TransferMessage.ReadData readData = (TransferMessage.ReadData) msg; + SocketAddress address = ctx.channel().remoteAddress(); + ChannelID channelID = readData.getChannelID(); + currentChannelID = channelID; + LOG.trace("({}) Received {}.", address, readData); + + ShuffleReadClient shuffleReadClient = readClientsByChannelID.get(channelID); + if (shuffleReadClient == null) { + readData.getBuffer().release(); + throw new IllegalStateException( + "Read channel has been unregistered -- " + channelID); + } else { + shuffleReadClient.dataReceived(readData); + } + } else if (msg.getClass() == TransferMessage.ErrorResponse.class) { + TransferMessage.ErrorResponse errorRsp = (TransferMessage.ErrorResponse) msg; + SocketAddress address = ctx.channel().remoteAddress(); + ChannelID channelID = errorRsp.getChannelID(); + currentChannelID = channelID; + LOG.debug("({}) Received {}.", address, errorRsp); + ShuffleReadClient shuffleReadClient = readClientsByChannelID.get(channelID); + assertChannelExists(channelID, shuffleReadClient); + String errorMsg = new String(errorRsp.getErrorMessageBytes()); + shuffleReadClient.exceptionCaught(new IOException(errorMsg)); + + } else if (msg.getClass() == TransferMessage.BacklogAnnouncement.class) { + TransferMessage.BacklogAnnouncement backlog = + (TransferMessage.BacklogAnnouncement) msg; + SocketAddress address = ctx.channel().remoteAddress(); + ChannelID channelID = backlog.getChannelID(); + currentChannelID = channelID; + LOG.trace("({}) Received {}.", address, backlog); + ShuffleReadClient shuffleReadClient = readClientsByChannelID.get(channelID); + assertChannelExists(channelID, shuffleReadClient); + shuffleReadClient.backlogReceived(backlog.getBacklog()); + } else { + ctx.fireChannelRead(msg); + } + } catch (Throwable t) { + SocketAddress address = ctx.channel().remoteAddress(); + LOG.debug("({}, ch: {}) Exception caught.", address, currentChannelID, t); + if (readClientsByChannelID.containsKey(currentChannelID)) { + readClientsByChannelID.get(currentChannelID).exceptionCaught(t); + } + } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof IdleStateEvent) { + IdleStateEvent event = (IdleStateEvent) evt; + LOG.debug( + "({}) Remote seems lost and connection idle -- {}.", + ctx.channel().remoteAddress(), + event.state()); + if (heartbeatInterval <= 0) { + return; + } + + CommonUtils.runQuietly( + () -> + exceptionCaught( + ctx, + new Exception("Connection idle, state is " + event.state())), + true); + } else if (evt instanceof ClientReadingFailureEvent) { + ClientReadingFailureEvent event = (ClientReadingFailureEvent) evt; + LOG.debug("({}) Received {}.", ctx.channel().remoteAddress(), event); + ShuffleReadClient shuffleReadClient = readClientsByChannelID.get(event.channelID); + if (shuffleReadClient != null) { + shuffleReadClient.exceptionCaught(event.cause); + } + // Otherwise, the client is already released, thus no need to propagate. + } else if (evt instanceof TransferMessage.ReadAddCredit) { + TransferMessage.ReadAddCredit addCredit = (TransferMessage.ReadAddCredit) evt; + if (LOG.isTraceEnabled()) { + LOG.trace( + "(remote: {}, channel: {}) Send {}.", + ctx.channel().remoteAddress(), + addCredit.getChannelID(), + addCredit); + } + ctx.channel() + .writeAndFlush(addCredit) + .addListener( + new ChannelFutureListenerImpl( + (future, throwable) -> exceptionCaught(throwable))); + } else { + ctx.fireUserEventTriggered(evt); + } + } + + private void assertChannelExists(ChannelID channelID, ShuffleReadClient shuffleReadClient) { + checkState( + shuffleReadClient != null, + () -> "Read channel has been unregistered -- " + channelID); + } + + static class ClientReadingFailureEvent { + + private final ChannelID channelID; + private final Throwable cause; + + ClientReadingFailureEvent(ChannelID channelID, Throwable cause) { + this.channelID = checkNotNull(channelID); + this.cause = checkNotNull(cause); + } + + @Override + public String toString() { + return String.format( + "ClientReadingFailureEvent [channelID: %s, cause: %s]", + channelID, cause.getMessage()); + } + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ReadServerHandler.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ReadServerHandler.java new file mode 100644 index 00000000..a3c90525 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ReadServerHandler.java @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.transfer.TransferMessage.BacklogAnnouncement; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseChannel; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseConnection; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ErrorResponse; +import com.alibaba.flink.shuffle.transfer.TransferMessage.Heartbeat; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadAddCredit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadHandshakeRequest; + +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.SimpleChannelInboundHandler; +import org.apache.flink.shaded.netty4.io.netty.handler.timeout.IdleStateEvent; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.SocketAddress; +import java.nio.channels.ClosedChannelException; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.currentProtocolVersion; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyExtraMessage; + +/** A {@link ChannelInboundHandler} serves shuffle read process on server side. */ +public class ReadServerHandler extends SimpleChannelInboundHandler { + + private static final Logger LOG = LoggerFactory.getLogger(ReadServerHandler.class); + + /** Heartbeat interval. */ + private final int heartbeatInterval; + + /** Service logic underground. */ + private final ReadingService readingService; + + /** Identifier of current channel under serving. */ + private ChannelID currentChannelID; + + /** {@link ScheduledFuture} for heartbeat. */ + private ScheduledFuture heartbeatFuture; + + /** If connection closed. */ + private boolean connectionClosed; + + /** + * @param dataStore Implementation of storage layer. + * @param heartbeatInterval Heartbeat interval in seconds. + */ + public ReadServerHandler(PartitionedDataStore dataStore, int heartbeatInterval) { + this.readingService = new ReadingService(dataStore); + this.heartbeatInterval = heartbeatInterval; + this.connectionClosed = false; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + if (heartbeatInterval > 0) { + heartbeatFuture = + ctx.executor() + .scheduleAtFixedRate( + () -> ctx.writeAndFlush(new Heartbeat()), + 0, + heartbeatInterval, + TimeUnit.SECONDS); + } + super.channelActive(ctx); + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, TransferMessage msg) { + try { + onMessage(ctx, msg); + } catch (Throwable e) { + CommonUtils.runQuietly(() -> onInternalFailure(currentChannelID, ctx, e), true); + } + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + super.channelRegistered(ctx); + NetworkMetrics.numReadingConnections().inc(); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + super.channelUnregistered(ctx); + NetworkMetrics.numReadingConnections().dec(); + } + + private void onMessage(ChannelHandlerContext ctx, TransferMessage msg) throws Throwable { + + Class msgClazz = msg.getClass(); + if (msgClazz == ReadHandshakeRequest.class) { + ReadHandshakeRequest handshakeReq = (ReadHandshakeRequest) msg; + ChannelID channelID = handshakeReq.getChannelID(); + SocketAddress address = ctx.channel().remoteAddress(); + LOG.debug("({}}) received {}.", address, handshakeReq); + + currentChannelID = channelID; + DataSetID dataSetID = handshakeReq.getDataSetID(); + MapPartitionID mapID = handshakeReq.getMapID(); + int startSubIdx = handshakeReq.getStartSubIdx(); + int endSubIdx = handshakeReq.getEndSubIdx(); + int initCredit = handshakeReq.getInitialCredit(); + Consumer dataListener = ctx.pipeline()::fireUserEventTriggered; + Consumer backlogListener = getBacklogListener(ctx, channelID); + Consumer failureListener = getFailureListener(ctx, channelID); + readingService.handshake( + channelID, + dataSetID, + mapID, + startSubIdx, + endSubIdx, + dataListener, + backlogListener, + failureListener, + initCredit, + address.toString()); + + } else if (msgClazz == ReadAddCredit.class) { + ReadAddCredit addCredit = (ReadAddCredit) msg; + LOG.trace("({}) Received {}.", ctx.channel().remoteAddress(), addCredit); + + ChannelID channelID = addCredit.getChannelID(); + currentChannelID = channelID; + int credit = addCredit.getCredit(); + if (credit > 0) { + readingService.addCredit(channelID, credit); + } + + } else if (msgClazz == CloseChannel.class) { + CloseChannel closeChannel = (CloseChannel) msg; + SocketAddress address = ctx.channel().remoteAddress(); + LOG.debug("({}) Received {}.", address, closeChannel); + + ChannelID channelID = closeChannel.getChannelID(); + currentChannelID = channelID; + readingService.closeAbnormallyIfUnderServing(channelID); + + // Ensure both WritingService and ReadingService could receive this message. + ctx.fireChannelRead(msg); + + } else if (msgClazz == CloseConnection.class) { + SocketAddress address = ctx.channel().remoteAddress(); + CloseConnection closeConnection = (CloseConnection) msg; + LOG.info("({}) received {}.", address, closeConnection); + close(ctx, new ShuffleException("Receive connection close from client.")); + + // Ensure both WritingService and ReadingService could receive this message. + ctx.fireChannelRead(msg); + + } else { + ctx.fireChannelRead(msg); + } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + Class msgClazz = evt.getClass(); + if (msgClazz == ReadingFailureEvent.class) { + ReadingFailureEvent errRspEvt = (ReadingFailureEvent) evt; + LOG.error("({}) Received {}.", ctx.channel().remoteAddress(), errRspEvt); + CommonUtils.runQuietly( + () -> onInternalFailure(errRspEvt.channelID, ctx, errRspEvt.cause), true); + + } else if (evt instanceof IdleStateEvent) { + IdleStateEvent event = (IdleStateEvent) evt; + LOG.debug( + "({}) Remote seems lost and connection idle -- {}.", + ctx.channel().remoteAddress(), + event.state()); + if (heartbeatInterval <= 0) { + return; + } + + CommonUtils.runQuietly( + () -> close(ctx, new ShuffleException("Heartbeat timeout.")), true); + } else if (evt instanceof BacklogAnnouncement) { + BacklogAnnouncement backlog = (BacklogAnnouncement) evt; + ctx.writeAndFlush(backlog).addListener(new CloseChannelWhenFailure()); + LOG.trace("({}) Announce backlog {}.", ctx.channel().remoteAddress(), backlog); + } else { + ctx.fireUserEventTriggered(evt); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + close(ctx, new ShuffleException("Channel inactive.")); + super.channelInactive(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + close(ctx, cause); + } + + public ReadingService getReadingService() { + return readingService; + } + + private Consumer getFailureListener(ChannelHandlerContext ctx, ChannelID channelID) { + return t -> { + if (t instanceof ClosedChannelException) { + return; + } + ctx.pipeline().fireUserEventTriggered(new ReadingFailureEvent(t, channelID)); + }; + } + + private Consumer getBacklogListener(ChannelHandlerContext ctx, ChannelID channelID) { + return (backlog) -> + ctx.pipeline() + .fireUserEventTriggered( + new TransferMessage.BacklogAnnouncement(channelID, backlog)); + } + + private void onInternalFailure(ChannelID channelID, ChannelHandlerContext ctx, Throwable t) { + checkNotNull(channelID); + LOG.error("(ch: {}) Internal shuffle read failure.", channelID, t); + byte[] errorMessageBytes = ExceptionUtils.summaryErrorMessageStack(t).getBytes(); + ErrorResponse errRsp = + new ErrorResponse( + currentProtocolVersion(), + channelID, + errorMessageBytes, + emptyExtraMessage()); + ctx.writeAndFlush(errRsp).addListener(new CloseChannelWhenFailure()); + readingService.releaseOnError(t, channelID); + } + + // This method is invoked when: + // 1. Received CloseConnection from client; + // 2. Network error -- + // a. sending failure message; + // b. connection inactive; + // c. connection exception caught; + // + // This method does below things: + // 1. Triggering errors to corresponding logical channels; + // 2. Stopping the heartbeat to client; + // 3. Closing physical connection; + private void close(ChannelHandlerContext ctx, Throwable throwable) { + if (readingService.getNumServingChannels() > 0) { + readingService.releaseOnError( + throwable != null ? throwable : new ClosedChannelException(), null); + } + + if (heartbeatFuture != null) { + heartbeatFuture.cancel(true); + } + ctx.channel().close(); + connectionClosed = true; + } + + static class ReadingFailureEvent { + + Throwable cause; + + ChannelID channelID; + + ReadingFailureEvent(Throwable cause, ChannelID channelID) { + this.cause = cause; + this.channelID = channelID; + } + + @Override + public String toString() { + return String.format( + "ReadingFailureEvent [channelID: %s, cause: %s]", + channelID, cause.getMessage()); + } + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ReadingExceptionWithChannelID.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ReadingExceptionWithChannelID.java new file mode 100644 index 00000000..8b80333f --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ReadingExceptionWithChannelID.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.core.ids.ChannelID; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; + +/** A {@link Throwable} with a related {@link ChannelID} while reading on shuffle worker. */ +public class ReadingExceptionWithChannelID extends ShuffleException { + + private static final long serialVersionUID = -6153186511717646882L; + + private final ChannelID channelID; + + public ReadingExceptionWithChannelID(ChannelID channelID, Throwable t) { + super(t); + checkNotNull(channelID); + checkNotNull(t); + this.channelID = channelID; + } + + public ChannelID getChannelID() { + return channelID; + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ReadingService.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ReadingService.java new file mode 100644 index 00000000..5b096dc1 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ReadingService.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReadingView; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.core.storage.ReadingViewContext; + +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * Harness used to read data from storage. It performs shuffle read by {@link PartitionedDataStore}. + * The lifecycle is the same with a Netty {@link ChannelInboundHandler} instance. + */ +public class ReadingService { + + private static final Logger LOG = LoggerFactory.getLogger(ReadingService.class); + + private final PartitionedDataStore dataStore; + + private final Map servingChannels; + + public ReadingService(PartitionedDataStore datastore) { + this.dataStore = datastore; + this.servingChannels = new HashMap<>(); + } + + public void handshake( + ChannelID channelID, + DataSetID dataSetID, + MapPartitionID mapID, + int startSubIdx, + int endSubIdx, + Consumer dataListener, + Consumer backlogListener, + Consumer failureHandler, + int initialCredit, + String addressStr) + throws Throwable { + + checkState( + !servingChannels.containsKey(channelID), + () -> "Duplicate handshake for channel: " + channelID); + DataViewReader dataViewReader = new DataViewReader(channelID, addressStr, dataListener); + if (initialCredit > 0) { + dataViewReader.addCredit(initialCredit); + } + long startTime = System.nanoTime(); + DataPartitionReadingView readingView = + dataStore.createDataPartitionReadingView( + new ReadingViewContext( + dataSetID, + mapID, + startSubIdx, + endSubIdx, + () -> dataListener.accept(dataViewReader), + backlogListener::accept, + failureHandler::accept)); + LOG.debug( + "(channel: {}) Reading handshake cost {} ms.", + channelID, + (System.nanoTime() - startTime) / 1000_000); + dataViewReader.setReadingView(readingView); + servingChannels.put(channelID, dataViewReader); + NetworkMetrics.numReadingFlows().inc(); + } + + public void addCredit(ChannelID channelID, int credit) { + DataViewReader dataViewReader = servingChannels.get(channelID); + if (dataViewReader == null) { + return; + } + int oldCredit = dataViewReader.getCredit(); + dataViewReader.addCredit(credit); + if (oldCredit == 0) { + dataViewReader.getDataListener().accept(dataViewReader); + } + } + + public void readFinish(ChannelID channelID) { + servingChannels.remove(channelID); + NetworkMetrics.numReadingFlows().dec(); + } + + public int getNumServingChannels() { + return servingChannels.size(); + } + + public void closeAbnormallyIfUnderServing(ChannelID channelID) { + DataViewReader dataViewReader = servingChannels.get(channelID); + if (dataViewReader != null) { + DataPartitionReadingView readingView = dataViewReader.getReadingView(); + readingView.onError( + new Exception( + String.format("(channel: %s) Channel closed abnormally", channelID))); + servingChannels.remove(channelID); + NetworkMetrics.numReadingFlows().dec(); + } + } + + public void releaseOnError(Throwable cause, ChannelID channelID) { + if (channelID == null) { + Set channelIDs = servingChannels.keySet(); + LOG.error( + "Release channels -- {} on error.", + channelIDs.stream().map(ChannelID::toString).collect(Collectors.joining(", ")), + cause); + for (DataViewReader dataViewReader : servingChannels.values()) { + CommonUtils.runQuietly(() -> dataViewReader.getReadingView().onError(cause), true); + } + NetworkMetrics.numReadingFlows().dec(getNumServingChannels()); + servingChannels.clear(); + } else if (servingChannels.containsKey(channelID)) { + LOG.error("Release channel -- {} on error.", channelID, cause); + CommonUtils.runQuietly( + () -> servingChannels.get(channelID).getReadingView().onError(cause), true); + servingChannels.remove(channelID); + NetworkMetrics.numReadingFlows().dec(); + } + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ShuffleReadClient.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ShuffleReadClient.java new file mode 100644 index 00000000..1d1dc9fb --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ShuffleReadClient.java @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseChannel; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadAddCredit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadData; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadHandshakeRequest; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; +import java.util.function.Consumer; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.randomBytes; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.currentProtocolVersion; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyExtraMessage; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyOffset; + +/** + * Reader client used to retrieve buffers from a remote shuffle worker to Flink TM. It talks with a + * shuffle worker using Netty connection by language of {@link TransferMessage}. Flow control is + * guaranteed by a credit based mechanism. The whole process of communication between {@link + * ShuffleReadClient} and shuffle worker could be described as below: + * + *
    + *
  • 1. Client opens connection and sends {@link ReadHandshakeRequest}, which contains number of + * initial credits -- indicates how many buffers it can accept; + *
  • 2. Server sends {@link ReadData} by the view of the number of credits from client side. + * {@link ReadData} contains backlog information -- indicates how many more buffers to send; + *
  • 3. Client allocates more buffers and sends {@link ReadAddCredit} to notify more credits; + *
  • 4. Repeat from step-2 to step-3; + *
  • 5. Client sends {@link CloseChannel} to server; + *
+ */ +public class ShuffleReadClient extends CreditListener { + + private static final Logger LOG = LoggerFactory.getLogger(ShuffleReadClient.class); + + /** Address of shuffle worker. */ + private final InetSocketAddress address; + + /** String representation the remote shuffle address. */ + private final String addressStr; + + /** Used to set up connections. */ + private final ConnectionManager connectionManager; + + /** Pool to allocate buffers. */ + private final TransferBufferPool bufferPool; + + /** Index of the first logic Subpartition to be read (inclusive). */ + private final int startSubIdx; + + /** Index of the last logic Subpartition to be read (inclusive). */ + private final int endSubIdx; + + /** Size of buffer to receive shuffle data. */ + private final int bufferSize; + + /** {@link DataSetID} of the reading. */ + private final DataSetID dataSetID; + + /** {@link MapPartitionID} of the reading. */ + private final MapPartitionID mapID; + + /** Identifier of the channel. */ + private final ChannelID channelID; + + /** String of channelID. */ + private final String channelIDStr; + + /** {@link ReadClientHandler} back this write-client. */ + private volatile ReadClientHandler readClientHandler; + + /** Listener to notify data received. */ + private final Consumer dataListener; + + /** Listener to notify failure. */ + private final Consumer failureListener; + + /** Netty channel. */ + private volatile Channel nettyChannel; + + /** {@link Throwable} when failure. */ + private Throwable cause; + + /** Whether the channel is closed. */ + private volatile boolean closed; + + public ShuffleReadClient( + InetSocketAddress address, + DataSetID dataSetID, + MapPartitionID mapID, + int startSubIdx, + int endSubIdx, + int bufferSize, + TransferBufferPool bufferPool, + ConnectionManager connectionManager, + Consumer dataListener, + Consumer failureListener) { + + checkArgument(address != null, "Must be not null."); + checkArgument(dataSetID != null, "Must be not null."); + checkArgument(mapID != null, "Must be not null."); + checkArgument(startSubIdx >= 0, "Must be positive value."); + checkArgument(endSubIdx >= startSubIdx, "Must be equal or larger than startSubIdx."); + checkArgument(bufferSize > 0, "Must be positive value."); + checkArgument(bufferPool != null, "Must be not null."); + checkArgument(connectionManager != null, "Must be not null."); + checkArgument(dataListener != null, "Must be not null."); + checkArgument(failureListener != null, "Must be not null."); + + this.address = address; + this.addressStr = address.toString(); + this.dataSetID = dataSetID; + this.mapID = mapID; + this.startSubIdx = startSubIdx; + this.endSubIdx = endSubIdx; + this.bufferSize = bufferSize; + this.bufferPool = bufferPool; + this.connectionManager = connectionManager; + this.dataListener = dataListener; + this.failureListener = failureListener; + this.channelID = new ChannelID(randomBytes(16)); + this.channelIDStr = channelID.toString(); + } + + /** Create Netty connection to remote. */ + public void connect() throws IOException, InterruptedException { + LOG.debug("(remote: {}, channel: {}) Connect channel.", address, channelIDStr); + nettyChannel = connectionManager.getChannel(channelID, address); + } + + /** Fire handshake. */ + public void open() throws IOException { + readClientHandler = nettyChannel.pipeline().get(ReadClientHandler.class); + if (readClientHandler == null) { + throw new IOException( + "The network connection is already released for channelID: " + channelIDStr); + } + readClientHandler.register(this); + + ReadHandshakeRequest handshake = + new ReadHandshakeRequest( + currentProtocolVersion(), + channelID, + dataSetID, + mapID, + startSubIdx, + endSubIdx, + 0, + bufferSize, + emptyOffset(), + emptyExtraMessage()); + LOG.debug("(remote: {}) Send {}.", nettyChannel.remoteAddress(), handshake); + nettyChannel + .writeAndFlush(handshake) + .addListener( + new ChannelFutureListenerImpl( + (ignored, throwable) -> exceptionCaught(throwable))); + } + + public boolean isOpened() { + return readClientHandler != null; + } + + /** Get identifier of the channel. */ + public ChannelID getChannelID() { + return channelID; + } + + /** Called by Netty thread. */ + public void dataReceived(ReadData readData) { + LOG.trace("(remote: {}, channel: {}) Received {}.", address, channelIDStr, readData); + if (closed) { + readData.getBuffer().release(); + return; + } + dataListener.accept(readData.getBuffer()); + } + + /** Called by Netty thread. */ + public void backlogReceived(int backlog) { + bufferPool.reserveBuffers(this, backlog); + } + + /** Called by Netty thread. */ + public void channelInactive() { + if (closed || cause != null) { + return; + } + cause = + new IOException( + "Shuffle failure on connection to " + + address + + " for channel of " + + channelIDStr, + new ClosedChannelException()); + failureListener.accept(cause); + } + + /** Called by Netty thread. */ + public void exceptionCaught(Throwable t) { + if (cause != null) { + return; + } + if (t != null) { + cause = + new IOException( + "Shuffle failure on connection to " + + address + + " for channel of " + + channelIDStr + + ", cause: " + + t.getMessage(), + t); + } else { + cause = + new IOException( + "Shuffle failure on connection to " + + address + + " for channel of " + + channelIDStr); + } + failureListener.accept(cause); + } + + /** Called by Netty thread to request buffer to receive data. */ + public ByteBuf requestBuffer() { + return bufferPool.requestBuffer(); + } + + public Throwable getCause() { + return cause; + } + + /** Closes Netty connection -- called from task thread. */ + public void close() throws IOException { + closed = true; + LOG.debug( + "(remote: {}, channel: {}) Close for (dataSetID: {}, mapID: {}, startSubIdx: {}, endSubIdx: {}).", + address, + channelIDStr, + dataSetID, + mapID, + startSubIdx, + endSubIdx); + if (nettyChannel != null) { + connectionManager.releaseChannel(address, channelID); + } + + if (readClientHandler != null) { + readClientHandler.unregister(this); + } + } + + // Called from both task thread and netty thread. + @Override + public void notifyAvailableCredits(int numCredits) { + ReadAddCredit addCredit = + new ReadAddCredit( + currentProtocolVersion(), channelID, numCredits, emptyExtraMessage()); + readClientHandler.notifyReadCredit(addCredit); + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ShuffleReadDataDecoder.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ShuffleReadDataDecoder.java new file mode 100644 index 00000000..2e1f386b --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ShuffleReadDataDecoder.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.ChannelID; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; + +import java.util.function.Function; +import java.util.function.Supplier; + +/** {@link TransferMessageDecoder} for {@link TransferMessage.ReadData}. */ +public class ShuffleReadDataDecoder extends TransferMessageDecoder { + + private ByteBuf headerByteBuf; + + private boolean headerInitialized; + + private final Function> bufferSuppliers; + + private ByteBuf body; + + private TransferMessage.ReadData shuffleReadData; + + /** @param bufferSuppliers Supplies buffers to accommodate network buffers. */ + public ShuffleReadDataDecoder(Function> bufferSuppliers) { + this.bufferSuppliers = bufferSuppliers; + this.isClosed = false; + } + + @Override + public void onNewMessageReceived(ChannelHandlerContext ctx, int msgId, int messageLength) { + super.onNewMessageReceived(ctx, msgId, messageLength); + headerByteBuf = ctx.alloc().directBuffer(messageLength); + } + + @Override + public DecodingResult onChannelRead(ByteBuf byteBuf) { + CommonUtils.checkState(!isClosed, "Decoder has been closed."); + + if (!headerInitialized) { + boolean accumulationFinished = + DecodingUtil.accumulate( + headerByteBuf, byteBuf, messageLength, headerByteBuf.readableBytes()); + if (!accumulationFinished) { + return DecodingResult.NOT_FINISHED; + } + shuffleReadData = TransferMessage.ReadData.initByHeader(headerByteBuf); + headerInitialized = true; + } + try { + if (body == null) { + body = bufferSuppliers.apply(shuffleReadData.getChannelID()).get(); + } + + if (body.capacity() < shuffleReadData.getBufferSize()) { + throw new IllegalArgumentException( + String.format( + "Buffer size of write data (%d) is bigger than that can be accepted (%d)", + shuffleReadData.getBufferSize(), body.capacity())); + } + + boolean accumulationFinished = + DecodingUtil.accumulate( + body, byteBuf, shuffleReadData.getBufferSize(), body.readableBytes()); + if (accumulationFinished) { + shuffleReadData.setBuffer(body); + DecodingResult res = DecodingResult.fullMessage(shuffleReadData); + headerInitialized = false; + body = null; + return res; + } else { + return DecodingResult.NOT_FINISHED; + } + } catch (Throwable t) { + throw new ReadingExceptionWithChannelID(shuffleReadData.getChannelID(), t); + } + } + + @Override + public void close() { + if (isClosed) { + return; + } + if (headerByteBuf != null) { + headerByteBuf.release(); + } + if (body != null) { + body.release(); + } + isClosed = true; + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ShuffleWriteClient.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ShuffleWriteClient.java new file mode 100644 index 00000000..ed3d660e --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ShuffleWriteClient.java @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.GuardedBy; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.randomBytes; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.currentProtocolVersion; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyExtraMessage; + +/** + * Writer client used to send buffers to a remote shuffle worker. It talks with a shuffle worker + * using Netty connection by language of {@link TransferMessage}. Flow control is guaranteed by a + * credit based mechanism. The whole process of communication between {@link ShuffleWriteClient} and + * shuffle worker could be described as below: + * + *
    + *
  • 1. Client opens connection and sends {@link TransferMessage.WriteHandshakeRequest}; + *
  • 2. Client sends {@link TransferMessage.WriteRegionStart}, which announces the start of a + * writing region and also indicates the number of buffers inside the region, we call it + * 'backlog'; + *
  • 3. Server sends {@link TransferMessage.WriteAddCredit} to announce how many more buffers it + * can accept; + *
  • 4. Client sends {@link TransferMessage.WriteData} based on server side 'credit'; + *
  • 5. Client sends {@link TransferMessage.WriteRegionFinish} to indicate writing finish of a + * region; + *
  • 6. Repeat from step-2 to step-5; + *
  • 7. Client sends {@link TransferMessage.WriteFinish} to indicate writing finish; + *
  • 8. Server sends {@link TransferMessage.WriteFinishCommit} to confirm the writing finish. + *
  • 9. Client sends {@link TransferMessage.CloseChannel} to server. + *
+ */ +public class ShuffleWriteClient { + + private static final Logger LOG = LoggerFactory.getLogger(ShuffleWriteClient.class); + + /** Address of shuffle worker. */ + private final InetSocketAddress address; + + /** String representation the remote shuffle address. */ + private final String addressStr; + + /** {@link MapPartitionID} of the writing. */ + private final MapPartitionID mapID; + + /** {@link JobID} of the writing. */ + private final JobID jobID; + + /** {@link DataSetID} of the writing. */ + private final DataSetID dataSetID; + + /** Number of subpartitions of the writing. */ + private final int numSubs; + + /** Defines the buffer size used by client. */ + private final int bufferSize; + + /** Target data partition type to write. */ + private final String dataPartitionFactoryName; + + /** Used to set up and release connections. */ + private final ConnectionManager connectionManager; + + /** Lock to protect {@link #currentCredit}. */ + private final Object lock = new Object(); + + /** Netty channel. */ + private Channel nettyChannel; + + /** Current view of the sum of credits received from remote shuffle worker for the channel. */ + @GuardedBy("lock") + private int currentCredit; + + /** Current writing region index, used for outdating credits. */ + @GuardedBy("lock") + private int currentRegionIdx; + + /** Identifier of the channel. */ + private final ChannelID channelID; + + /** String of channelID. */ + private final String channelIDStr; + + /** {@link WriteClientHandler} back this write-client. */ + private WriteClientHandler writeClientHandler; + + /** Whether task thread is waiting for more credits for sending. */ + private volatile boolean isWaitingForCredit; + + /** Whether task thread is waiting for {@link TransferMessage.WriteFinishCommit}. */ + private volatile boolean isWaitingForFinishCommit; + + /** Whether {@link TransferMessage.WriteFinishCommit} is already received. */ + private volatile boolean finishCommitted; + + /** {@link Throwable} when writing failure. */ + private volatile Throwable cause; + + /** If closed ever. */ + private volatile boolean closed; + + /** Callback when write channel. */ + private final ChannelFutureListenerImpl channelFutureListener = + new ChannelFutureListenerImpl((channelFuture, cause) -> handleFailure(cause)); + + /** + * @param address Address of shuffle worker. + * @param jobID {@link JobID} of the writing. + * @param dataSetID {@link DataSetID} of the writing. + * @param mapID {@link MapPartitionID} of the writing. + * @param numSubs Number of subpartitions of the writing. + * @param connectionManager Manages physical connections. + */ + public ShuffleWriteClient( + InetSocketAddress address, + JobID jobID, + DataSetID dataSetID, + MapPartitionID mapID, + int numSubs, + int bufferSize, + String dataPartitionFactoryName, + ConnectionManager connectionManager) { + + checkArgument(address != null, "Must be not null."); + checkArgument(jobID != null, "Must be not null."); + checkArgument(dataSetID != null, "Must be not null."); + checkArgument(mapID != null, "Must be not null."); + checkArgument(numSubs > 0, "Must be positive value."); + checkArgument(bufferSize > 0, "Must be positive value."); + checkArgument(dataPartitionFactoryName != null, "Must be not null."); + checkArgument(connectionManager != null, "Must be not null."); + + this.address = address; + this.addressStr = address.toString(); + this.mapID = mapID; + this.jobID = jobID; + this.dataSetID = dataSetID; + this.numSubs = numSubs; + this.bufferSize = bufferSize; + this.dataPartitionFactoryName = dataPartitionFactoryName; + this.connectionManager = connectionManager; + this.channelID = new ChannelID(randomBytes(16)); + this.channelIDStr = channelID.toString(); + } + + /** Initialize Netty connection and fire handshake. */ + public void open() throws IOException, InterruptedException { + LOG.debug("(remote: {}, channel: {}) Connect channel.", address, channelIDStr); + nettyChannel = connectionManager.getChannel(channelID, address); + writeClientHandler = nettyChannel.pipeline().get(WriteClientHandler.class); + if (writeClientHandler == null) { + throw new IOException( + "The network connection is already released for channelID: " + channelIDStr); + } + writeClientHandler.register(this); + + TransferMessage.WriteHandshakeRequest msg = + new TransferMessage.WriteHandshakeRequest( + currentProtocolVersion(), + channelID, + jobID, + dataSetID, + mapID, + numSubs, + bufferSize, + dataPartitionFactoryName, + emptyExtraMessage()); + LOG.debug("(remote: {}, channel: {}) Send {}.", address, channelIDStr, msg); + writeAndFlush(msg); + } + + /** Writes a piece of data to a subpartition. */ + public void write(ByteBuf byteBuf, int subIdx) throws InterruptedException { + synchronized (lock) { + try { + healthCheck(); + checkState( + currentCredit >= 0, + () -> + "BUG: credit smaller than 0: " + + currentCredit + + ", channelID=" + + channelIDStr); + if (currentCredit == 0) { + isWaitingForCredit = true; + while (currentCredit == 0 && cause == null && !closed) { + lock.wait(); + } + isWaitingForCredit = false; + healthCheck(); + checkState( + currentCredit > 0, + () -> + "BUG: credit should be positive, but got " + + currentCredit + + ", channelID=" + + channelIDStr); + } + } catch (Throwable t) { + byteBuf.release(); + throw t; + } + + int size = byteBuf.readableBytes(); + TransferMessage.WriteData writeData = + new TransferMessage.WriteData( + currentProtocolVersion(), + channelID, + byteBuf, + subIdx, + size, + false, + emptyExtraMessage()); + LOG.trace("(remote: {}, channel: {}) Send {}.", address, channelIDStr, writeData); + writeAndFlush(writeData); + currentCredit--; + } + } + + /** + * Indicates the start of a region. A region of buffers guarantees the records inside are + * completed. + * + * @param isBroadcast Whether it's a broadcast region. + */ + public void regionStart(boolean isBroadcast) { + synchronized (lock) { + healthCheck(); + TransferMessage.WriteRegionStart writeRegionStart = + new TransferMessage.WriteRegionStart( + currentProtocolVersion(), + channelID, + currentRegionIdx, + isBroadcast, + emptyExtraMessage()); + LOG.debug( + "(remote: {}, channel: {}) Send {}.", address, channelIDStr, writeRegionStart); + writeAndFlush(writeRegionStart); + } + } + + /** + * Indicates the finish of a region. A region is always bounded by a pair of region-start and + * region-finish. + */ + public void regionFinish() { + synchronized (lock) { + healthCheck(); + TransferMessage.WriteRegionFinish writeRegionFinish = + new TransferMessage.WriteRegionFinish( + currentProtocolVersion(), channelID, emptyExtraMessage()); + LOG.debug( + "(remote: {}, channel: {}) Region({}) finished, send {}.", + address, + channelIDStr, + currentRegionIdx, + writeRegionFinish); + currentRegionIdx++; + currentCredit = 0; + writeAndFlush(writeRegionFinish); + } + } + + /** Indicates the writing is finished. */ + public void finish() throws InterruptedException { + synchronized (lock) { + healthCheck(); + TransferMessage.WriteFinish writeFinish = + new TransferMessage.WriteFinish( + currentProtocolVersion(), channelID, emptyExtraMessage()); + LOG.debug("(remote: {}, channel: {}) Send {}.", address, channelIDStr, writeFinish); + writeAndFlush(writeFinish); + if (!finishCommitted) { + isWaitingForFinishCommit = true; + while (!finishCommitted && cause == null && !closed) { + lock.wait(); + } + isWaitingForFinishCommit = false; + healthCheck(); + checkState(finishCommitted, "finishCommitted should be true."); + } + } + } + + /** Closes Netty connection. */ + public void close() throws IOException { + synchronized (lock) { + closed = true; + lock.notifyAll(); + } + + LOG.debug("(remote: {}) Close for (dataSetID: {}, mapID: {}).", address, dataSetID, mapID); + if (writeClientHandler != null) { + writeClientHandler.unregister(this); + } + + if (nettyChannel != null) { + connectionManager.releaseChannel(address, channelID); + } + } + + /** Whether task thread is waiting for more credits for sending. */ + public boolean isWaitingForCredit() { + return isWaitingForCredit; + } + + /** Whether task thread is waiting for {@link TransferMessage.WriteFinishCommit}. */ + public boolean isWaitingForFinishCommit() { + return isWaitingForFinishCommit; + } + + /** Identifier of the channel. */ + public ChannelID getChannelID() { + return channelID; + } + + /** Get {@link Throwable} when writing failure. */ + public Throwable getCause() { + return cause; + } + + /** Called by Netty thread. */ + public void writeFinishCommitReceived(TransferMessage.WriteFinishCommit commit) { + LOG.debug("(remote: {}, channel: {}) Received {}.", address, channelIDStr, commit); + synchronized (lock) { + finishCommitted = true; + if (isWaitingForFinishCommit) { + lock.notifyAll(); + } + } + } + + /** Called by Netty thread. */ + public void creditReceived(TransferMessage.WriteAddCredit addCredit) { + LOG.trace("(remote: {}, channel: {}) Received {}.", address, channelIDStr, addCredit); + synchronized (lock) { + if (addCredit.getCredit() > 0 && addCredit.getRegionIdx() == currentRegionIdx) { + currentCredit += addCredit.getCredit(); + if (isWaitingForCredit) { + lock.notifyAll(); + } + } + } + } + + /** Called by Netty thread. */ + public void channelInactive() { + synchronized (lock) { + if (!closed) { + handleFailure(new ClosedChannelException()); + } + } + } + + /** Called by Netty thread. */ + public void exceptionCaught(Throwable t) { + synchronized (lock) { + handleFailure(t); + } + } + + private void healthCheck() { + if (cause != null) { + ExceptionUtils.rethrowAsRuntimeException(cause); + } + if (closed) { + throw new IllegalStateException("Write client is already cancelled/closed."); + } + } + + private void writeAndFlush(Object obj) { + nettyChannel.writeAndFlush(obj).addListener(channelFutureListener); + } + + private void handleFailure(Throwable t) { + synchronized (lock) { + if (cause != null) { + return; + } + if (t != null) { + cause = + new IOException( + "Shuffle failure on connection to " + + address + + " for channel of " + + channelIDStr, + t); + } else { + cause = + new Exception( + "Shuffle failure on connection to " + + address + + " for channel of " + + channelIDStr); + } + LOG.error("(remote: {}, channel: {}) Shuffle failure.", address, channelIDStr, cause); + lock.notifyAll(); + } + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ShuffleWriteDataDecoder.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ShuffleWriteDataDecoder.java new file mode 100644 index 00000000..bac1028f --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/ShuffleWriteDataDecoder.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteData; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; + +import java.util.function.Function; +import java.util.function.Supplier; + +/** {@link TransferMessageDecoder} for {@link TransferMessage.WriteData}. */ +public class ShuffleWriteDataDecoder extends TransferMessageDecoder { + + private ByteBuf headerByteBuf; + + private boolean headerInitialized; + + private final Function> bufferSuppliers; + + private ByteBuf body; + + private WriteData shuffleWriteData; + + /** @param bufferSuppliers Supplies buffers to accommodate network buffers. */ + public ShuffleWriteDataDecoder(Function> bufferSuppliers) { + this.bufferSuppliers = bufferSuppliers; + this.isClosed = false; + } + + @Override + public void onNewMessageReceived(ChannelHandlerContext ctx, int msgId, int messageLength) { + super.onNewMessageReceived(ctx, msgId, messageLength); + headerByteBuf = ctx.alloc().directBuffer(messageLength); + } + + @Override + public DecodingResult onChannelRead(ByteBuf byteBuf) { + CommonUtils.checkState(!isClosed, "Decoder has been closed."); + + if (!headerInitialized) { + boolean accumulationFinished = + DecodingUtil.accumulate( + headerByteBuf, byteBuf, messageLength, headerByteBuf.readableBytes()); + if (!accumulationFinished) { + return DecodingResult.NOT_FINISHED; + } + shuffleWriteData = WriteData.initByHeader(headerByteBuf); + headerInitialized = true; + } + + try { + if (body == null) { + body = bufferSuppliers.apply(shuffleWriteData.getChannelID()).get(); + } + + if (body.capacity() < shuffleWriteData.getBufferSize()) { + throw new IllegalArgumentException( + String.format( + "Buffer size of write data (%d) is bigger than that can be accepted (%d)", + shuffleWriteData.getBufferSize(), body.capacity())); + } + + boolean accumulationFinished = + DecodingUtil.accumulate( + body, byteBuf, shuffleWriteData.getBufferSize(), body.readableBytes()); + + if (accumulationFinished) { + shuffleWriteData.setBuffer(body); + DecodingResult res = DecodingResult.fullMessage(shuffleWriteData); + headerInitialized = false; + body = null; + return res; + } else { + return DecodingResult.NOT_FINISHED; + } + } catch (Throwable t) { + throw new WritingExceptionWithChannelID(shuffleWriteData.getChannelID(), t); + } + } + + @Override + public void close() { + if (isClosed) { + return; + } + if (headerByteBuf != null) { + headerByteBuf.release(); + } + if (body != null) { + body.release(); + } + isClosed = true; + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/TransferBufferPool.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/TransferBufferPool.java new file mode 100644 index 00000000..6be3fb4f --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/TransferBufferPool.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.memory.BufferRecycler; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import javax.annotation.concurrent.GuardedBy; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Queue; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** A buffer pool which will dispatch buffers to all {@link CreditListener}s. */ +public class TransferBufferPool implements BufferRecycler { + + private static final int MIN_CREDITS_TO_NOTIFY = 2; + + private final Object lock = new Object(); + + private final Queue buffers = new ArrayDeque<>(); + + @GuardedBy("lock") + private final Queue listeners = new ArrayDeque<>(); + + @GuardedBy("lock") + private int numAvailableBuffers; + + @GuardedBy("lock") + private boolean isDestroyed; + + public TransferBufferPool(Collection initialBuffers) { + synchronized (lock) { + buffers.addAll(initialBuffers); + numAvailableBuffers += initialBuffers.size(); + } + } + + /** Requests a data transmitting unit. */ + public ByteBuf requestBuffer() { + synchronized (lock) { + checkState(!isDestroyed, "Buffer pool has been destroyed."); + + return buffers.poll(); + } + } + + /** Adds an available buffer to this buffer pool. */ + public void addBuffers(List byteBufs) { + List creditAssignments; + synchronized (lock) { + if (isDestroyed) { + byteBufs.forEach(ByteBuf::release); + return; + } + + buffers.addAll(byteBufs); + numAvailableBuffers += byteBufs.size(); + creditAssignments = dispatchReservedCredits(); + } + for (CreditAssignment creditAssignment : creditAssignments) { + creditAssignment + .getCreditListener() + .notifyAvailableCredits(creditAssignment.getNumCredits()); + } + } + + /** Tries to reserve buffers for the target {@link CreditListener}. */ + public void reserveBuffers(CreditListener creditListener, int numRequiredBuffers) { + int numCredits; + CreditListener listener = null; + synchronized (lock) { + if (isDestroyed) { + throw new IllegalStateException("Buffer pool has been destroyed."); + } + + if (numRequiredBuffers > numAvailableBuffers) { + creditListener.increaseNumCreditsNeeded(numRequiredBuffers - numAvailableBuffers); + } + + if (!creditListener.isRegistered() && creditListener.getNumCreditsNeeded() > 0) { + listeners.add(creditListener); + creditListener.setRegistered(true); + } + + numCredits = Math.min(numAvailableBuffers, numRequiredBuffers); + if (numCredits > 0) { + numAvailableBuffers -= numCredits; + listener = creditListener; + } + } + if (listener != null) { + listener.notifyAvailableCredits(numCredits); + } + } + + /** Returns the number of available buffers. */ + public int numBuffers() { + synchronized (lock) { + return buffers.size(); + } + } + + /** Destroys buffer pool. */ + public void destroy() { + synchronized (lock) { + isDestroyed = true; + listeners.clear(); + buffers.forEach(ByteBuf::release); + buffers.clear(); + } + } + + /** Returns true if this buffer pool has been destroyed. */ + public boolean isDestroyed() { + synchronized (lock) { + return isDestroyed; + } + } + + @Override + public void recycle(ByteBuffer buffer) { + List creditAssignments; + synchronized (lock) { + // unmanaged memory no need to recycle, currently it is used only by tests + if (isDestroyed) { + return; + } + + buffers.add(new Buffer(buffer, this, 0)); + ++numAvailableBuffers; + creditAssignments = dispatchReservedCredits(); + } + for (CreditAssignment creditAssignment : creditAssignments) { + creditAssignment + .getCreditListener() + .notifyAvailableCredits(creditAssignment.getNumCredits()); + } + } + + private int assignCredits(CreditListener creditListener) { + assert Thread.holdsLock(lock); + + if (creditListener == null) { + return 0; + } + + int numCredits = Math.min(creditListener.getNumCreditsNeeded(), numAvailableBuffers); + if (numCredits > 0) { + creditListener.decreaseNumCreditsNeeded(numCredits); + numAvailableBuffers -= numCredits; + } + + if (creditListener.getNumCreditsNeeded() > 0) { + listeners.add(creditListener); + } else { + creditListener.setRegistered(false); + } + return numCredits; + } + + private List dispatchReservedCredits() { + assert Thread.holdsLock(lock); + + if (numAvailableBuffers < MIN_CREDITS_TO_NOTIFY || listeners.size() <= 0) { + return Collections.emptyList(); + } + + List creditAssignments = new ArrayList<>(); + while (numAvailableBuffers > 0 && listeners.size() > 0) { + CreditListener creditListener = listeners.poll(); + int numCredits = assignCredits(creditListener); + if (numCredits > 0) { + creditAssignments.add(new CreditAssignment(numCredits, creditListener)); + } + } + return creditAssignments; + } + + private static class CreditAssignment { + + private final int numCredits; + private final CreditListener creditListener; + + CreditAssignment(int numCredits, CreditListener creditListener) { + CommonUtils.checkArgument(numCredits > 0, "Must be positive."); + CommonUtils.checkArgument(creditListener != null, "Must be not null."); + + this.numCredits = numCredits; + this.creditListener = creditListener; + } + + public int getNumCredits() { + return numCredits; + } + + public CreditListener getCreditListener() { + return creditListener; + } + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/TransferMessage.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/TransferMessage.java new file mode 100644 index 00000000..86c0d57d --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/TransferMessage.java @@ -0,0 +1,1398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufAllocator; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOutboundInvoker; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelPromise; + +import java.util.Arrays; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.currentProtocolVersion; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyExtraMessageBytes; +import static com.alibaba.flink.shuffle.common.utils.StringUtils.bytesToString; +import static com.alibaba.flink.shuffle.common.utils.StringUtils.stringToBytes; + +/** + * Communication protocol between shuffle worker and client. Extension defines specific fields and + * how they are serialized and deserialized. A message is wrapped in a 'frame' with a header defines + * frame length, magic number and message type. + */ +public abstract class TransferMessage { + + /** Length of frame header -- frame length (4) + magic number (4) + msg ID (1). */ + public static final int FRAME_HEADER_LENGTH = 4 + 4 + 1; + + public static final int MAGIC_NUMBER = 0xBADC0FEF; + + private static ByteBuf allocateBuffer( + ByteBufAllocator allocator, byte messageID, int contentLength) { + + checkArgument(contentLength <= Integer.MAX_VALUE - FRAME_HEADER_LENGTH); + + ByteBuf buffer = allocator.directBuffer(FRAME_HEADER_LENGTH + contentLength); + buffer.writeInt(FRAME_HEADER_LENGTH + contentLength); + buffer.writeInt(MAGIC_NUMBER); + buffer.writeByte(messageID); + + return buffer; + } + + /** Content length of the frame other than header. */ + public abstract int getContentLength(); + + /** Method to define message serialization. */ + public abstract void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator); + + /** Feedback when error on server. */ + public static class ErrorResponse extends TransferMessage { + + public static final byte ID = -1; + + private final int version; + + private final ChannelID channelID; + + private final byte[] errorMessageBytes; + + private final byte[] extraInfo; + + public ErrorResponse( + int version, ChannelID channelID, byte[] errorMessageBytes, String extraInfo) { + this.version = version; + this.channelID = channelID; + this.errorMessageBytes = errorMessageBytes; + this.extraInfo = stringToBytes(extraInfo); + } + + public static ErrorResponse readFrom(ByteBuf byteBuf) { + int version = byteBuf.readInt(); + ChannelID channelID = ChannelID.readFrom(byteBuf); + int errorMessageLength = byteBuf.readInt(); + byte[] errorMessageBytes = new byte[errorMessageLength]; + byteBuf.readBytes(errorMessageBytes); + int extraInfoLen = byteBuf.readInt(); + byte[] extraInfoBytes = new byte[extraInfoLen]; + byteBuf.readBytes(extraInfoBytes); + String extraInfo = bytesToString(extraInfoBytes); + return new ErrorResponse(version, channelID, errorMessageBytes, extraInfo); + } + + @Override + public int getContentLength() { + return 4 + + channelID.getFootprint() + + 4 + + errorMessageBytes.length + + 4 + + extraInfo.length; + } + + @Override + public void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) { + + ByteBuf byteBuf = allocateBuffer(allocator, ID, getContentLength()); + byteBuf.writeInt(version); + channelID.writeTo(byteBuf); + byteBuf.writeInt(errorMessageBytes.length); + byteBuf.writeBytes(errorMessageBytes); + byteBuf.writeInt(extraInfo.length); + byteBuf.writeBytes(extraInfo); + out.write(byteBuf, promise); + } + + public int getVersion() { + return version; + } + + public ChannelID getChannelID() { + return channelID; + } + + public byte[] getErrorMessageBytes() { + return errorMessageBytes; + } + + public String getExtraInfo() { + return bytesToString(extraInfo); + } + + @Override + public String toString() { + return String.format( + "ErrorResponse{%d, %s, msg=%s, extraInfo=%s}", + version, channelID, new String(errorMessageBytes), bytesToString(extraInfo)); + } + } + + /** Handshake message send by client when start shuffle write. */ + public static class WriteHandshakeRequest extends TransferMessage { + + public static final byte ID = 0; + + private final int version; + + private final ChannelID channelID; + + private final JobID jobID; + + private final DataSetID dataSetID; + + private final MapPartitionID mapID; + + private final int numSubs; + + private final int bufferSize; + + // Specify the factory name of data partition type + private final byte[] dataPartitionType; + + private final byte[] extraInfo; + + public WriteHandshakeRequest( + int version, + ChannelID channelID, + JobID jobID, + DataSetID dataSetID, + MapPartitionID mapPartitionID, + int numSubs, + int bufferSize, + String dataPartitionType, + String extraInfo) { + + this.version = version; + this.channelID = channelID; + this.jobID = jobID; + this.dataSetID = dataSetID; + this.mapID = mapPartitionID; + this.numSubs = numSubs; + this.bufferSize = bufferSize; + this.dataPartitionType = stringToBytes(dataPartitionType); + this.extraInfo = stringToBytes(extraInfo); + } + + public static WriteHandshakeRequest readFrom(ByteBuf byteBuf) { + int version = byteBuf.readInt(); + ChannelID channelID = ChannelID.readFrom(byteBuf); + JobID jobID = JobID.readFrom(byteBuf); + DataSetID dataSetID = DataSetID.readFrom(byteBuf); + MapPartitionID mapID = MapPartitionID.readFrom(byteBuf); + int numSubs = byteBuf.readInt(); + int bufferSize = byteBuf.readInt(); + int partitionTypeLen = byteBuf.readInt(); + byte[] dataPartitionTypeBytes = new byte[partitionTypeLen]; + byteBuf.readBytes(dataPartitionTypeBytes); + String dataPartitionType = bytesToString(dataPartitionTypeBytes); + int extraInfoLen = byteBuf.readInt(); + byte[] extraInfoBytes = new byte[extraInfoLen]; + byteBuf.readBytes(extraInfoBytes); + String extraInfo = bytesToString(extraInfoBytes); + return new WriteHandshakeRequest( + version, + channelID, + jobID, + dataSetID, + mapID, + numSubs, + bufferSize, + dataPartitionType, + extraInfo); + } + + @Override + public int getContentLength() { + return 4 + + channelID.getFootprint() + + jobID.getFootprint() + + dataSetID.getFootprint() + + mapID.getFootprint() + + 4 + + 4 + + 4 + + dataPartitionType.length + + 4 + + extraInfo.length; + } + + @Override + public void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) { + + ByteBuf byteBuf = allocateBuffer(allocator, ID, getContentLength()); + byteBuf.writeInt(version); + channelID.writeTo(byteBuf); + jobID.writeTo(byteBuf); + dataSetID.writeTo(byteBuf); + mapID.writeTo(byteBuf); + byteBuf.writeInt(numSubs); + byteBuf.writeInt(bufferSize); + byteBuf.writeInt(dataPartitionType.length); + byteBuf.writeBytes(dataPartitionType); + byteBuf.writeInt(extraInfo.length); + byteBuf.writeBytes(extraInfo); + out.write(byteBuf, promise); + } + + public int getVersion() { + return version; + } + + public ChannelID getChannelID() { + return channelID; + } + + public JobID getJobID() { + return jobID; + } + + public DataSetID getDataSetID() { + return dataSetID; + } + + public MapPartitionID getMapID() { + return mapID; + } + + public int getNumSubs() { + return numSubs; + } + + public int getBufferSize() { + return bufferSize; + } + + public String getDataPartitionType() { + return bytesToString(dataPartitionType); + } + + public String getExtraInfo() { + return bytesToString(extraInfo); + } + + @Override + public String toString() { + return String.format( + "WriteHandshakeRequest{%d, %s, %s, %s, %s, numSubs=%d, " + + "bufferSize=%d, dataPartitionType=%s, extraInfo=%s}", + version, + channelID, + jobID, + dataSetID, + mapID, + numSubs, + bufferSize, + bytesToString(dataPartitionType), + bytesToString(extraInfo)); + } + } + + /** Message send by server to announce more credits thus to accept more shuffle write data. */ + public static class WriteAddCredit extends TransferMessage { + + public static final byte ID = 1; + + private final int version; + + private final ChannelID channelID; + + private final int credit; + + private final int regionIdx; + + private final byte[] extraInfo; + + public WriteAddCredit( + int version, ChannelID channelID, int credit, int regionIdx, String extraInfo) { + this.version = version; + this.channelID = channelID; + this.credit = credit; + this.regionIdx = regionIdx; + this.extraInfo = stringToBytes(extraInfo); + } + + public static WriteAddCredit readFrom(ByteBuf byteBuf) { + int version = byteBuf.readInt(); + ChannelID channelID = ChannelID.readFrom(byteBuf); + int credit = byteBuf.readInt(); + int regionIdx = byteBuf.readInt(); + int extraInfoLen = byteBuf.readInt(); + byte[] extraInfoBytes = new byte[extraInfoLen]; + byteBuf.readBytes(extraInfoBytes); + String extraInfo = bytesToString(extraInfoBytes); + return new WriteAddCredit(version, channelID, credit, regionIdx, extraInfo); + } + + @Override + public int getContentLength() { + return 4 + channelID.getFootprint() + 4 + 4 + 4 + extraInfo.length; + } + + @Override + public void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) { + + ByteBuf byteBuf = allocateBuffer(allocator, ID, getContentLength()); + byteBuf.writeInt(version); + channelID.writeTo(byteBuf); + byteBuf.writeInt(credit); + byteBuf.writeInt(regionIdx); + byteBuf.writeInt(extraInfo.length); + byteBuf.writeBytes(extraInfo); + out.write(byteBuf, promise); + } + + public int getVersion() { + return version; + } + + public ChannelID getChannelID() { + return channelID; + } + + public int getCredit() { + return credit; + } + + public int getRegionIdx() { + return regionIdx; + } + + public String getExtraInfo() { + return bytesToString(extraInfo); + } + + @Override + public String toString() { + return String.format( + "WriteAddCredit{%d, %s, credit=%d, regionIdx=%d, extraInfo=%s}", + version, channelID, credit, regionIdx, bytesToString(extraInfo)); + } + } + + /** Shuffle write data send by client. */ + public static class WriteData extends TransferMessage { + + public static final byte ID = 2; + + private final int version; + + private final ChannelID channelID; + + private final int subIdx; + + private final int bufferSize; + + private final boolean isRegionFinish; + + private ByteBuf body; + + private final byte[] extraInfo; + + public WriteData( + int version, + ChannelID channelID, + ByteBuf body, + int subIdx, + int bufferSize, + boolean isRegionFinish, + String extraInfo) { + + this.version = version; + this.channelID = channelID; + this.body = body; + this.subIdx = subIdx; + this.bufferSize = bufferSize; + this.isRegionFinish = isRegionFinish; + this.extraInfo = stringToBytes(extraInfo); + } + + public static WriteData initByHeader(ByteBuf byteBuf) { + int version = byteBuf.readInt(); + ChannelID channelID = ChannelID.readFrom(byteBuf); + int subIdx = byteBuf.readInt(); + int bufferSize = byteBuf.readInt(); + boolean isRegionFinish = byteBuf.readBoolean(); + int extraInfoLen = byteBuf.readInt(); + byte[] extraInfoBytes = new byte[extraInfoLen]; + byteBuf.readBytes(extraInfoBytes); + String extraInfo = bytesToString(extraInfoBytes); + return new WriteData( + version, channelID, null, subIdx, bufferSize, isRegionFinish, extraInfo); + } + + @Override + public int getContentLength() { + throw new UnsupportedOperationException("Unsupported and should not be called."); + } + + private int footprintExceptData() { + return 4 + channelID.getFootprint() + 4 + 4 + 1 + 4 + extraInfo.length; + } + + @Override + public void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) { + + ByteBuf byteBuf = allocateBuffer(allocator, ID, footprintExceptData()); + byteBuf.writeInt(version); + channelID.writeTo(byteBuf); + byteBuf.writeInt(subIdx); + byteBuf.writeInt(bufferSize); + byteBuf.writeBoolean(false); + byteBuf.writeInt(extraInfo.length); + byteBuf.writeBytes(extraInfo); + out.write(byteBuf); + out.write(body, promise); + } + + public int getVersion() { + return version; + } + + public ChannelID getChannelID() { + return channelID; + } + + public int getSubIdx() { + return subIdx; + } + + public ByteBuf getBuffer() { + return body; + } + + public void setBuffer(ByteBuf body) { + this.body = body; + } + + public int getBufferSize() { + return bufferSize; + } + + public String getExtraInfo() { + return bytesToString(extraInfo); + } + + public boolean isRegionFinish() { + return isRegionFinish; + } + + @Override + public String toString() { + return String.format( + "WriteData{%d, %s, bufferSize=%d, subIdx=%d, isRegionFinish=%s, extraInfo=%s}", + version, + channelID, + bufferSize, + subIdx, + isRegionFinish, + bytesToString(extraInfo)); + } + } + + /** Message to indicate the start of a region, which guarantees records inside are complete. */ + public static class WriteRegionStart extends TransferMessage { + + public static final byte ID = 3; + + private final int version; + + private final ChannelID channelID; + + private final int regionIdx; + + private final boolean isBroadcast; + + private final byte[] extraInfo; + + public WriteRegionStart( + int version, + ChannelID channelID, + int regionIdx, + boolean isBroadcast, + String extraInfo) { + this.version = version; + this.channelID = channelID; + this.regionIdx = regionIdx; + this.isBroadcast = isBroadcast; + this.extraInfo = stringToBytes(extraInfo); + } + + public static WriteRegionStart readFrom(ByteBuf byteBuf) { + int version = byteBuf.readInt(); + ChannelID channelID = ChannelID.readFrom(byteBuf); + int regionIdx = byteBuf.readInt(); + boolean isBroadcast = byteBuf.readBoolean(); + int extraInfoLen = byteBuf.readInt(); + byte[] extraInfoBytes = new byte[extraInfoLen]; + byteBuf.readBytes(extraInfoBytes); + String extraInfo = bytesToString(extraInfoBytes); + return new WriteRegionStart(version, channelID, regionIdx, isBroadcast, extraInfo); + } + + @Override + public int getContentLength() { + return 4 + channelID.getFootprint() + 4 + 1 + 4 + extraInfo.length; + } + + @Override + public void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) { + + ByteBuf byteBuf = allocateBuffer(allocator, ID, getContentLength()); + byteBuf.writeInt(version); + channelID.writeTo(byteBuf); + byteBuf.writeInt(regionIdx); + byteBuf.writeBoolean(isBroadcast); + byteBuf.writeInt(extraInfo.length); + byteBuf.writeBytes(extraInfo); + out.write(byteBuf, promise); + } + + public int getVersion() { + return version; + } + + public ChannelID getChannelID() { + return channelID; + } + + public int getRegionIdx() { + return regionIdx; + } + + public boolean isBroadcast() { + return isBroadcast; + } + + public String getExtraInfo() { + return bytesToString(extraInfo); + } + + @Override + public String toString() { + return String.format( + "WriteRegionStart{%d, %s, regionIdx=%d, broadcast=%b, extraInfo=%s}", + version, channelID, regionIdx, isBroadcast, bytesToString(extraInfo)); + } + } + + /** Message to indicate the finish of a region, which guarantees records inside are complete. */ + public static class WriteRegionFinish extends TransferMessage { + + public static final byte ID = 4; + + private final int version; + + private final ChannelID channelID; + + private final byte[] extraInfo; + + public WriteRegionFinish(int version, ChannelID channelID, String extraInfo) { + this.version = version; + this.channelID = channelID; + this.extraInfo = stringToBytes(extraInfo); + } + + public static WriteRegionFinish readFrom(ByteBuf byteBuf) { + int version = byteBuf.readInt(); + ChannelID channelID = ChannelID.readFrom(byteBuf); + int extraInfoLen = byteBuf.readInt(); + byte[] extraInfoBytes = new byte[extraInfoLen]; + byteBuf.readBytes(extraInfoBytes); + String extraInfo = bytesToString(extraInfoBytes); + return new WriteRegionFinish(version, channelID, extraInfo); + } + + @Override + public int getContentLength() { + return 4 + channelID.getFootprint() + 4 + extraInfo.length; + } + + @Override + public void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) { + + ByteBuf byteBuf = allocateBuffer(allocator, ID, getContentLength()); + byteBuf.writeInt(version); + channelID.writeTo(byteBuf); + byteBuf.writeInt(extraInfo.length); + byteBuf.writeBytes(extraInfo); + out.write(byteBuf, promise); + } + + public int getVersion() { + return version; + } + + public ChannelID getChannelID() { + return channelID; + } + + public String getExtraInfo() { + return bytesToString(extraInfo); + } + + @Override + public String toString() { + return String.format( + "WriteRegionFinish{%d, %s, extraInfo=%s}", + version, channelID, bytesToString(extraInfo)); + } + } + + /** Message indicates finish of a shuffle write. */ + public static class WriteFinish extends TransferMessage { + + public static final byte ID = 5; + + private final int version; + + private final ChannelID channelID; + + private final byte[] extraInfo; + + public WriteFinish(int version, ChannelID channelID, String extraInfo) { + this.version = version; + this.channelID = channelID; + this.extraInfo = stringToBytes(extraInfo); + } + + public static WriteFinish readFrom(ByteBuf byteBuf) { + int version = byteBuf.readInt(); + ChannelID channelID = ChannelID.readFrom(byteBuf); + int extraInfoLen = byteBuf.readInt(); + byte[] extraInfoBytes = new byte[extraInfoLen]; + byteBuf.readBytes(extraInfoBytes); + String extraInfo = bytesToString(extraInfoBytes); + return new WriteFinish(version, channelID, extraInfo); + } + + @Override + public int getContentLength() { + return 4 + channelID.getFootprint() + 4 + extraInfo.length; + } + + @Override + public void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) { + + ByteBuf byteBuf = allocateBuffer(allocator, ID, getContentLength()); + byteBuf.writeInt(version); + channelID.writeTo(byteBuf); + byteBuf.writeInt(extraInfo.length); + byteBuf.writeBytes(extraInfo); + out.write(byteBuf, promise); + } + + public int getVersion() { + return version; + } + + public ChannelID getChannelID() { + return channelID; + } + + public String getExtraInfo() { + return bytesToString(extraInfo); + } + + @Override + public String toString() { + return String.format( + "WriteFinish{%d, %s, extraInfo=%s}", + version, channelID, bytesToString(extraInfo)); + } + } + + /** + * Message send from server to confirm {@link WriteFinish}. A shuffle write process could be + * regarded as successful only when this message is received. + */ + public static class WriteFinishCommit extends TransferMessage { + + public static final byte ID = 6; + + private final int version; + + private final ChannelID channelID; + + private final byte[] extraInfo; + + public WriteFinishCommit(int version, ChannelID channelID, String extraInfo) { + this.version = version; + this.channelID = channelID; + this.extraInfo = stringToBytes(extraInfo); + } + + public static WriteFinishCommit readFrom(ByteBuf byteBuf) { + int version = byteBuf.readInt(); + ChannelID channelID = ChannelID.readFrom(byteBuf); + int extraInfoLen = byteBuf.readInt(); + byte[] extraInfoBytes = new byte[extraInfoLen]; + byteBuf.readBytes(extraInfoBytes); + String extraInfo = bytesToString(extraInfoBytes); + return new WriteFinishCommit(version, channelID, extraInfo); + } + + @Override + public int getContentLength() { + return 4 + channelID.getFootprint() + 4 + extraInfo.length; + } + + @Override + public void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) { + + ByteBuf byteBuf = allocateBuffer(allocator, ID, getContentLength()); + byteBuf.writeInt(version); + channelID.writeTo(byteBuf); + byteBuf.writeInt(extraInfo.length); + byteBuf.writeBytes(extraInfo); + out.write(byteBuf, promise); + } + + public int getVersion() { + return version; + } + + public ChannelID getChannelID() { + return channelID; + } + + public String getExtraInfo() { + return bytesToString(extraInfo); + } + + @Override + public String toString() { + return String.format( + "WriteFinishCommit{%d, %s, extraInfo=%s}", + version, channelID, bytesToString(extraInfo)); + } + } + + /** Handshake message send from client when start shuffle read. */ + public static class ReadHandshakeRequest extends TransferMessage { + + public static final byte ID = 7; + + private final int version; + + private final ChannelID channelID; + + private final DataSetID dataSetID; + + private final MapPartitionID mapID; + + private final int startSubIdx; + + private final int endSubIdx; + + private final int initialCredit; + + private final int bufferSize; + + private final long offset; + + private final byte[] extraInfo; + + public ReadHandshakeRequest( + int version, + ChannelID channelID, + DataSetID dataSetID, + MapPartitionID mapID, + int startSubIdx, + int endSubIdx, + int initialCredit, + int bufferSize, + long offset, + String extraInfo) { + + this.version = version; + this.channelID = channelID; + this.dataSetID = dataSetID; + this.mapID = mapID; + this.startSubIdx = startSubIdx; + this.endSubIdx = endSubIdx; + this.initialCredit = initialCredit; + this.bufferSize = bufferSize; + this.offset = offset; + this.extraInfo = stringToBytes(extraInfo); + } + + public static ReadHandshakeRequest readFrom(ByteBuf byteBuf) { + int version = byteBuf.readInt(); + ChannelID channelID = ChannelID.readFrom(byteBuf); + DataSetID dataSetID = DataSetID.readFrom(byteBuf); + MapPartitionID mapID = MapPartitionID.readFrom(byteBuf); + int startSubIdx = byteBuf.readInt(); + int endSubIdx = byteBuf.readInt(); + int initialCredit = byteBuf.readInt(); + int bufferSize = byteBuf.readInt(); + long offset = byteBuf.readLong(); + int extraInfoLen = byteBuf.readInt(); + byte[] extraInfoBytes = new byte[extraInfoLen]; + byteBuf.readBytes(extraInfoBytes); + String extraInfo = bytesToString(extraInfoBytes); + return new ReadHandshakeRequest( + version, + channelID, + dataSetID, + mapID, + startSubIdx, + endSubIdx, + initialCredit, + bufferSize, + offset, + extraInfo); + } + + @Override + public int getContentLength() { + return 4 + + channelID.getFootprint() + + dataSetID.getFootprint() + + mapID.getFootprint() + + 4 + + 4 + + 4 + + 4 + + 8 + + 4 + + extraInfo.length; + } + + @Override + public void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) { + + ByteBuf byteBuf = allocateBuffer(allocator, ID, getContentLength()); + byteBuf.writeInt(version); + channelID.writeTo(byteBuf); + dataSetID.writeTo(byteBuf); + mapID.writeTo(byteBuf); + byteBuf.writeInt(startSubIdx); + byteBuf.writeInt(endSubIdx); + byteBuf.writeInt(initialCredit); + byteBuf.writeInt(bufferSize); + byteBuf.writeLong(offset); + byteBuf.writeInt(extraInfo.length); + byteBuf.writeBytes(extraInfo); + out.write(byteBuf, promise); + } + + public int getVersion() { + return version; + } + + public ChannelID getChannelID() { + return channelID; + } + + public DataSetID getDataSetID() { + return dataSetID; + } + + public MapPartitionID getMapID() { + return mapID; + } + + public int getStartSubIdx() { + return startSubIdx; + } + + public int getEndSubIdx() { + return endSubIdx; + } + + public int getInitialCredit() { + return initialCredit; + } + + public int getBufferSize() { + return bufferSize; + } + + public long getOffset() { + return offset; + } + + public String getExtraInfo() { + return bytesToString(extraInfo); + } + + @Override + public String toString() { + return String.format( + "ReadHandshakeRequest{%d, %s, %s, %s, " + + "startSubIdx=%d, endSubIdx=%d, initialCredit=%d, bufferSize=%d, " + + "offset=%d, extraInfo=%s}", + version, + channelID, + dataSetID, + mapID, + startSubIdx, + endSubIdx, + initialCredit, + bufferSize, + offset, + bytesToString(extraInfo)); + } + } + + /** Message send from client to announce more credits to accept more data. */ + public static class ReadAddCredit extends TransferMessage { + + public static final byte ID = 8; + + private final int version; + + private final ChannelID channelID; + + private final int credit; + + private final byte[] extraInfo; + + public ReadAddCredit(int version, ChannelID channelID, int credit, String extraInfo) { + this.version = version; + this.channelID = channelID; + this.credit = credit; + this.extraInfo = stringToBytes(extraInfo); + } + + public static ReadAddCredit readFrom(ByteBuf byteBuf) { + int version = byteBuf.readInt(); + ChannelID channelID = ChannelID.readFrom(byteBuf); + int credit = byteBuf.readInt(); + int extraInfoLen = byteBuf.readInt(); + byte[] extraInfoBytes = new byte[extraInfoLen]; + byteBuf.readBytes(extraInfoBytes); + String extraInfo = bytesToString(extraInfoBytes); + return new ReadAddCredit(version, channelID, credit, extraInfo); + } + + @Override + public int getContentLength() { + return 4 + channelID.getFootprint() + 4 + 4 + extraInfo.length; + } + + @Override + public void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) { + + ByteBuf byteBuf = allocateBuffer(allocator, ID, getContentLength()); + byteBuf.writeInt(version); + channelID.writeTo(byteBuf); + byteBuf.writeInt(credit); + byteBuf.writeInt(extraInfo.length); + byteBuf.writeBytes(extraInfo); + out.write(byteBuf, promise); + } + + public int getVersion() { + return version; + } + + public ChannelID getChannelID() { + return channelID; + } + + public int getCredit() { + return credit; + } + + public String getExtraInfo() { + return bytesToString(extraInfo); + } + + @Override + public String toString() { + return String.format( + "ReadAddCredit{%d, %s, credit=%d, extraInfo=%s}", + version, channelID, credit, bytesToString(extraInfo)); + } + } + + /** Shuffle read data send from server. */ + public static class ReadData extends TransferMessage { + + public static final byte ID = 9; + + private final int version; + + private final ChannelID channelID; + + private final int backlog; + + private final int bufferSize; + + private final long offset; + + private ByteBuf body; + + private final byte[] extraInfo; + + public ReadData( + int version, + ChannelID channelID, + int backlog, + int bufferSize, + long offset, + ByteBuf body, + String extraInfo) { + this.version = version; + this.channelID = channelID; + this.backlog = backlog; + this.bufferSize = bufferSize; + this.offset = offset; + this.body = body; + this.extraInfo = stringToBytes(extraInfo); + } + + public static ReadData initByHeader(ByteBuf byteBuf) { + int version = byteBuf.readInt(); + ChannelID channelID = ChannelID.readFrom(byteBuf); + int backlog = byteBuf.readInt(); + int bufferSize = byteBuf.readInt(); + long offset = byteBuf.readLong(); + int extraInfoLen = byteBuf.readInt(); + byte[] extraInfoBytes = new byte[extraInfoLen]; + byteBuf.readBytes(extraInfoBytes); + String extraInfo = bytesToString(extraInfoBytes); + return new ReadData(version, channelID, backlog, bufferSize, offset, null, extraInfo); + } + + @Override + public int getContentLength() { + throw new UnsupportedOperationException("Unsupported and should not be called."); + } + + private int footprintExceptData() { + return 4 + channelID.getFootprint() + 4 + 4 + 8 + 4 + extraInfo.length; + } + + @Override + public void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) { + + ByteBuf byteBuf = allocateBuffer(allocator, ID, footprintExceptData()); + byteBuf.writeInt(version); + channelID.writeTo(byteBuf); + byteBuf.writeInt(backlog); + byteBuf.writeInt(bufferSize); + byteBuf.writeLong(offset); + byteBuf.writeInt(extraInfo.length); + byteBuf.writeBytes(extraInfo); + out.write(byteBuf); + out.write(body, promise); + } + + public int getVersion() { + return version; + } + + public ChannelID getChannelID() { + return channelID; + } + + public int getBacklog() { + return backlog; + } + + public ByteBuf getBuffer() { + return body; + } + + public void setBuffer(ByteBuf body) { + this.body = body; + } + + public int getBufferSize() { + return bufferSize; + } + + public long getOffset() { + return offset; + } + + public String getExtraInfo() { + return bytesToString(extraInfo); + } + + @Override + public String toString() { + return String.format( + "ReadData{%d, %s, size=%d, offset=%d, backlog=%d, extraInfo=%s}", + version, channelID, bufferSize, offset, backlog, bytesToString(extraInfo)); + } + } + + /** Message send from client to ask server to close logical channel. */ + public static class CloseChannel extends TransferMessage { + + public static final byte ID = 10; + + private final int version; + + private final ChannelID channelID; + + private final byte[] extraInfo; + + public CloseChannel(int version, ChannelID channelID, String extraInfo) { + this.version = version; + this.channelID = channelID; + this.extraInfo = stringToBytes(extraInfo); + } + + public static CloseChannel readFrom(ByteBuf byteBuf) { + int version = byteBuf.readInt(); + ChannelID channelID = ChannelID.readFrom(byteBuf); + int extraInfoLen = byteBuf.readInt(); + byte[] extraInfoBytes = new byte[extraInfoLen]; + byteBuf.readBytes(extraInfoBytes); + String extraInfo = bytesToString(extraInfoBytes); + return new CloseChannel(version, channelID, extraInfo); + } + + @Override + public int getContentLength() { + return 4 + channelID.getFootprint() + 4 + extraInfo.length; + } + + @Override + public void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) { + + ByteBuf byteBuf = allocateBuffer(allocator, ID, getContentLength()); + byteBuf.writeInt(version); + channelID.writeTo(byteBuf); + byteBuf.writeInt(extraInfo.length); + byteBuf.writeBytes(extraInfo); + out.write(byteBuf, promise); + } + + public int getVersion() { + return version; + } + + public ChannelID getChannelID() { + return channelID; + } + + public String getExtraInfo() { + return bytesToString(extraInfo); + } + + @Override + public String toString() { + return String.format( + "CloseChannel{%d, %s, extraInfo=%s}", + version, channelID, bytesToString(extraInfo)); + } + } + + /** Message send from client to ask server to close physical connection. */ + public static class CloseConnection extends TransferMessage { + + public static final byte ID = 11; + + private final int version; + + private final byte[] extraInfo; + + public CloseConnection() { + this.version = currentProtocolVersion(); + this.extraInfo = emptyExtraMessageBytes(); + } + + public static CloseConnection readFrom(ByteBuf byteBuf) { + byteBuf.readInt(); + int extraInfoLen = byteBuf.readInt(); + byte[] extraInfoBytes = new byte[extraInfoLen]; + byteBuf.readBytes(extraInfoBytes); + return new CloseConnection(); + } + + @Override + public int getContentLength() { + return 4 + 4 + extraInfo.length; + } + + @Override + public void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) { + + ByteBuf byteBuf = allocateBuffer(allocator, ID, getContentLength()); + byteBuf.writeInt(version); + byteBuf.writeInt(extraInfo.length); + byteBuf.writeBytes(extraInfo); + out.write(byteBuf, promise); + } + + public int getVersion() { + return version; + } + + public String getExtraInfo() { + return bytesToString(extraInfo); + } + + @Override + public String toString() { + return String.format("CloseConnection"); + } + } + + /** Heartbeat message send from server and client, thus to keep alive between each other. */ + public static class Heartbeat extends TransferMessage { + + public static final byte ID = 12; + + private final int version; + + private final byte[] extraInfo; + + public Heartbeat() { + this.version = currentProtocolVersion(); + this.extraInfo = emptyExtraMessageBytes(); + } + + public static Heartbeat readFrom(ByteBuf byteBuf) { + byteBuf.readInt(); + int extraInfoLen = byteBuf.readInt(); + byte[] extraInfoBytes = new byte[extraInfoLen]; + byteBuf.readBytes(extraInfoBytes); + return new Heartbeat(); + } + + @Override + public int getContentLength() { + return 4 + 4 + extraInfo.length; + } + + @Override + public void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) { + + ByteBuf byteBuf = allocateBuffer(allocator, ID, getContentLength()); + byteBuf.writeInt(version); + byteBuf.writeInt(extraInfo.length); + byteBuf.writeBytes(extraInfo); + out.write(byteBuf, promise); + } + + public int getVersion() { + return version; + } + + public String getExtraInfo() { + return bytesToString(extraInfo); + } + + @Override + public String toString() { + return String.format("Heartbeat"); + } + } + + /** Backlog announcement for the data sender to the data receiver. */ + public static class BacklogAnnouncement extends TransferMessage { + + public static final byte ID = 13; + + private final ChannelID channelID; + + private final int backlog; + + private final int version; + + private final byte[] extraInfo; + + public BacklogAnnouncement(ChannelID channelID, int backlog) { + checkArgument(channelID != null, "Must be not null."); + checkArgument(backlog > 0, "Must be positive."); + + this.channelID = channelID; + this.backlog = backlog; + this.version = currentProtocolVersion(); + this.extraInfo = emptyExtraMessageBytes(); + } + + @Override + public int getContentLength() { + return 4 + channelID.getFootprint() + 4 + 4 + extraInfo.length; + } + + @Override + public void write( + ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) { + + ByteBuf byteBuf = allocateBuffer(allocator, ID, getContentLength()); + byteBuf.writeInt(version); + channelID.writeTo(byteBuf); + byteBuf.writeInt(backlog); + byteBuf.writeInt(extraInfo.length); + byteBuf.writeBytes(extraInfo); + out.write(byteBuf, promise); + } + + public static BacklogAnnouncement readFrom(ByteBuf byteBuf) { + byteBuf.readInt(); + ChannelID channelID = ChannelID.readFrom(byteBuf); + int backlog = byteBuf.readInt(); + int extraInfoLen = byteBuf.readInt(); + byte[] extraInfoBytes = new byte[extraInfoLen]; + byteBuf.readBytes(extraInfoBytes); + return new BacklogAnnouncement(channelID, backlog); + } + + public ChannelID getChannelID() { + return channelID; + } + + public int getBacklog() { + return backlog; + } + + public int getVersion() { + return version; + } + + public byte[] getExtraInfo() { + return extraInfo; + } + + @Override + public String toString() { + return String.format( + "BacklogAnnouncement{channelID=%s, backlog=%d, version=%d, extraInfo=%s}", + channelID, backlog, version, Arrays.toString(extraInfo)); + } + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/TransferMessageDecoder.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/TransferMessageDecoder.java new file mode 100644 index 00000000..45824660 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/TransferMessageDecoder.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; + +import javax.annotation.Nullable; + +/** Decoding harness under {@link DecoderDelegate}. */ +public abstract class TransferMessageDecoder { + + /** ID of the message under decoding. */ + protected int msgId; + + /** Length of the message under decoding. */ + protected int messageLength; + + /** Whether the decoder is closed ever. */ + protected boolean isClosed; + + /** + * Notifies that a new message is to be decoded. + * + * @param ctx Channel context. + * @param msgId The type of the message to be decoded. + * @param messageLength The length of the message to be decoded. + */ + public void onNewMessageReceived(ChannelHandlerContext ctx, int msgId, int messageLength) { + this.msgId = msgId; + this.messageLength = messageLength; + } + + /** + * Notifies that more data is received to continue decoding. + * + * @param data The data received. + * @return The result of decoding received data. + */ + public abstract DecodingResult onChannelRead(ByteBuf data) throws Exception; + + /** Close this decoder and release relevant resource. */ + public abstract void close(); + + /** The result of decoding one {@link ByteBuf}. */ + public static class DecodingResult { + + public static final DecodingResult NOT_FINISHED = new DecodingResult(false, null); + + public static final DecodingResult UNKNOWN_MESSAGE = new DecodingResult(true, null); + + private final boolean finished; + + @Nullable private final TransferMessage message; + + private DecodingResult(boolean finished, @Nullable TransferMessage message) { + this.finished = finished; + this.message = message; + } + + public static DecodingResult fullMessage(TransferMessage message) { + return new DecodingResult(true, message); + } + + public boolean isFinished() { + return finished; + } + + @Nullable + public TransferMessage getMessage() { + return message; + } + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/TransferMessageEncoder.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/TransferMessageEncoder.java new file mode 100644 index 00000000..24caab73 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/TransferMessageEncoder.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOutboundHandlerAdapter; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelPromise; + +/** Message encoder -- the specific encoding logic is defined in {@link TransferMessage#write}. */ +public class TransferMessageEncoder extends ChannelOutboundHandlerAdapter { + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (!(msg instanceof TransferMessage)) { + throw new RuntimeException("Unexpected Netty message: " + msg.getClass().getName()); + } + ((TransferMessage) msg).write(ctx, promise, ctx.alloc()); + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/WriteClientHandler.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/WriteClientHandler.java new file mode 100644 index 00000000..4d22d9c5 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/WriteClientHandler.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ErrorResponse; +import com.alibaba.flink.shuffle.transfer.TransferMessage.Heartbeat; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteAddCredit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteFinishCommit; + +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.flink.shaded.netty4.io.netty.handler.timeout.IdleStateEvent; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.SocketAddress; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * A {@link ChannelInboundHandlerAdapter} shared by multiple {@link ShuffleWriteClient} for shuffle + * write. + */ +public class WriteClientHandler extends ChannelInboundHandlerAdapter { + + private static final Logger LOG = LoggerFactory.getLogger(WriteClientHandler.class); + + /** Heartbeat interval. */ + private final int heartbeatInterval; + + /** The sharing {@link ShuffleWriteClient}s by {@link ChannelID}s. */ + private final Map writeClientsByChannelID; + + /** {@link ScheduledFuture} for heartbeat. */ + private ScheduledFuture heartbeatFuture; + + /** Whether heartbeat future canceled. */ + private boolean heartbeatFutureCanceled; + + /** + * @param heartbeatInterval Heartbeat interval -- client & server send heartbeat with each other + * to confirm existence. + */ + public WriteClientHandler(int heartbeatInterval) { + this.heartbeatInterval = heartbeatInterval; + this.writeClientsByChannelID = new ConcurrentHashMap<>(); + this.heartbeatFutureCanceled = false; + } + + /** Register a {@link ShuffleWriteClient}. */ + public void register(ShuffleWriteClient shuffleWriteClient) { + writeClientsByChannelID.put(shuffleWriteClient.getChannelID(), shuffleWriteClient); + } + + /** Unregister a {@link ShuffleWriteClient}. */ + public void unregister(ShuffleWriteClient shuffleWriteClient) { + writeClientsByChannelID.remove(shuffleWriteClient.getChannelID()); + } + + /** Whether a {@link ChannelID} is registered ever. */ + public boolean isRegistered(ChannelID channelID) { + return writeClientsByChannelID.containsKey(channelID); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + if (heartbeatInterval > 0) { + heartbeatFuture = + ctx.executor() + .scheduleAtFixedRate( + () -> ctx.writeAndFlush(new Heartbeat()), + 0, + heartbeatInterval, + TimeUnit.SECONDS); + } + + super.channelActive(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + LOG.debug("({}) Connection inactive.", ctx.channel().remoteAddress()); + writeClientsByChannelID.values().forEach(ShuffleWriteClient::channelInactive); + + if (heartbeatFuture != null && !heartbeatFutureCanceled) { + heartbeatFuture.cancel(true); + heartbeatFutureCanceled = true; + } + super.channelInactive(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + LOG.debug("({}) Connection exception caught.", ctx.channel().remoteAddress(), cause); + writeClientsByChannelID.values().forEach(writeClient -> writeClient.exceptionCaught(cause)); + + if (heartbeatFuture != null && !heartbeatFutureCanceled) { + heartbeatFuture.cancel(true); + heartbeatFutureCanceled = true; + } + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + + ChannelID currentChannelID = null; + + try { + if (msg.getClass() == WriteAddCredit.class) { + WriteAddCredit addCredit = (WriteAddCredit) msg; + ChannelID channelID = addCredit.getChannelID(); + currentChannelID = channelID; + assertChannelExists(channelID); + writeClientsByChannelID.get(channelID).creditReceived(addCredit); + + } else if (msg.getClass() == WriteFinishCommit.class) { + WriteFinishCommit finishCommit = (WriteFinishCommit) msg; + ChannelID channelID = finishCommit.getChannelID(); + currentChannelID = channelID; + assertChannelExists(channelID); + writeClientsByChannelID.get(channelID).writeFinishCommitReceived(finishCommit); + + } else if (msg.getClass() == ErrorResponse.class) { + ErrorResponse errorRsp = (ErrorResponse) msg; + SocketAddress addr = ctx.channel().remoteAddress(); + LOG.debug("({}) Received {}.", addr, errorRsp); + ChannelID channelID = errorRsp.getChannelID(); + currentChannelID = channelID; + assertChannelExists(channelID); + String errorMsg = new String(errorRsp.getErrorMessageBytes()); + writeClientsByChannelID.get(channelID).exceptionCaught(new IOException(errorMsg)); + + } else { + ctx.fireChannelRead(msg); + } + } catch (Throwable t) { + SocketAddress address = ctx.channel().remoteAddress(); + LOG.debug("({}, ch: {}) Exception caught.", address, currentChannelID, t); + if (writeClientsByChannelID.containsKey(currentChannelID)) { + writeClientsByChannelID.get(currentChannelID).exceptionCaught(t); + } + } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof IdleStateEvent) { + IdleStateEvent event = (IdleStateEvent) evt; + LOG.debug( + "({}) Remote seems lost and connection idle -- {}.", + ctx.channel().remoteAddress(), + event.state()); + if (heartbeatInterval <= 0) { + return; + } + + CommonUtils.runQuietly( + () -> + exceptionCaught( + ctx, + new Exception("Connection idle, state is " + event.state())), + true); + } else { + ctx.fireUserEventTriggered(evt); + } + } + + private void assertChannelExists(ChannelID channelID) { + checkState( + writeClientsByChannelID.containsKey(channelID), + "Unexpected (might already unregistered) channelID -- " + channelID); + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/WriteServerHandler.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/WriteServerHandler.java new file mode 100644 index 00000000..f3eea552 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/WriteServerHandler.java @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.functions.BiConsumerWithException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.ExceptionUtils; +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseChannel; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseConnection; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ErrorResponse; +import com.alibaba.flink.shuffle.transfer.TransferMessage.Heartbeat; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteAddCredit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteData; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteFinish; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteFinishCommit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteHandshakeRequest; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteRegionFinish; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteRegionStart; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.SimpleChannelInboundHandler; +import org.apache.flink.shaded.netty4.io.netty.handler.timeout.IdleStateEvent; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.SocketAddress; +import java.nio.channels.ClosedChannelException; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.currentProtocolVersion; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyExtraMessage; + +/** A {@link ChannelInboundHandler} serves shuffle write process on server side. */ +public class WriteServerHandler extends SimpleChannelInboundHandler { + + private static final Logger LOG = LoggerFactory.getLogger(WriteServerHandler.class); + + /** Service logic underground. */ + private final WritingService writingService; + + /** Identifier of current channel under serving. */ + private ChannelID currentChannelID; + + /** Heartbeat interval. */ + private final int heartbeatInterval; + + /** {@link ScheduledFuture} for heartbeat. */ + private ScheduledFuture heartbeatFuture; + + /** If connection closed. */ + private boolean connectionClosed; + + /** + * @param dataStore Implementation of storage layer. + * @param heartbeatInterval Heartbeat interval in seconds. + */ + public WriteServerHandler(PartitionedDataStore dataStore, int heartbeatInterval) { + this.writingService = new WritingService(dataStore); + this.currentChannelID = null; + this.heartbeatInterval = heartbeatInterval; + this.connectionClosed = false; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + if (heartbeatInterval > 0) { + heartbeatFuture = + ctx.executor() + .scheduleAtFixedRate( + () -> ctx.writeAndFlush(new Heartbeat()), + 0, + heartbeatInterval, + TimeUnit.SECONDS); + } + super.channelActive(ctx); + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + super.channelRegistered(ctx); + NetworkMetrics.numWritingConnections().inc(); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + super.channelUnregistered(ctx); + NetworkMetrics.numWritingConnections().dec(); + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, TransferMessage msg) { + try { + onMessage(ctx, msg); + } catch (Throwable t) { + CommonUtils.runQuietly(() -> onInternalFailure(currentChannelID, ctx, t), true); + } + } + + private void onMessage(ChannelHandlerContext ctx, TransferMessage msg) throws Throwable { + Class msgClazz = msg.getClass(); + if (msgClazz == WriteHandshakeRequest.class) { + WriteHandshakeRequest handshakeReq = (WriteHandshakeRequest) msg; + SocketAddress address = ctx.channel().remoteAddress(); + ChannelID channelID = handshakeReq.getChannelID(); + LOG.debug("({}) Received {}.", address, msg); + + currentChannelID = channelID; + BiConsumer creditListener = getCreditListener(ctx, channelID); + Consumer failureListener = getFailureListener(ctx, channelID); + writingService.handshake( + channelID, + handshakeReq.getJobID(), + handshakeReq.getDataSetID(), + handshakeReq.getMapID(), + handshakeReq.getNumSubs(), + handshakeReq.getDataPartitionType(), + creditListener, + failureListener, + address.toString()); + + } else if (msgClazz == WriteData.class) { + WriteData writeData = (WriteData) msg; + SocketAddress address = ctx.channel().remoteAddress(); + LOG.trace("({}) Received {}.", address, writeData); + + ChannelID channelID = writeData.getChannelID(); + currentChannelID = channelID; + int subpartitionIndex = writeData.getSubIdx(); + ByteBuf buffer = writeData.getBuffer(); + writingService.write(channelID, subpartitionIndex, buffer); + + } else if (msgClazz == WriteRegionStart.class) { + WriteRegionStart regionStart = (WriteRegionStart) msg; + SocketAddress address = ctx.channel().remoteAddress(); + LOG.debug("({}) Received {}.", address, regionStart); + + ChannelID channelID = regionStart.getChannelID(); + currentChannelID = channelID; + writingService.regionStart( + channelID, regionStart.getRegionIdx(), regionStart.isBroadcast()); + + } else if (msgClazz == WriteRegionFinish.class) { + WriteRegionFinish regionFinish = (WriteRegionFinish) msg; + SocketAddress address = ctx.channel().remoteAddress(); + LOG.debug("({}) Received {}.", address, regionFinish); + + ChannelID channelID = regionFinish.getChannelID(); + currentChannelID = channelID; + writingService.regionFinish(channelID); + + } else if (msgClazz == WriteFinish.class) { + WriteFinish writeFinish = (WriteFinish) msg; + SocketAddress address = ctx.channel().remoteAddress(); + LOG.debug("({}) Received {}.", address, writeFinish); + + ChannelID channelID = writeFinish.getChannelID(); + currentChannelID = channelID; + writingService.writeFinish(channelID, getFinishCommitListener(ctx, channelID)); + + } else if (msgClazz == CloseChannel.class) { + CloseChannel closeChannel = (CloseChannel) msg; + SocketAddress address = ctx.channel().remoteAddress(); + LOG.debug("({}) Received {}.", address, closeChannel); + + ChannelID channelID = closeChannel.getChannelID(); + currentChannelID = channelID; + writingService.closeAbnormallyIfUnderServing(channelID); + + // Ensure both WritingService and ReadingService could receive this message. + ctx.fireChannelRead(msg); + + } else if (msgClazz == CloseConnection.class) { + SocketAddress address = ctx.channel().remoteAddress(); + CloseConnection closeConnection = (CloseConnection) msg; + LOG.info("({}) received {}.", address, closeConnection); + close(ctx, new ShuffleException("Receive connection close from client.")); + + // Ensure both WritingService and ReadingService could receive this message. + ctx.fireChannelRead(msg); + + } else { + ctx.fireChannelRead(msg); + } + } + + private BiConsumer getCreditListener( + ChannelHandlerContext ctx, ChannelID channelID) { + return (credit, regionIdx) -> + ctx.pipeline() + .fireUserEventTriggered(new AddCreditEvent(channelID, credit, regionIdx)); + } + + private Runnable getFinishCommitListener(ChannelHandlerContext ctx, ChannelID channelID) { + return () -> ctx.pipeline().fireUserEventTriggered(new WriteFinishCommitEvent(channelID)); + } + + private Consumer getFailureListener(ChannelHandlerContext ctx, ChannelID channelID) { + return t -> { + if (t instanceof ClosedChannelException) { + return; + } + ctx.pipeline().fireUserEventTriggered(new WritingFailureEvent(channelID, t)); + }; + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + Class msgClazz = evt.getClass(); + if (msgClazz == AddCreditEvent.class) { + AddCreditEvent addCreditEvt = (AddCreditEvent) evt; + SocketAddress address = ctx.channel().remoteAddress(); + + ChannelID channelID = addCreditEvt.channelID; + WriteAddCredit addCredit = + new WriteAddCredit( + currentProtocolVersion(), + channelID, + addCreditEvt.credit, + addCreditEvt.regionIdx, + emptyExtraMessage()); + LOG.trace("({}) Send {}.", address, addCredit); + writeAndFlush(ctx, addCredit); + + } else if (msgClazz == WriteFinishCommitEvent.class) { + WriteFinishCommitEvent finishCommitEvt = (WriteFinishCommitEvent) evt; + SocketAddress address = ctx.channel().remoteAddress(); + + ChannelID channelID = finishCommitEvt.channelID; + WriteFinishCommit finishCommit = + new WriteFinishCommit(currentProtocolVersion(), channelID, emptyExtraMessage()); + LOG.debug("({}) Send {}.", address, finishCommit); + writeAndFlush(ctx, finishCommit); + + } else if (msgClazz == WritingFailureEvent.class) { + WritingFailureEvent errRspEvt = (WritingFailureEvent) evt; + SocketAddress address = ctx.channel().remoteAddress(); + LOG.error("({}) Received {}.", address, errRspEvt); + ChannelID channelID = errRspEvt.channelID; + CommonUtils.runQuietly(() -> onInternalFailure(channelID, ctx, errRspEvt.cause), true); + + } else if (evt instanceof IdleStateEvent) { + IdleStateEvent event = (IdleStateEvent) evt; + LOG.debug( + "({}) Remote seems lost and connection idle -- {}.", + ctx.channel().remoteAddress(), + event.state()); + if (heartbeatInterval <= 0) { + return; + } + + CommonUtils.runQuietly( + () -> close(ctx, new ShuffleException("Heartbeat timeout.")), true); + } else { + ctx.fireUserEventTriggered(evt); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + close(ctx, new ShuffleException("Channel inactive.")); + super.channelInactive(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + close(ctx, cause); + } + + public WritingService getWritingService() { + return writingService; + } + + private void writeAndFlush(ChannelHandlerContext ctx, Object obj) { + BiConsumerWithException errorHandler = + (channelFuture, cause) -> { + LOG.error( + "Shuffle write Netty failure -- failed to send {}, cause: {}.", + obj, + cause.getClass().getSimpleName()); + close(ctx, cause); + }; + ChannelFutureListenerImpl listener = new ChannelFutureListenerImpl(errorHandler); + ctx.writeAndFlush(obj).addListener(listener); + } + + private void onInternalFailure(ChannelID channelID, ChannelHandlerContext ctx, Throwable t) { + checkNotNull(channelID); + byte[] errorMessageBytes = ExceptionUtils.summaryErrorMessageStack(t).getBytes(); + ErrorResponse errRsp = + new ErrorResponse( + currentProtocolVersion(), + channelID, + errorMessageBytes, + emptyExtraMessage()); + ctx.writeAndFlush(errRsp).addListener(new CloseChannelWhenFailure()); + writingService.releaseOnError(t, channelID); + } + + // This method is invoked when: + // 1. Received CloseConnection from client; + // 2. Network error -- + // a. sending message failure; + // b. connection inactive; + // c. connection exception caught; + // + // This method does below things: + // 1. Triggering errors to corresponding logical channels; + // 2. Stopping the heartbeat to client; + // 3. Closing physical connection; + private void close(ChannelHandlerContext ctx, Throwable throwable) { + if (writingService.getNumServingChannels() > 0) { + writingService.releaseOnError( + throwable != null ? throwable : new ClosedChannelException(), null); + } + + if (heartbeatFuture != null) { + heartbeatFuture.cancel(true); + } + ctx.channel().close(); + connectionClosed = true; + } + + private static class AddCreditEvent { + + ChannelID channelID; + + int credit; + + int regionIdx; + + AddCreditEvent(ChannelID channelID, int credit, int regionIdx) { + this.channelID = channelID; + this.credit = credit; + this.regionIdx = regionIdx; + } + + @Override + public String toString() { + return String.format( + "AddCreditEvent [credit: %d, channelID: %s, regionIdx: %d]", + credit, channelID, regionIdx); + } + } + + private static class WriteFinishCommitEvent { + + ChannelID channelID; + + WriteFinishCommitEvent(ChannelID channelID) { + this.channelID = channelID; + } + + @Override + public String toString() { + return String.format("WriteFinishCommitEvent [channelID: %s]", channelID); + } + } + + static class WritingFailureEvent { + + ChannelID channelID; + + Throwable cause; + + WritingFailureEvent(ChannelID channelID, Throwable cause) { + this.cause = cause; + this.channelID = channelID; + } + + @Override + public String toString() { + return String.format( + "WritingFailureEvent [channelID: %s, cause: %s]", + channelID, cause.getMessage()); + } + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/WritingExceptionWithChannelID.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/WritingExceptionWithChannelID.java new file mode 100644 index 00000000..ccaa5dff --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/WritingExceptionWithChannelID.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.core.ids.ChannelID; + +/** A {@link Throwable} with a related {@link ChannelID} while writing on shuffle worker. */ +public class WritingExceptionWithChannelID extends ShuffleException { + + private static final long serialVersionUID = 1567641801077683004L; + + private final ChannelID channelID; + + public WritingExceptionWithChannelID(ChannelID channelID, Throwable t) { + super(t); + this.channelID = channelID; + } + + public ChannelID getChannelID() { + return channelID; + } +} diff --git a/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/WritingService.java b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/WritingService.java new file mode 100644 index 00000000..3676d011 --- /dev/null +++ b/shuffle-transfer/src/main/java/com/alibaba/flink/shuffle/transfer/WritingService.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWritingView; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.core.storage.WritingViewContext; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** + * Harness used to write data to storage. It performs shuffle write by {@link PartitionedDataStore}. + * The lifecycle is the same with a Netty {@link ChannelInboundHandler} instance. + */ +public class WritingService { + + private static final Logger LOG = LoggerFactory.getLogger(WritingService.class); + + private final PartitionedDataStore dataStore; + + private final Map servingChannels; + + public WritingService(PartitionedDataStore dataStore) { + this.dataStore = dataStore; + this.servingChannels = new HashMap<>(); + } + + public void handshake( + ChannelID channelID, + JobID jobID, + DataSetID dataSetID, + MapPartitionID mapID, + int numSubs, + String dataPartitionFactory, + BiConsumer creditListener, + Consumer failureListener, + String addressStr) + throws Throwable { + + checkState( + !servingChannels.containsKey(channelID), + () -> "Duplicate handshake for channel: " + channelID); + long startTime = System.nanoTime(); + DataPartitionWritingView writingView = + dataStore.createDataPartitionWritingView( + new WritingViewContext( + jobID, + dataSetID, + mapID, + mapID, + numSubs, + dataPartitionFactory, + creditListener::accept, + failureListener::accept)); + LOG.debug( + "(channel: {}) Writing handshake cost {} ms.", + channelID, + (System.nanoTime() - startTime) / 1000_000); + servingChannels.put(channelID, new DataViewWriter(writingView, addressStr)); + NetworkMetrics.numWritingFlows().inc(); + } + + public void write(ChannelID channelID, int subIdx, ByteBuf byteBuf) { + DataViewWriter dataViewWriter = servingChannels.get(channelID); + if (dataViewWriter == null) { + byteBuf.release(); + throw new IllegalStateException("Writing channel has been released -- " + channelID); + } + ReducePartitionID reduceID = new ReducePartitionID(subIdx); + NetworkMetrics.numBytesWritingThroughput().mark(byteBuf.readableBytes()); + dataViewWriter.getWritingView().onBuffer((Buffer) byteBuf, reduceID); + } + + public Supplier getBufferSupplier(ChannelID channelID) { + return () -> { + checkState(servingChannels.containsKey(channelID), "Channel is not under serving."); + return servingChannels.get(channelID).getWritingView().getBufferSupplier().pollBuffer(); + }; + } + + public void regionStart(ChannelID channelID, int regionIdx, boolean isBroadcast) { + DataViewWriter dataViewWriter = servingChannels.get(channelID); + checkState( + dataViewWriter != null, + () -> String.format("Write-channel %s is not under serving.", channelID)); + dataViewWriter.getWritingView().regionStarted(regionIdx, isBroadcast); + } + + public void regionFinish(ChannelID channelID) { + DataViewWriter dataViewWriter = servingChannels.get(channelID); + checkState( + dataViewWriter != null, + () -> String.format("Write-channel %s is not under serving.", channelID)); + dataViewWriter.getWritingView().regionFinished(); + } + + public void writeFinish(ChannelID channelID, Runnable committedListener) { + DataViewWriter dataViewWriter = servingChannels.get(channelID); + checkState( + dataViewWriter != null, + () -> String.format("Write-channel %s is not under serving.", channelID)); + dataViewWriter.getWritingView().finish(committedListener::run); + servingChannels.remove(channelID); + NetworkMetrics.numWritingFlows().dec(); + } + + public int getNumServingChannels() { + return servingChannels.size(); + } + + public void closeAbnormallyIfUnderServing(ChannelID channelID) { + DataViewWriter dataViewWriter = servingChannels.get(channelID); + if (dataViewWriter != null) { + dataViewWriter + .getWritingView() + .onError( + new Exception( + String.format( + "(channel: %s) Channel closed abnormally.", + channelID))); + servingChannels.remove(channelID); + NetworkMetrics.numWritingFlows().dec(); + } + } + + public void releaseOnError(Throwable cause, ChannelID channelID) { + if (channelID == null) { + Set channelIDs = servingChannels.keySet(); + LOG.error( + "Release channels -- {} on error.", + channelIDs.stream().map(ChannelID::toString).collect(Collectors.joining(", ")), + cause); + for (DataViewWriter dataViewWriter : servingChannels.values()) { + CommonUtils.runQuietly(() -> dataViewWriter.getWritingView().onError(cause), true); + } + NetworkMetrics.numWritingFlows().dec(getNumServingChannels()); + servingChannels.clear(); + } else if (servingChannels.containsKey(channelID)) { + LOG.error("Release channel -- {} on error. ", channelID, cause); + CommonUtils.runQuietly( + () -> servingChannels.get(channelID).getWritingView().onError(cause), true); + servingChannels.remove(channelID); + NetworkMetrics.numWritingFlows().dec(); + } + } +} diff --git a/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/AbstractNettyTest.java b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/AbstractNettyTest.java new file mode 100644 index 00000000..dc7734c7 --- /dev/null +++ b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/AbstractNettyTest.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.functions.RunnableWithException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import org.junit.After; +import org.junit.Before; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.util.ArrayDeque; +import java.util.HashSet; +import java.util.List; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** A facility class for Netty test. */ +public class AbstractNettyTest { + + private static final Logger LOG = LoggerFactory.getLogger(AbstractNettyTest.class); + + protected InetSocketAddress address; + + protected JobID jobID; + + protected DataSetID dataSetID; + + protected ExecutorService executor; + + protected NettyConfig nettyConfig; + + protected TransferBufferPool transferBufferPool; + + private int prevAvailableBuffers; + + @Before + public void setup() throws Exception { + jobID = new JobID(CommonUtils.randomBytes(32)); + dataSetID = new DataSetID(CommonUtils.randomBytes(32)); + executor = Executors.newSingleThreadExecutor(); + nettyConfig = new NettyConfig(new Configuration()); + transferBufferPool = new TestTransferBufferPool(64, 64); + prevAvailableBuffers = transferBufferPool.numBuffers(); + } + + @After + public void tearDown() throws Exception { + executor.shutdown(); + executor.awaitTermination(10, TimeUnit.SECONDS); + assertEquals(prevAvailableBuffers, transferBufferPool.numBuffers()); + transferBufferPool.destroy(); + } + + protected void runAsync(RunnableWithException runnable) { + executor.submit( + () -> { + try { + runnable.run(); + } catch (Throwable e) { + LOG.info("", e); + } + }); + } + + protected ByteBuf requestBuffer() { + return transferBufferPool.requestBuffer(); + } + + protected Queue constructBuffers(int numBuffers, int numLongsPerBuffer) { + Queue res = new ArrayDeque<>(); + int number = 0; + for (int i = 0; i < numBuffers; i++) { + ByteBuf buffer = transferBufferPool.requestBuffer(); + for (int j = 0; j < numLongsPerBuffer; j++) { + buffer.writeLong(number++); + } + res.add(buffer); + } + return res; + } + + protected void verifyBuffers(int numBuffers, int numLongsPerBuffer, List buffers) { + int a = 0; + for (int i = 0; i < numBuffers; i++) { + ByteBuf buffer = buffers.get(i); + assertEquals(8 * numLongsPerBuffer, buffer.readableBytes()); + for (int j = 0; j < numLongsPerBuffer; j++) { + assertEquals(a++, buffer.readLong()); + } + } + } + + protected void checkUntil(Runnable runnable) throws InterruptedException { + Throwable lastThrowable = null; + for (int i = 0; i < 100; i++) { + try { + runnable.run(); + return; + } catch (Throwable t) { + lastThrowable = t; + Thread.sleep(200); + } + } + LOG.info("", lastThrowable); + fail(); + } + + protected void delayCheck(Runnable runnable) throws InterruptedException { + Thread.sleep(100); + runnable.run(); + } + + public static int getAvailablePort() { + return getAvailablePorts(1)[0]; + } + + public static int[] getAvailablePorts(int num) { + Set ports = new HashSet<>(); + for (int i = 0; i < 100; i++) { + try (ServerSocket serverSocket = new ServerSocket(0)) { + int port = serverSocket.getLocalPort(); + if (port != 0) { + ports.add(port); + } + if (ports.size() >= num) { + break; + } + } catch (IOException ignored) { + } + } + if (ports.size() == num) { + int[] ret = new int[num]; + int i = 0; + for (int port : ports) { + ret[i++] = port; + } + return ret; + } else { + throw new RuntimeException("Could not find free permitted ports on the machine."); + } + } +} diff --git a/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/ConnectionManagerTest.java b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/ConnectionManagerTest.java new file mode 100644 index 00000000..a0e5ae82 --- /dev/null +++ b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/ConnectionManagerTest.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.core.ids.ChannelID; + +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; + +import org.junit.Test; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static junit.framework.Assert.assertNotNull; +import static junit.framework.Assert.assertNull; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** Test for {@link ConnectionManager}. */ +public class ConnectionManagerTest { + + private final int maxRetries = 10; + + private final Duration connectionRetryWait = Duration.ofMillis(100); + + @Test + public void testConnectionRetry() throws Exception { + for (int numFailsFirstConnects : new int[] {0, 1, maxRetries, maxRetries + 1}) { + ConnectionManager connMgr = createConnectionManager(numFailsFirstConnects); + connMgr.start(); + Channel ch = null; + try { + InetSocketAddress address = new InetSocketAddress(InetAddress.getLocalHost(), 0); + ch = connMgr.getChannel(new ChannelID(), address); + } catch (Exception ignored) { + } + if (numFailsFirstConnects < maxRetries) { + assertNotNull(ch); + } else { + assertNull(ch); + } + connMgr.shutdown(); + } + } + + @Test + public void testReuse() throws Exception { + InetSocketAddress address0 = new InetSocketAddress(InetAddress.getLocalHost(), 0); + InetSocketAddress address1 = new InetSocketAddress(InetAddress.getLocalHost(), 1); + + ConnectionManager connMgr = createConnectionManager(0); + connMgr.start(); + + ChannelID channelID00 = new ChannelID(); + Channel ch00 = connMgr.getChannel(channelID00, address0); + assertNotNull(ch00); + assertEquals(1, connMgr.numPhysicalConnections()); + + ChannelID channelID01 = new ChannelID(); + Channel ch01 = connMgr.getChannel(channelID01, address0); + assertNotNull(ch01); + assertEquals(1, connMgr.numPhysicalConnections()); + assertEquals(ch00, ch01); + + ChannelID channelID10 = new ChannelID(); + Channel ch10 = connMgr.getChannel(channelID10, address1); + assertNotNull(ch10); + assertEquals(2, connMgr.numPhysicalConnections()); + assertNotEquals(ch00, ch10); + + connMgr.releaseChannel(address0, channelID01); + assertEquals(2, connMgr.numPhysicalConnections()); + + connMgr.releaseChannel(address0, channelID00); + assertEquals(1, connMgr.numPhysicalConnections()); + + connMgr.releaseChannel(address1, channelID10); + assertEquals(0, connMgr.numPhysicalConnections()); + + connMgr.shutdown(); + } + + @Test + public void testReconnectDelay() throws Exception { + long startTime = System.nanoTime(); + int numFailsFirstConnects = 2; + ConnectionManager connMgr = createConnectionManager(numFailsFirstConnects); + connMgr.start(); + InetSocketAddress address = new InetSocketAddress(InetAddress.getLocalHost(), 0); + connMgr.getChannel(new ChannelID(), address); + connMgr.shutdown(); + long duration = System.nanoTime() - startTime; + long delay = connectionRetryWait.toNanos() * numFailsFirstConnects; + String msg = String.format("Retry duration (%d) < delay (%d)", duration, delay); + assertTrue(msg, duration >= delay); + } + + @Test + public void testMultipleConcurrentConnect() throws Exception { + ConnectionManager connMgr = createConnectionManager(0); + connMgr.start(); + List addrs = new ArrayList<>(); + for (int i = 0; i < 4; i++) { + addrs.add(new InetSocketAddress(InetAddress.getLocalHost(), i)); + } + List threads = new ArrayList<>(); + List channelIDs = new ArrayList<>(); + AtomicReference cause = new AtomicReference<>(null); + for (int i = 0; i < 24; i++) { + final int idx = i % 4; + ChannelID channelID = new ChannelID(); + channelIDs.add(channelID); + Runnable r = + () -> { + try { + assertNotNull(connMgr.getChannel(channelID, addrs.get(idx))); + } catch (Throwable t) { + cause.set(t); + } + }; + threads.add(new Thread(r)); + } + for (Thread t : threads) { + t.start(); + } + for (Thread t : threads) { + t.join(); + } + assertNull(cause.get()); + assertEquals(4, connMgr.numPhysicalConnections()); + for (int i = 0; i < 24; i++) { + connMgr.releaseChannel(addrs.get(i % 4), channelIDs.get(i)); + } + assertEquals(0, connMgr.numPhysicalConnections()); + } + + private ConnectionManager createConnectionManager(int numFailsFirstConnects) { + NettyClient mockedNettyClient = mock(NettyClient.class); + AtomicInteger failedTimes = new AtomicInteger(0); + when(mockedNettyClient.connect(any(InetSocketAddress.class))) + .thenAnswer( + invoke -> { + Channel ch = mock(Channel.class); + when(ch.writeAndFlush(any(Object.class))) + .thenReturn(mock(ChannelFuture.class)); + return mockChannelFuture(ch, failedTimes, numFailsFirstConnects); + }); + return new ConnectionManager(null, null, maxRetries, connectionRetryWait) { + @Override + public synchronized void start() throws IOException { + this.nettyClient = mockedNettyClient; + } + }; + } + + private ChannelFuture mockChannelFuture( + Channel channel, AtomicInteger failedTimes, int numFailsFirstConnects) + throws Exception { + ChannelFuture channelFuture = mock(ChannelFuture.class); + when(channelFuture.sync()) + .thenAnswer( + invoke -> { + if (failedTimes.get() < numFailsFirstConnects) { + failedTimes.incrementAndGet(); + throw new Exception("Connection failure."); + } else { + return channelFuture; + } + }); + when(channelFuture.channel()).thenReturn(channel); + return channelFuture; + } +} diff --git a/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/DecoderDelegateTest.java b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/DecoderDelegateTest.java new file mode 100644 index 00000000..4647d925 --- /dev/null +++ b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/DecoderDelegateTest.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.config.TransferOptions; +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadHandshakeRequest; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.net.InetAddress; +import java.net.InetSocketAddress; + +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.currentProtocolVersion; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyBufferSize; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyExtraMessage; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyOffset; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** Test for {@link DecoderDelegate}. */ +public class DecoderDelegateTest extends AbstractNettyTest { + + private NettyServer nettyServer; + + private DummyChannelInboundHandlerAdaptor serverH; + + private MapPartitionID mapID; + + private DecoderDelegate decoderDelegate; + + private FakeDecoder decoder; + + private NettyClient nettyClient; + + private Channel channel; + + @Override + @Before + public void setup() throws Exception { + super.setup(); + + mapID = new MapPartitionID(CommonUtils.randomBytes(32)); + + int dataPort = initShuffleServer(); + address = new InetSocketAddress(InetAddress.getLocalHost(), dataPort); + + nettyClient = new NettyClient(nettyConfig); + nettyClient.init(() -> new ChannelHandler[] {new TransferMessageEncoder()}); + channel = nettyClient.connect(address).await().channel(); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + nettyServer.shutdown(); + channel.close(); + nettyClient.shutdown(); + } + + @Test + public void testDecoderException() throws Exception { + // Send ReadHandshakeRequest -- FakeDecoder throws when processing message. + channel.writeAndFlush( + new ReadHandshakeRequest( + currentProtocolVersion(), + new ChannelID(), + dataSetID, + mapID, + 0, + 0, + 1, + emptyBufferSize(), + emptyOffset(), + emptyExtraMessage())); + checkUntil(() -> assertTrue(decoder.isClosed)); + } + + @Test + public void testCloseConnection() throws Exception { + checkUntil(() -> assertNotNull(decoderDelegate)); + decoder = new FakeDecoder(); + decoderDelegate.setCurrentDecoder(decoder); + channel.close(); + checkUntil(() -> assertTrue(decoder.isClosed)); + } + + private int initShuffleServer() throws Exception { + serverH = new DummyChannelInboundHandlerAdaptor(); + int dataPort = getAvailablePort(); + nettyConfig.getConfig().setInteger(TransferOptions.SERVER_DATA_PORT, dataPort); + nettyServer = + new NettyServer(null, nettyConfig) { + @Override + public ChannelHandler[] getServerHandlers() { + decoder = new FakeDecoder(); + decoderDelegate = new DecoderDelegate(ignore -> decoder); + return new ChannelHandler[] {decoderDelegate, serverH}; + } + }; + nettyServer.start(); + return dataPort; + } + + private static class FakeDecoder extends TransferMessageDecoder { + + volatile boolean isClosed = false; + + @Override + public DecodingResult onChannelRead(ByteBuf data) throws Exception { + throw new Exception("Expected exception."); + } + + @Override + public void close() { + isClosed = true; + } + } +} diff --git a/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/DummyChannelInboundHandlerAdaptor.java b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/DummyChannelInboundHandlerAdaptor.java new file mode 100644 index 00000000..7fa6fca3 --- /dev/null +++ b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/DummyChannelInboundHandlerAdaptor.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseConnection; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadData; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteData; + +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** A dummy {@link ChannelInboundHandlerAdapter} for testing. */ +public class DummyChannelInboundHandlerAdaptor extends ChannelInboundHandlerAdapter { + + private volatile ChannelHandlerContext currentCtx; + + private final List messages = Collections.synchronizedList(new ArrayList<>()); + + private volatile Throwable cause; + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + this.cause = cause; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + currentCtx = ctx; + } + + @Override + public synchronized void channelRead(ChannelHandlerContext ctx, Object msg) { + messages.add(msg); + if (msg instanceof CloseConnection) { + ctx.channel().close(); + } + } + + public synchronized Object getMsg(int index) { + if (index >= messages.size()) { + return null; + } + return messages.get(index); + } + + public synchronized Object getLastMsg() { + if (messages.isEmpty()) { + return null; + } + return messages.get(messages.size() - 1); + } + + public synchronized List getMessages() { + return messages; + } + + public synchronized int numMessages() { + return messages.size(); + } + + public synchronized boolean isEmpty() { + return messages.isEmpty(); + } + + public void send(Object obj) { + currentCtx.writeAndFlush(obj); + } + + public boolean isConnected() { + return currentCtx != null; + } + + public void close() { + messages.forEach( + msg -> { + if (msg instanceof WriteData) { + ((WriteData) msg).getBuffer().release(); + } + if (msg instanceof ReadData) { + ((ReadData) msg).getBuffer().release(); + } + }); + currentCtx.close().awaitUninterruptibly(); + } +} diff --git a/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/EncodingDecodingTest.java b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/EncodingDecodingTest.java new file mode 100644 index 00000000..512fdab2 --- /dev/null +++ b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/EncodingDecodingTest.java @@ -0,0 +1,628 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseChannel; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseConnection; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ErrorResponse; +import com.alibaba.flink.shuffle.transfer.TransferMessage.Heartbeat; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadAddCredit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadData; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadHandshakeRequest; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteAddCredit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteData; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteFinish; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteFinishCommit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteHandshakeRequest; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteRegionFinish; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteRegionStart; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufAllocator; +import org.apache.flink.shaded.netty4.io.netty.buffer.PooledByteBufAllocator; +import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel; +import org.apache.flink.shaded.netty4.io.netty.util.ReferenceCounted; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.currentProtocolVersion; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyExtraMessage; +import static com.alibaba.flink.shuffle.common.utils.StringUtils.bytesToString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** Test for {@link TransferMessage} encoding and decoding. */ +public class EncodingDecodingTest { + + private static final ByteBufAllocator ALLOCATOR = new PooledByteBufAllocator(); + + private static final Random random = new Random(); + + private final CreditListener creditListener = new TestCreditListener(); + + @Test + public void testReadData() { + testReadData(100, 128, 128); + testReadData(100, 128, 100); + testReadData(100, 128, 200); + testReadData(1000, 128, 1024); + testReadData(1000, 128, 1000); + testReadData(1000, 1024, 128); + testReadData(1000, 1024, 4096); + testReadData(1000, 1024, 4000); + } + + @Test + public void testWriteData() { + testWriteData(100, 128, 128); + testWriteData(100, 128, 100); + testWriteData(100, 128, 200); + testWriteData(1000, 128, 1024); + testWriteData(1000, 128, 1000); + testWriteData(1000, 1024, 128); + testWriteData(1000, 1024, 4096); + testWriteData(1000, 1024, 4000); + } + + @Test + public void testErrorResponse() { + int version = currentProtocolVersion(); + ChannelID channelID = new ChannelID(); + String errMsg = "My Exception."; + String extraInfo = bytesToString(CommonUtils.randomBytes(32)); + Supplier messageBuilder = + () -> new ErrorResponse(version, channelID, errMsg.getBytes(), extraInfo); + Consumer messageVerifier = + msg -> { + ErrorResponse errRsp = (ErrorResponse) msg; + assertEquals(version, errRsp.getVersion()); + assertEquals(channelID, errRsp.getChannelID()); + assertEquals(errMsg, new String(errRsp.getErrorMessageBytes())); + assertEquals(extraInfo, errRsp.getExtraInfo()); + }; + testCommonMessage(messageBuilder, messageVerifier, 100, 10); + testCommonMessage(messageBuilder, messageVerifier, 100, 16); + testCommonMessage(messageBuilder, messageVerifier, 100, 20); + testCommonMessage(messageBuilder, messageVerifier, 100, 40); + testCommonMessage(messageBuilder, messageVerifier, 100, 100); + testCommonMessage(messageBuilder, messageVerifier, 100, 1024); + } + + @Test + public void testWriteHandshakeRequest() { + int version = currentProtocolVersion(); + ChannelID channelID = new ChannelID(); + JobID jobID = new JobID(CommonUtils.randomBytes(32)); + DataSetID dataSetID = new DataSetID(CommonUtils.randomBytes(32)); + MapPartitionID mapID = new MapPartitionID(CommonUtils.randomBytes(16)); + int numSubs = 1234; + int bufferSize = random.nextInt(); + String dataPartitionTypeFactory = bytesToString(CommonUtils.randomBytes(32)); + String extraInfo = bytesToString(CommonUtils.randomBytes(32)); + Supplier messageBuilder = + () -> + new WriteHandshakeRequest( + version, + channelID, + jobID, + dataSetID, + mapID, + numSubs, + bufferSize, + dataPartitionTypeFactory, + extraInfo); + Consumer messageVerifier = + msg -> { + WriteHandshakeRequest tmp = (WriteHandshakeRequest) msg; + assertEquals(version, tmp.getVersion()); + assertEquals(channelID, tmp.getChannelID()); + assertEquals(jobID, tmp.getJobID()); + assertEquals(dataSetID, tmp.getDataSetID()); + assertEquals(mapID, tmp.getMapID()); + assertEquals(numSubs, tmp.getNumSubs()); + assertEquals(bufferSize, tmp.getBufferSize()); + assertEquals(dataPartitionTypeFactory, tmp.getDataPartitionType()); + assertEquals(extraInfo, tmp.getExtraInfo()); + }; + testCommonMessage(messageBuilder, messageVerifier, 100, 20); + testCommonMessage(messageBuilder, messageVerifier, 100, 32); + testCommonMessage(messageBuilder, messageVerifier, 100, 96); + testCommonMessage(messageBuilder, messageVerifier, 100, 100); + testCommonMessage(messageBuilder, messageVerifier, 100, 200); + testCommonMessage(messageBuilder, messageVerifier, 100, 1024); + } + + @Test + public void testWriteAddCredit() { + int version = currentProtocolVersion(); + ChannelID channelID = new ChannelID(); + int credit = 123; + int regionIdx = 789; + String extraInfo = bytesToString(CommonUtils.randomBytes(32)); + Supplier messageBuilder = + () -> new WriteAddCredit(version, channelID, credit, regionIdx, extraInfo); + Consumer messageVerifier = + msg -> { + WriteAddCredit tmp = (WriteAddCredit) msg; + assertEquals(version, tmp.getVersion()); + assertEquals(channelID, tmp.getChannelID()); + assertEquals(credit, tmp.getCredit()); + assertEquals(regionIdx, tmp.getRegionIdx()); + assertEquals(extraInfo, tmp.getExtraInfo()); + }; + testCommonMessage(messageBuilder, messageVerifier, 100, 10); + testCommonMessage(messageBuilder, messageVerifier, 100, 16); + testCommonMessage(messageBuilder, messageVerifier, 100, 20); + testCommonMessage(messageBuilder, messageVerifier, 100, 40); + testCommonMessage(messageBuilder, messageVerifier, 100, 100); + testCommonMessage(messageBuilder, messageVerifier, 100, 1024); + } + + @Test + public void testWriteRegionStart() { + int version = currentProtocolVersion(); + ChannelID channelID = new ChannelID(); + int regionIdx = 123; + boolean isBroadcast = false; + String extraInfo = bytesToString(CommonUtils.randomBytes(32)); + Supplier messageBuilder = + () -> new WriteRegionStart(version, channelID, regionIdx, isBroadcast, extraInfo); + Consumer messageVerifier = + msg -> { + WriteRegionStart tmp = (WriteRegionStart) msg; + assertEquals(version, tmp.getVersion()); + assertEquals(channelID, tmp.getChannelID()); + assertEquals(regionIdx, tmp.getRegionIdx()); + assertEquals(isBroadcast, tmp.isBroadcast()); + assertEquals(extraInfo, tmp.getExtraInfo()); + }; + testCommonMessage(messageBuilder, messageVerifier, 100, 10); + testCommonMessage(messageBuilder, messageVerifier, 100, 16); + testCommonMessage(messageBuilder, messageVerifier, 100, 20); + testCommonMessage(messageBuilder, messageVerifier, 100, 40); + testCommonMessage(messageBuilder, messageVerifier, 100, 100); + testCommonMessage(messageBuilder, messageVerifier, 100, 1024); + } + + @Test + public void testWriteRegionFinish() { + int version = currentProtocolVersion(); + ChannelID channelID = new ChannelID(); + String extraInfo = bytesToString(CommonUtils.randomBytes(32)); + Supplier messageBuilder = + () -> new WriteRegionFinish(version, channelID, extraInfo); + Consumer messageVerifier = + msg -> { + WriteRegionFinish tmp = (WriteRegionFinish) msg; + assertEquals(version, tmp.getVersion()); + assertEquals(channelID, tmp.getChannelID()); + assertEquals(extraInfo, tmp.getExtraInfo()); + }; + testCommonMessage(messageBuilder, messageVerifier, 100, 10); + testCommonMessage(messageBuilder, messageVerifier, 100, 16); + testCommonMessage(messageBuilder, messageVerifier, 100, 20); + testCommonMessage(messageBuilder, messageVerifier, 100, 32); + testCommonMessage(messageBuilder, messageVerifier, 100, 100); + testCommonMessage(messageBuilder, messageVerifier, 100, 1024); + } + + @Test + public void testWriteFinish() { + int version = currentProtocolVersion(); + ChannelID channelID = new ChannelID(); + String extraInfo = bytesToString(CommonUtils.randomBytes(32)); + Supplier messageBuilder = () -> new WriteFinish(version, channelID, extraInfo); + Consumer messageVerifier = + msg -> { + WriteFinish tmp = (WriteFinish) msg; + assertEquals(version, tmp.getVersion()); + assertEquals(channelID, tmp.getChannelID()); + assertEquals(extraInfo, tmp.getExtraInfo()); + }; + testCommonMessage(messageBuilder, messageVerifier, 100, 10); + testCommonMessage(messageBuilder, messageVerifier, 100, 16); + testCommonMessage(messageBuilder, messageVerifier, 100, 20); + testCommonMessage(messageBuilder, messageVerifier, 100, 32); + testCommonMessage(messageBuilder, messageVerifier, 100, 100); + testCommonMessage(messageBuilder, messageVerifier, 100, 1024); + } + + @Test + public void testWriteFinishCommit() { + int version = currentProtocolVersion(); + ChannelID channelID = new ChannelID(); + String extraInfo = bytesToString(CommonUtils.randomBytes(32)); + Supplier messageBuilder = + () -> new WriteFinishCommit(version, channelID, extraInfo); + Consumer messageVerifier = + msg -> { + WriteFinishCommit tmp = (WriteFinishCommit) msg; + assertEquals(version, tmp.getVersion()); + assertEquals(channelID, tmp.getChannelID()); + assertEquals(extraInfo, tmp.getExtraInfo()); + }; + testCommonMessage(messageBuilder, messageVerifier, 100, 10); + testCommonMessage(messageBuilder, messageVerifier, 100, 16); + testCommonMessage(messageBuilder, messageVerifier, 100, 20); + testCommonMessage(messageBuilder, messageVerifier, 100, 32); + testCommonMessage(messageBuilder, messageVerifier, 100, 100); + testCommonMessage(messageBuilder, messageVerifier, 100, 1024); + } + + @Test + public void testReadHandshakeRequest() { + int version = currentProtocolVersion(); + ChannelID channelID = new ChannelID(); + DataSetID dataSetID = new DataSetID(CommonUtils.randomBytes(32)); + MapPartitionID mapID = new MapPartitionID(CommonUtils.randomBytes(16)); + int startSubIdx = 123; + int endSubIdx = 456; + int initialCredit = 789; + int bufferSize = random.nextInt(); + int offset = random.nextInt(); + String extraInfo = bytesToString(CommonUtils.randomBytes(32)); + Supplier messageBuilder = + () -> + new ReadHandshakeRequest( + version, + channelID, + dataSetID, + mapID, + startSubIdx, + endSubIdx, + initialCredit, + bufferSize, + offset, + extraInfo); + Consumer messageVerifier = + msg -> { + ReadHandshakeRequest tmp = (ReadHandshakeRequest) msg; + assertEquals(version, tmp.getVersion()); + assertEquals(channelID, tmp.getChannelID()); + assertEquals(dataSetID, tmp.getDataSetID()); + assertEquals(mapID, tmp.getMapID()); + assertEquals(startSubIdx, tmp.getStartSubIdx()); + assertEquals(endSubIdx, tmp.getEndSubIdx()); + assertEquals(initialCredit, tmp.getInitialCredit()); + assertEquals(bufferSize, tmp.getBufferSize()); + assertEquals(offset, tmp.getOffset()); + assertEquals(extraInfo, tmp.getExtraInfo()); + }; + testCommonMessage(messageBuilder, messageVerifier, 100, 20); + testCommonMessage(messageBuilder, messageVerifier, 100, 72); + testCommonMessage(messageBuilder, messageVerifier, 100, 144); + testCommonMessage(messageBuilder, messageVerifier, 100, 200); + testCommonMessage(messageBuilder, messageVerifier, 100, 1024); + } + + @Test + public void testReadAddCredit() { + int version = currentProtocolVersion(); + ChannelID channelID = new ChannelID(); + int credit = 123; + String extraInfo = bytesToString(CommonUtils.randomBytes(32)); + Supplier messageBuilder = + () -> new ReadAddCredit(version, channelID, credit, extraInfo); + Consumer messageVerifier = + msg -> { + ReadAddCredit tmp = (ReadAddCredit) msg; + assertEquals(version, tmp.getVersion()); + assertEquals(channelID, tmp.getChannelID()); + assertEquals(credit, tmp.getCredit()); + assertEquals(extraInfo, tmp.getExtraInfo()); + }; + testCommonMessage(messageBuilder, messageVerifier, 100, 10); + testCommonMessage(messageBuilder, messageVerifier, 100, 16); + testCommonMessage(messageBuilder, messageVerifier, 100, 20); + testCommonMessage(messageBuilder, messageVerifier, 100, 40); + testCommonMessage(messageBuilder, messageVerifier, 100, 100); + testCommonMessage(messageBuilder, messageVerifier, 100, 1024); + } + + @Test + public void testCloseChannel() { + int version = currentProtocolVersion(); + ChannelID channelID = new ChannelID(); + String extraInfo = bytesToString(CommonUtils.randomBytes(32)); + Supplier messageBuilder = + () -> new CloseChannel(version, channelID, extraInfo); + Consumer messageVerifier = + msg -> { + CloseChannel tmp = (CloseChannel) msg; + assertEquals(version, tmp.getVersion()); + assertEquals(channelID, tmp.getChannelID()); + assertEquals(extraInfo, tmp.getExtraInfo()); + }; + testCommonMessage(messageBuilder, messageVerifier, 100, 10); + testCommonMessage(messageBuilder, messageVerifier, 100, 16); + testCommonMessage(messageBuilder, messageVerifier, 100, 20); + testCommonMessage(messageBuilder, messageVerifier, 100, 32); + testCommonMessage(messageBuilder, messageVerifier, 100, 100); + testCommonMessage(messageBuilder, messageVerifier, 100, 1024); + } + + @Test + public void testCloseConnection() { + int version = currentProtocolVersion(); + String extraInfo = emptyExtraMessage(); + Supplier messageBuilder = CloseConnection::new; + Consumer messageVerifier = + msg -> { + assertTrue(msg instanceof CloseConnection); + CloseConnection tmp = (CloseConnection) msg; + assertEquals(version, tmp.getVersion()); + assertEquals(extraInfo, tmp.getExtraInfo()); + }; + + testCommonMessage(messageBuilder, messageVerifier, 100, 1); + testCommonMessage(messageBuilder, messageVerifier, 100, 2); + testCommonMessage(messageBuilder, messageVerifier, 100, 3); + testCommonMessage(messageBuilder, messageVerifier, 100, 4); + testCommonMessage(messageBuilder, messageVerifier, 100, 8); + testCommonMessage(messageBuilder, messageVerifier, 100, 16); + testCommonMessage(messageBuilder, messageVerifier, 100, 100); + testCommonMessage(messageBuilder, messageVerifier, 100, 1024); + } + + @Test + public void testHeartbeat() { + int version = currentProtocolVersion(); + String extraInfo = emptyExtraMessage(); + Supplier messageBuilder = Heartbeat::new; + Consumer messageVerifier = + msg -> { + assertTrue(msg instanceof Heartbeat); + Heartbeat tmp = (Heartbeat) msg; + assertEquals(version, tmp.getVersion()); + assertEquals(extraInfo, tmp.getExtraInfo()); + }; + + testCommonMessage(messageBuilder, messageVerifier, 100, 1); + testCommonMessage(messageBuilder, messageVerifier, 100, 2); + testCommonMessage(messageBuilder, messageVerifier, 100, 3); + testCommonMessage(messageBuilder, messageVerifier, 100, 4); + testCommonMessage(messageBuilder, messageVerifier, 100, 8); + testCommonMessage(messageBuilder, messageVerifier, 100, 16); + testCommonMessage(messageBuilder, messageVerifier, 100, 100); + testCommonMessage(messageBuilder, messageVerifier, 100, 1024); + } + + private void testCommonMessage( + Supplier messageBuilder, + Consumer messageVerifier, + int numMessages, + int targetChopSize) { + List messages = new ArrayList<>(); + for (int i = 0; i < numMessages; i++) { + messages.add(messageBuilder.get()); + } + EmbeddedChannel channel = + new EmbeddedChannel(new TransferMessageEncoder(), getDecoderDelegate(null)); + List encoded = encode(channel, messages); + List chopped = chop(encoded, targetChopSize); + chopped.forEach(channel::writeInbound); + List receivedList = new ArrayList<>(); + TransferMessage received = null; + while ((received = channel.readInbound()) != null) { + receivedList.add(received); + } + assertEquals(numMessages, receivedList.size()); + receivedList.forEach(messageVerifier); + } + + private void testReadData(int numReadDatas, int bufferSize, int targetChopSize) { + int version = currentProtocolVersion(); + ChannelID channelID = new ChannelID(); + int offset = random.nextInt(); + String extraInfo = bytesToString(CommonUtils.randomBytes(32)); + int backlog = 789; + + TransferBufferPool transferBufferPool = + new TestTransferBufferPool(numReadDatas * 2, bufferSize); + EmbeddedChannel channel = + new EmbeddedChannel( + new TransferMessageEncoder(), getDecoderDelegate(transferBufferPool)); + + List buffers = constructBuffers(transferBufferPool, numReadDatas); + int expectSize = getTotalSize(buffers); + List readDatas = + buffers.stream() + .map( + byteBuf -> + new ReadData( + version, + channelID, + backlog, + byteBuf.readableBytes(), + offset, + byteBuf, + extraInfo)) + .collect(Collectors.toList()); + List encoded = encode(channel, readDatas); + + List chopped = chop(encoded, targetChopSize); + chopped.forEach(channel::writeInbound); + List receivedList = new ArrayList<>(); + ReadData received = null; + while ((received = channel.readInbound()) != null) { + receivedList.add(received); + } + assertEquals(numReadDatas, receivedList.size()); + receivedList.forEach( + readData -> { + assertEquals(version, readData.getVersion()); + assertEquals(channelID, readData.getChannelID()); + assertEquals(offset, readData.getOffset()); + assertEquals(backlog, readData.getBacklog()); + assertEquals(extraInfo, readData.getExtraInfo()); + }); + List byteBufs = + receivedList.stream().map(ReadData::getBuffer).collect(Collectors.toList()); + verifyBuffers(byteBufs, expectSize); + } + + private void testWriteData(int numWriteDatas, int bufferSize, int targetChopSize) { + int version = currentProtocolVersion(); + ChannelID channelID = new ChannelID(); + String extraInfo = bytesToString(CommonUtils.randomBytes(32)); + int subIdx = 789; + + TransferBufferPool transferBufferPool = + new TestTransferBufferPool(numWriteDatas * 2, bufferSize); + EmbeddedChannel channel = + new EmbeddedChannel( + new TransferMessageEncoder(), getDecoderDelegate(transferBufferPool)); + + List buffers = constructBuffers(transferBufferPool, numWriteDatas); + int expectSize = getTotalSize(buffers); + List readDatas = + buffers.stream() + .map( + b -> + new WriteData( + version, + channelID, + b, + subIdx, + b.readableBytes(), + false, + extraInfo)) + .collect(Collectors.toList()); + List encoded = encode(channel, readDatas); + + List chopped = chop(encoded, targetChopSize); + chopped.forEach(channel::writeInbound); + List receivedList = new ArrayList<>(); + WriteData received = null; + while ((received = channel.readInbound()) != null) { + receivedList.add(received); + } + assertEquals(numWriteDatas, receivedList.size()); + receivedList.forEach( + writeData -> { + assertEquals(version, writeData.getVersion()); + assertEquals(channelID, writeData.getChannelID()); + assertEquals(subIdx, writeData.getSubIdx()); + assertEquals(extraInfo, writeData.getExtraInfo()); + }); + List byteBufs = + receivedList.stream().map(WriteData::getBuffer).collect(Collectors.toList()); + verifyBuffers(byteBufs, expectSize); + } + + private List constructBuffers(TransferBufferPool bufferPool, int numBuffers) { + List res = new ArrayList<>(); + long a = 0; + for (int i = 0; i < numBuffers; i++) { + ByteBuf buffer = bufferPool.requestBuffer(); + while (buffer.capacity() - buffer.writerIndex() > 8) { + buffer.writeLong(a++); + } + res.add(buffer); + } + return res; + } + + private int getTotalSize(List buffers) { + return buffers.stream().map(ByteBuf::readableBytes).reduce(0, Integer::sum); + } + + private List encode( + EmbeddedChannel channel, List messages) { + List ret = new ArrayList<>(); + for (Object msg : messages) { + channel.writeOutbound(msg); + ByteBuf encoded; + while ((encoded = channel.readOutbound()) != null) { + ret.add(encoded); + } + } + return ret; + } + + private List chop(List buffers, int targetSize) { + List ret = new ArrayList<>(); + ByteBuf tmp = ALLOCATOR.directBuffer(targetSize); + for (ByteBuf buffer : buffers) { + while (buffer.readableBytes() != 0) { + int numBytes = Math.min(tmp.writableBytes(), buffer.readableBytes()); + tmp.writeBytes(buffer, numBytes); + if (tmp.writableBytes() == 0) { + ret.add(tmp); + tmp = ALLOCATOR.directBuffer(targetSize); + } + } + buffer.release(); + } + if (tmp.readableBytes() == 0) { + tmp.release(); + } else { + ret.add(tmp); + } + return ret; + } + + private void verifyBuffers(List buffers, int expectedSize) { + assertEquals(expectedSize, getTotalSize(buffers)); + int a = 0; + ByteBuf tmp = ALLOCATOR.directBuffer(1024); + for (ByteBuf buffer : buffers) { + while (buffer.readableBytes() != 0) { + int numBytes = Math.min(tmp.writableBytes(), buffer.readableBytes()); + tmp.writeBytes(buffer, numBytes); + if (tmp.writableBytes() == 0) { + while (tmp.readableBytes() > 0) { + assertEquals(a++, tmp.readLong()); + } + tmp.clear(); + } + } + } + tmp.release(); + buffers.forEach(ReferenceCounted::release); + } + + private DecoderDelegate getDecoderDelegate(TransferBufferPool bufferPool) { + Function messageDecoder = + msgID -> { + switch (msgID) { + case ReadData.ID: + return new ShuffleReadDataDecoder(ignore -> bufferPool::requestBuffer); + case WriteData.ID: + return new ShuffleWriteDataDecoder(ignore -> bufferPool::requestBuffer); + default: + return new CommonTransferMessageDecoder(); + } + }; + return new DecoderDelegate(messageDecoder); + } +} diff --git a/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/FakedDataPartitionReadingView.java b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/FakedDataPartitionReadingView.java new file mode 100644 index 00000000..c34c388d --- /dev/null +++ b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/FakedDataPartitionReadingView.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.core.listener.BacklogListener; +import com.alibaba.flink.shuffle.core.listener.DataListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.storage.BufferWithBacklog; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReadingView; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import java.util.ArrayDeque; +import java.util.List; +import java.util.Queue; + +/** A {@link DataPartitionReadingView} for test. */ +public class FakedDataPartitionReadingView implements DataPartitionReadingView { + + private Queue buffersToSend = new ArrayDeque<>(); + private final DataListener dataListener; + private final BacklogListener backlogListener; + private final FailureListener failureListener; + private Throwable cause; + private boolean noMoreData; + + public FakedDataPartitionReadingView( + DataListener dataListener, + BacklogListener backlogListener, + FailureListener failureListener) { + this.dataListener = dataListener; + this.backlogListener = backlogListener; + this.failureListener = failureListener; + } + + public synchronized void notifyBuffers(List buffers) { + int numPrefResidual = buffersToSend.size(); + buffersToSend.addAll(buffers); + if (numPrefResidual == 0) { + dataListener.notifyDataAvailable(); + } + } + + public synchronized void notifyBuffer(ByteBuf buffer) { + if (buffersToSend.isEmpty()) { + buffersToSend.add(buffer); + dataListener.notifyDataAvailable(); + } else { + buffersToSend.add(buffer); + } + } + + public void notifyBacklog(int backlog) { + backlogListener.notifyBacklog(backlog); + } + + @Override + public synchronized BufferWithBacklog nextBuffer() { + ByteBuf polled = buffersToSend.poll(); + if (polled == null) { + return null; + } + return new BufferWithBacklog((Buffer) polled, buffersToSend.size()); + } + + @Override + public synchronized void onError(Throwable throwable) { + while (!buffersToSend.isEmpty()) { + buffersToSend.poll().release(); + } + cause = throwable; + } + + public void setNoMoreData(boolean value) { + noMoreData = value; + } + + @Override + public synchronized boolean isFinished() { + return noMoreData && buffersToSend.isEmpty(); + } + + public Throwable getError() { + return cause; + } + + public void triggerFailure(Throwable t) { + failureListener.notifyFailure(t); + } + + public void setBuffersToSend(Queue buffersToSend) { + this.buffersToSend = buffersToSend; + } + + public DataListener getDataAvailableListener() { + return dataListener; + } +} diff --git a/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/FakedDataPartitionWritingView.java b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/FakedDataPartitionWritingView.java new file mode 100644 index 00000000..86e1a2d5 --- /dev/null +++ b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/FakedDataPartitionWritingView.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.ids.ReducePartitionID; +import com.alibaba.flink.shuffle.core.listener.DataCommitListener; +import com.alibaba.flink.shuffle.core.listener.DataRegionCreditListener; +import com.alibaba.flink.shuffle.core.listener.FailureListener; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.memory.BufferSupplier; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWritingView; + +import java.util.List; +import java.util.function.Consumer; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkState; + +/** A {@link DataPartitionWritingView} for test. */ +public class FakedDataPartitionWritingView implements DataPartitionWritingView { + + private final List buffers; + private volatile int regionStartCount; + private volatile int regionFinishCount; + private volatile boolean isFinished; + private final Consumer dataHandler; + private final DataRegionCreditListener dataRegionCreditListener; + private final FailureListener failureListener; + private Throwable cause; + + public FakedDataPartitionWritingView( + DataSetID dataSetID, + MapPartitionID mapPartitionID, + Consumer dataHandler, + DataRegionCreditListener dataRegionCreditListener, + FailureListener failureListener, + List buffers) { + + this.dataHandler = dataHandler; + this.dataRegionCreditListener = dataRegionCreditListener; + this.failureListener = failureListener; + this.regionStartCount = 0; + this.regionFinishCount = 0; + this.isFinished = false; + this.cause = null; + this.buffers = buffers; + } + + @Override + public void onBuffer(Buffer buffer, ReducePartitionID reducePartitionID) { + dataHandler.accept(buffer); + buffer.clear(); + buffers.add(buffer); + dataRegionCreditListener.notifyCredits(1, regionFinishCount); + } + + @Override + public void regionStarted(int dataRegionIndex, boolean isBroadcastRegion) { + regionStartCount++; + dataRegionCreditListener.notifyCredits(1, regionFinishCount); + } + + @Override + public void regionFinished() { + regionFinishCount++; + } + + @Override + public void finish(DataCommitListener listener) { + listener.notifyDataCommitted(); + release(); + isFinished = true; + } + + private void release() { + while (!buffers.isEmpty()) { + buffers.remove(0).release(); + } + } + + @Override + public void onError(Throwable throwable) { + release(); + cause = throwable; + } + + @Override + public BufferSupplier getBufferSupplier() { + return () -> { + checkState(!buffers.isEmpty(), "No buffers available."); + return buffers.remove(0); + }; + } + + public List getCandidateBuffers() { + return buffers; + } + + public int getRegionStartCount() { + return regionStartCount; + } + + public int getRegionFinishCount() { + return regionFinishCount; + } + + public boolean isFinished() { + return isFinished; + } + + public void triggerFailure(Throwable t) { + failureListener.notifyFailure(t); + } + + public Throwable getError() { + return cause; + } +} diff --git a/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/ShuffleReadClientTest.java b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/ShuffleReadClientTest.java new file mode 100644 index 00000000..ac41db86 --- /dev/null +++ b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/ShuffleReadClientTest.java @@ -0,0 +1,398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.config.TransferOptions; +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseChannel; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseConnection; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ErrorResponse; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadAddCredit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadData; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadHandshakeRequest; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; +import org.apache.flink.shaded.netty4.io.netty.util.ReferenceCounted; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.currentProtocolVersion; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyBufferSize; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyExtraMessage; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyOffset; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** Test for {@link ShuffleReadClient}. */ +public class ShuffleReadClientTest extends AbstractNettyTest { + + private static final int NUM_BUFFERS = 20; + + private static final int BUFFER_SIZE = 64; + + private NettyServer nettyServer; + + private ShuffleReadClient client; + + private ConnectionManager connManager; + + private TransferBufferPool clientBufferPool; + + private DummyChannelInboundHandlerAdaptor serverHandler; + + private List readDatas; + + private final AtomicReference cause = new AtomicReference<>(); + + @Override + @Before + public void setup() throws Exception { + super.setup(); + int dataPort = getAvailablePort(); + Pair pair = initShuffleServer(dataPort); + nettyServer = pair.getLeft(); + serverHandler = pair.getRight(); + + MapPartitionID mapID = new MapPartitionID(CommonUtils.randomBytes(16)); + readDatas = new ArrayList<>(); + clientBufferPool = new TestTransferBufferPool(NUM_BUFFERS, BUFFER_SIZE); + connManager = ConnectionManager.createReadConnectionManager(nettyConfig, false); + connManager.start(); + address = new InetSocketAddress(InetAddress.getLocalHost(), dataPort); + client = + new ShuffleReadClient( + address, + dataSetID, + mapID, + 0, + 0, + emptyBufferSize(), + clientBufferPool, + connManager, + readDatas::add, + cause::set); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + nettyServer.shutdown(); + connManager.shutdown(); + assertEquals(20, clientBufferPool.numBuffers()); + clientBufferPool.destroy(); + } + + /** Basic routine. */ + @Test + public void testReadDataAndSendCredit() throws Exception { + assertTrue(serverHandler.isEmpty()); + + // Client sends ReadHandshakeRequest. + client.connect(); + client.open(); + checkUntil(() -> assertEquals(1, serverHandler.numMessages())); + assertTrue(serverHandler.getLastMsg() instanceof ReadHandshakeRequest); + assertEquals(0, ((ReadHandshakeRequest) serverHandler.getLastMsg()).getInitialCredit()); + + // Construct 25 buffers for sending. + Queue serverBuffers = constructBuffers(25, 3); + client.backlogReceived(25); + checkUntil(() -> assertEquals(2, serverHandler.numMessages())); + assertEquals(20, ((ReadAddCredit) serverHandler.getMsg(1)).getCredit()); + + // Server sends ReadData. + for (int i = 0; i < 20; i++) { + ByteBuf buffer = serverBuffers.poll(); + serverHandler.send( + new ReadData( + currentProtocolVersion(), + client.getChannelID(), + 1, + buffer.readableBytes(), + emptyOffset(), + buffer, + emptyExtraMessage())); + } + checkUntil(() -> assertEquals(20, readDatas.size())); + readDatas.forEach(ReferenceCounted::release); + readDatas.clear(); + + checkUntil(() -> assertEquals(5, serverHandler.numMessages())); + assertEquals(2, ((ReadAddCredit) serverHandler.getMsg(2)).getCredit()); + assertEquals(2, ((ReadAddCredit) serverHandler.getMsg(3)).getCredit()); + assertEquals(1, ((ReadAddCredit) serverHandler.getMsg(4)).getCredit()); + + for (int i = 0; i < 5; i++) { + ByteBuf buffer = serverBuffers.poll(); + serverHandler.send( + new ReadData( + currentProtocolVersion(), + client.getChannelID(), + 1, + buffer.readableBytes(), + emptyOffset(), + buffer, + emptyExtraMessage())); + } + checkUntil(() -> assertEquals(5, readDatas.size())); + while (!readDatas.isEmpty()) { + readDatas.remove(0).release(); + } + delayCheck(() -> assertEquals(5, serverHandler.numMessages())); + + client.close(); + checkUntil(() -> assertEquals(7, serverHandler.numMessages())); + assertTrue(serverHandler.getMsg(5) instanceof CloseChannel); + assertTrue(serverHandler.getMsg(6) instanceof CloseConnection); + } + + /** Release all buffers when close. */ + @Test + public void testReleaseAllBuffersWhenClose() throws Exception { + int prevAvailableBuffers = transferBufferPool.numBuffers(); + + // Client sends ReadHandshakeRequest. + client.connect(); + client.open(); + checkUntil(() -> serverHandler.isConnected()); + + // Close client. + client.close(); + assertEquals(prevAvailableBuffers, transferBufferPool.numBuffers()); + } + + /** Close connection from server. */ + @Test + public void testCloseConnectionFromServer() throws Exception { + // Client sends ReadHandshakeRequest. + client.connect(); + client.open(); + checkUntil(() -> assertTrue(serverHandler.isConnected())); + + serverHandler.close(); + checkUntil(() -> assertTrue(cause.get() instanceof IOException)); + client.close(); + } + + /** ErrorResponse from server. */ + @Test + public void testErrorResponseFromServer() throws Exception { + // Client sends ReadHandshakeRequest. + client.connect(); + client.open(); + checkUntil(() -> assertTrue(serverHandler.isConnected())); + + // Server sends ErrorResponse. + String errMsg = "Expected exception."; + serverHandler.send( + new ErrorResponse( + currentProtocolVersion(), + client.getChannelID(), + errMsg.getBytes(), + emptyExtraMessage())); + checkUntil( + () -> { + assertTrue(cause.get().getCause() instanceof IOException); + assertEquals(errMsg, ((IOException) cause.get().getCause()).getMessage()); + }); + client.close(); + } + + // Feed data after close. + @Test + public void testFeedDataAfterClose() throws Exception { + client.connect(); + client.open(); + checkUntil(() -> assertTrue(serverHandler.isConnected())); + + client.close(); + + ByteBuf byteBuf = constructBuffers(1, 3).poll(); + client.dataReceived( + new ReadData( + currentProtocolVersion(), + client.getChannelID(), + 0, + byteBuf.readableBytes(), + emptyOffset(), + byteBuf, + emptyExtraMessage())); + assertEquals(0, byteBuf.refCnt()); + + // Test close multiple times. + client.close(); + } + + // Avoid deadlock caused by the competition between ReadClient upper layer lock and the limited + // Netty thread resource. + @Test + public void testDeadlockAvoidance() throws Exception { + DataSetID dataSetID = new DataSetID(CommonUtils.randomBytes(32)); + Object lock = new Object(); + nettyConfig.getConfig().setInteger(TransferOptions.NUM_THREADS_CLIENT, 1); + ConnectionManager connManager = + ConnectionManager.createReadConnectionManager(nettyConfig, false); + connManager.start(); + + // Prepare client0 + MapPartitionID mapID0 = new MapPartitionID(CommonUtils.randomBytes(16)); + int dataPort0 = getAvailablePort(); + Pair pair0 = initShuffleServer(dataPort0); + NettyServer nettyServer0 = pair0.getLeft(); + DummyChannelInboundHandlerAdaptor serverH0 = pair0.getRight(); + InetSocketAddress address0 = new InetSocketAddress(InetAddress.getLocalHost(), dataPort0); + List readData0 = new ArrayList<>(); + Consumer dataListener = + byteBuf -> { + synchronized (lock) { + readData0.add(byteBuf); + } + }; + TransferBufferPool clientBufferPool0 = new TestTransferBufferPool(NUM_BUFFERS, BUFFER_SIZE); + ShuffleReadClient shuffleReadClient0 = + new ShuffleReadClient( + address0, + dataSetID, + mapID0, + 0, + 0, + emptyBufferSize(), + clientBufferPool0, + connManager, + dataListener, + ignore -> {}); + shuffleReadClient0.connect(); + shuffleReadClient0.open(); + checkUntil(() -> assertTrue(serverH0.isConnected())); + + // Prepare client1 + MapPartitionID mapID1 = new MapPartitionID(CommonUtils.randomBytes(16)); + int dataPort1 = getAvailablePort(); + Pair pair1 = initShuffleServer(dataPort1); + NettyServer nettyServer1 = pair1.getLeft(); + InetSocketAddress address1 = new InetSocketAddress(InetAddress.getLocalHost(), dataPort1); + TransferBufferPool clientBufferPool1 = new TestTransferBufferPool(NUM_BUFFERS, BUFFER_SIZE); + ShuffleReadClient shuffleReadClient1 = + new ShuffleReadClient( + address1, + dataSetID, + mapID1, + 0, + 0, + emptyBufferSize(), + clientBufferPool1, + connManager, + ignore -> {}, + ignore -> {}); + shuffleReadClient1.connect(); + + // Testing procedure. + synchronized (lock) { + Thread t = + new Thread( + () -> { + ChannelID channelID0 = shuffleReadClient0.getChannelID(); + Queue serverBuffers = constructBuffers(1, 3); + ByteBuf byteBuf = serverBuffers.poll(); + serverH0.send( + new ReadData( + currentProtocolVersion(), + channelID0, + 0, + byteBuf.readableBytes(), + emptyOffset(), + byteBuf, + emptyExtraMessage())); + }); + t.start(); + while (getAllThreads(NettyConfig.CLIENT_THREAD_GROUP_NAME, Thread.State.BLOCKED) + .isEmpty()) { + Thread.sleep(1000); + } + shuffleReadClient1.open(); + } + + readData0.forEach(ReferenceCounted::release); + + shuffleReadClient0.close(); + shuffleReadClient1.close(); + nettyServer0.shutdown(); + nettyServer1.shutdown(); + clientBufferPool0.destroy(); + clientBufferPool1.destroy(); + } + + private Pair initShuffleServer(int dataPort) + throws Exception { + DummyChannelInboundHandlerAdaptor serverH = new DummyChannelInboundHandlerAdaptor(); + nettyConfig.getConfig().setInteger(TransferOptions.SERVER_DATA_PORT, dataPort); + NettyServer nettyServer = + new NettyServer(null, nettyConfig) { + @Override + public ChannelHandler[] getServerHandlers() { + return new ChannelHandler[] { + new TransferMessageEncoder(), + DecoderDelegate.serverDecoderDelegate(null), + serverH + }; + } + }; + nettyServer.start(); + return Pair.of(nettyServer, serverH); + } + + private List getAllThreads(String namePrefix, Thread.State state) { + ThreadGroup group = Thread.currentThread().getThreadGroup(); + ThreadGroup topGroup = group; + while (group != null) { + topGroup = group; + group = group.getParent(); + } + int estimatedSize = topGroup.activeCount() * 2; + Thread[] slackList = new Thread[estimatedSize]; + int actualSize = topGroup.enumerate(slackList); + Thread[] list = new Thread[actualSize]; + System.arraycopy(slackList, 0, list, 0, actualSize); + + List ret = new ArrayList<>(); + for (Thread t : list) { + if (t.getName().startsWith(namePrefix) && t.getState() == state) { + ret.add(t); + } + } + return ret; + } +} diff --git a/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/ShuffleServerTest.java b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/ShuffleServerTest.java new file mode 100644 index 00000000..27ae2f26 --- /dev/null +++ b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/ShuffleServerTest.java @@ -0,0 +1,704 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.exception.ShuffleException; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.config.TransferOptions; +import com.alibaba.flink.shuffle.core.ids.ChannelID; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.core.memory.Buffer; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReadingView; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWritingView; +import com.alibaba.flink.shuffle.core.storage.ReadingViewContext; +import com.alibaba.flink.shuffle.core.storage.WritingViewContext; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseChannel; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ErrorResponse; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadAddCredit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadData; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ReadHandshakeRequest; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteAddCredit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteData; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteFinish; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteHandshakeRequest; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteRegionFinish; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteRegionStart; +import com.alibaba.flink.shuffle.transfer.utils.NoOpPartitionedDataStore; + +import org.apache.flink.shaded.netty4.io.netty.buffer.AbstractReferenceCountedByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; +import org.apache.flink.shaded.netty4.io.netty.util.ReferenceCounted; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.currentProtocolVersion; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyBufferSize; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyDataPartitionType; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyExtraMessage; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyOffset; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** Test for shuffle server. */ +public class ShuffleServerTest extends AbstractNettyTest { + + private ChannelID channelID; + + private MapPartitionID mapID; + + private MapPartitionID mapIDToFail; + + private int numSubs; + + private NettyServer nettyServer; + + private FakedPartitionDataStore dataStore; + + private NettyClient writeNettyClient; + + private NettyClient readNettyClient; + + private DummyChannelInboundHandlerAdaptor writeClientH; + + private DummyChannelInboundHandlerAdaptor readClientH; + + private Channel readChannel; + + private Channel writeChannel; + + private final CreditListener creditListener = new TestCreditListener(); + + @Override + @Before + public void setup() throws Exception { + super.setup(); + channelID = new ChannelID(); + mapID = new MapPartitionID(CommonUtils.randomBytes(32)); + mapIDToFail = new MapPartitionID(CommonUtils.randomBytes(32)); + numSubs = 2; + int dataPort = initServer(); + address = new InetSocketAddress(InetAddress.getLocalHost(), dataPort); + } + + @Override + @After + public void tearDown() throws Exception { + dataStore.shutDown(true); + nettyServer.shutdown(); + if (writeNettyClient != null) { + writeNettyClient.shutdown(); + } + if (readNettyClient != null) { + readNettyClient.shutdown(); + } + super.tearDown(); + } + + /** Basic writing routine. */ + @Test + public void testWritingRoutine() throws Exception { + initWriteClient(); + + // Client sends WriteHandshakeRequest, receives WriteAddCredit. + writeChannel.writeAndFlush( + new WriteHandshakeRequest( + currentProtocolVersion(), + channelID, + jobID, + dataSetID, + mapID, + numSubs, + emptyBufferSize(), + emptyDataPartitionType(), + emptyExtraMessage())); + Queue buffersToSend = constructBuffers(2, 3); + + // Client sends WriteRegionStart. + writeChannel.writeAndFlush( + new WriteRegionStart( + currentProtocolVersion(), channelID, 0, false, emptyExtraMessage())); + checkUntil(() -> assertEquals(1, dataStore.writingView.getRegionStartCount())); + checkUntil(() -> assertEquals(1, writeClientH.numMessages())); + assertTrue(writeClientH.getMsg(0) instanceof WriteAddCredit); + WriteAddCredit writeAddCredit = (WriteAddCredit) writeClientH.getMsg(0); + assertEquals(1, writeAddCredit.getCredit()); + + // Client sends WriteData, receives WriteAddCredit. + ByteBuf buffer = buffersToSend.poll(); + writeChannel.writeAndFlush( + new WriteData( + currentProtocolVersion(), + channelID, + buffer, + 0, + checkNotNull(buffer).readableBytes(), + false, + emptyExtraMessage())); + checkUntil(() -> assertEquals(1, dataStore.receivedBuffers.size())); + checkUntil(() -> assertEquals(2, writeClientH.numMessages())); + assertTrue(writeClientH.getMsg(1) instanceof WriteAddCredit); + writeAddCredit = (WriteAddCredit) writeClientH.getMsg(1); + assertEquals(1, writeAddCredit.getCredit()); + + // Client sends WriteData, receives WriteAddCredit. + buffer = buffersToSend.poll(); + writeChannel.writeAndFlush( + new WriteData( + currentProtocolVersion(), + channelID, + buffer, + 0, + checkNotNull(buffer).readableBytes(), + false, + emptyExtraMessage())); + checkUntil(() -> assertEquals(2, dataStore.receivedBuffers.size())); + checkUntil(() -> assertEquals(3, writeClientH.numMessages())); + assertTrue(writeClientH.getMsg(2) instanceof WriteAddCredit); + writeAddCredit = (WriteAddCredit) writeClientH.getMsg(2); + assertEquals(1, writeAddCredit.getCredit()); + + // Client sends WriteRegionFinish and WriteFinish. + writeChannel.writeAndFlush( + new WriteRegionFinish(currentProtocolVersion(), channelID, emptyExtraMessage())); + checkUntil(() -> assertEquals(1, dataStore.writingView.getRegionFinishCount())); + writeChannel.writeAndFlush( + new WriteFinish(currentProtocolVersion(), channelID, emptyExtraMessage())); + checkUntil(() -> assertTrue(dataStore.writingView.isFinished())); + } + + /** + * Failure in {@link DataPartitionWritingView} should trigger and send ErrorResponse to client. + */ + @Test + public void testWritingButWritingViewTriggerFailure() throws Exception { + initWriteClient(); + + // Client sends WriteHandshakeRequest, receives WriteAddCredit. + writeChannel.writeAndFlush( + new WriteHandshakeRequest( + currentProtocolVersion(), + channelID, + jobID, + dataSetID, + mapID, + numSubs, + emptyBufferSize(), + emptyDataPartitionType(), + emptyExtraMessage())); + checkUntil(() -> assertTrue(writeClientH.isConnected())); + checkUntil(() -> assertNotNull(dataStore.writingView)); + + // Failure in dataStore. + dataStore.writingView.triggerFailure(new Exception("Expected exception")); + checkUntil(() -> assertEquals(1, writeClientH.numMessages())); + assertTrue(writeClientH.getLastMsg().toString().contains("Exception: Expected exception")); + } + + /** {@link DataPartitionWritingView} should receive {@link Throwable} when broken connection. */ + @Test + public void testWritingViewReceiveThrowableWhenBrokenConnection() throws Exception { + initWriteClient(); + int prevBuffers = transferBufferPool.numBuffers(); + + // Client sends WriteHandshakeRequest, receives WriteAddCredit. + writeChannel.writeAndFlush( + new WriteHandshakeRequest( + currentProtocolVersion(), + channelID, + jobID, + dataSetID, + mapID, + numSubs, + emptyBufferSize(), + emptyDataPartitionType(), + emptyExtraMessage())); + checkUntil(() -> assertTrue(writeClientH.isConnected())); + checkUntil(() -> assertNotNull(dataStore.writingView)); + checkUntil(() -> assertEquals(prevBuffers - 1, transferBufferPool.numBuffers())); + + // Client closes connection. + writeClientH.close(); + Class clazz = ShuffleException.class; + checkUntil(() -> assertEquals(clazz, dataStore.writingView.getError().getClass())); + checkUntil(() -> assertEquals(prevBuffers, transferBufferPool.numBuffers())); + } + + // Test CloseChannel. + @Test + public void testWritingViewReceiveThrowableWhenChannelClosed() throws Exception { + initWriteClient(); + + checkUntil(() -> assertTrue(writeClientH.isConnected())); + writeChannel.writeAndFlush( + new WriteHandshakeRequest( + currentProtocolVersion(), + channelID, + jobID, + dataSetID, + mapID, + numSubs, + emptyBufferSize(), + emptyDataPartitionType(), + emptyExtraMessage())); + writeClientH.send( + new CloseChannel(currentProtocolVersion(), channelID, emptyExtraMessage())); + checkUntil(() -> assertNotNull(dataStore.writingView.getError())); + assertTrue( + dataStore + .writingView + .getError() + .getMessage() + .contains("Channel closed abnormally")); + } + + // Decoding error should result in ErrorResponse. + @Test + public void testDecodingErrorWillCauseErrorResponse() throws Exception { + initWriteClient(); + + writeChannel.writeAndFlush( + new WriteHandshakeRequest( + currentProtocolVersion(), + channelID, + jobID, + dataSetID, + mapID, + numSubs, + emptyBufferSize(), + emptyDataPartitionType(), + emptyExtraMessage())); + writeChannel.writeAndFlush( + new WriteRegionStart( + currentProtocolVersion(), channelID, 0, false, emptyExtraMessage())); + checkUntil(() -> assertTrue(writeClientH.isConnected())); + checkUntil(() -> assertNotNull(dataStore.writingView)); + List candidateBuffers = dataStore.writingView.getCandidateBuffers(); + candidateBuffers.forEach(AbstractReferenceCountedByteBuf::release); + candidateBuffers.clear(); + + Queue buffersToSend = constructBuffers(1, 3); + ByteBuf buffer = buffersToSend.poll(); + writeChannel.writeAndFlush( + new WriteData( + currentProtocolVersion(), + channelID, + buffer, + 0, + checkNotNull(buffer).readableBytes(), + false, + emptyExtraMessage())); + checkUntil(() -> assertEquals(2, writeClientH.numMessages())); + assertTrue(writeClientH.getLastMsg().getClass() == ErrorResponse.class); + assertTrue( + writeClientH + .getLastMsg() + .toString() + .contains("java.lang.IllegalStateException: No buffers available.")); + } + + /** Writing channels share the same connection do not disturb each other. */ + @Test + public void testWritingChannelsCanShareSameConnection() throws Exception { + initWriteClient(); + + writeChannel.writeAndFlush( + new WriteHandshakeRequest( + currentProtocolVersion(), + channelID, + jobID, + dataSetID, + mapID, + numSubs, + emptyBufferSize(), + emptyDataPartitionType(), + emptyExtraMessage())); + writeChannel.writeAndFlush( + new WriteRegionStart( + currentProtocolVersion(), channelID, 0, false, emptyExtraMessage())); + checkUntil(() -> assertEquals(1, writeClientH.numMessages())); + assertTrue(writeClientH.getMsg(0) instanceof WriteAddCredit); + ChannelID channelID1 = new ChannelID(); + writeChannel.writeAndFlush( + new WriteHandshakeRequest( + currentProtocolVersion(), + channelID1, + jobID, + dataSetID, + mapIDToFail, + numSubs, + emptyBufferSize(), + emptyDataPartitionType(), + emptyExtraMessage())); + checkUntil(() -> assertEquals(2, writeClientH.numMessages())); + assertTrue(writeClientH.getMsg(1) instanceof ErrorResponse); + assertEquals(channelID1, ((ErrorResponse) writeClientH.getMsg(1)).getChannelID()); + + writeChannel.writeAndFlush( + new WriteRegionFinish(currentProtocolVersion(), channelID, emptyExtraMessage())); + writeChannel.writeAndFlush( + new WriteFinish(currentProtocolVersion(), channelID, emptyExtraMessage())); + checkUntil(() -> assertTrue(dataStore.writingView.isFinished())); + } + + /** Basic reading routine. */ + @Test + public void testReadingRoutine() throws Exception { + initReadClient(); + List buffersToReceive = new ArrayList<>(constructBuffers(3, 3)); + + // Client sends ReadHandshakeRequest. + readChannel.writeAndFlush( + new ReadHandshakeRequest( + currentProtocolVersion(), + channelID, + dataSetID, + mapID, + 0, + 0, + 1, + emptyBufferSize(), + emptyOffset(), + emptyExtraMessage())); + checkUntil(() -> assertNotNull(dataStore.readingView)); + + // Server sends ReadData. + dataStore.readingView.notifyBuffers(buffersToReceive.subList(0, 2)); + checkUntil(() -> assertEquals(1, readClientH.numMessages())); + assertTrue(readClientH.getMsg(0) instanceof ReadData); + + // Client sends ReadAddCredit. + // Server sends ReadData. + readChannel.writeAndFlush( + new ReadAddCredit(currentProtocolVersion(), channelID, 2, emptyExtraMessage())); + checkUntil(() -> assertEquals(2, readClientH.numMessages())); + assertTrue(readClientH.getMsg(1) instanceof ReadData); + + // Server sends ReadData. + dataStore.readingView.notifyBuffers(buffersToReceive.subList(2, 3)); + checkUntil(() -> assertEquals(3, readClientH.numMessages())); + checkUntil(() -> assertTrue(readClientH.getMsg(2) instanceof ReadData)); + + dataStore.readingView.setNoMoreData(true); + assertTrue(dataStore.readingView.isFinished()); + + List buffers = + readClientH.getMessages().stream() + .filter(obj -> obj instanceof ReadData) + .map(obj -> ((ReadData) obj).getBuffer()) + .collect(Collectors.toList()); + verifyBuffers(3, 3, buffers); + buffers.forEach(ReferenceCounted::release); + } + + /** + * Failure in {@link DataPartitionReadingView} should trigger and send ErrorResponse to client. + */ + @Test + public void testReadingButReadingViewTriggerFailure() throws Exception { + initReadClient(); + + // Client sends ReadHandshakeRequest. + readChannel.writeAndFlush( + new ReadHandshakeRequest( + currentProtocolVersion(), + channelID, + dataSetID, + mapID, + 0, + 0, + 1, + emptyBufferSize(), + emptyOffset(), + emptyExtraMessage())); + checkUntil(() -> assertNotNull(dataStore.readingView)); + + dataStore.readingView.triggerFailure(new Exception("Expected exception")); + + checkUntil(() -> assertEquals(1, readClientH.numMessages())); + checkUntil( + () -> + assertTrue( + readClientH + .getLastMsg() + .toString() + .contains("Expected exception"))); + } + + /** {@link DataPartitionReadingView} should receive {@link Throwable} when broken connection. */ + @Test + public void testReadingViewReceiveThrowableWhenBrokenConnection() throws Exception { + initReadClient(); + + // Client sends ReadHandshakeRequest. + readChannel.writeAndFlush( + new ReadHandshakeRequest( + currentProtocolVersion(), + channelID, + dataSetID, + mapID, + 0, + 0, + 1, + emptyBufferSize(), + emptyOffset(), + emptyExtraMessage())); + checkUntil(() -> assertNotNull(dataStore.readingView)); + + // Client closes connection. + readClientH.close(); + Class clazz = ShuffleException.class; + checkUntil(() -> assertEquals(clazz, dataStore.readingView.getError().getClass())); + } + + // Test CloseChannel. + @Test + public void testReadingViewReceiveThrowableWhenChannelClosed() throws Exception { + initReadClient(); + + readChannel.writeAndFlush( + new ReadHandshakeRequest( + currentProtocolVersion(), + channelID, + dataSetID, + mapID, + 0, + 0, + 1, + emptyBufferSize(), + emptyOffset(), + emptyExtraMessage())); + checkUntil(() -> assertNotNull(dataStore.readingView)); + + readChannel.writeAndFlush( + new CloseChannel(currentProtocolVersion(), channelID, emptyExtraMessage())); + checkUntil(() -> assertNotNull(dataStore.readingView.getError())); + assertTrue( + dataStore + .readingView + .getError() + .getMessage() + .contains("Channel closed abnormally")); + } + + /** DataSender error should result in ErrorResponse. */ + @Test + public void testDataSenderErrorWillCauseErrorResponse() throws Exception { + initReadClient(); + readChannel.writeAndFlush( + new ReadHandshakeRequest( + currentProtocolVersion(), + channelID, + dataSetID, + mapID, + 0, + 0, + 1, + emptyBufferSize(), + emptyOffset(), + emptyExtraMessage())); + checkUntil(() -> assertNotNull(dataStore.readingView)); + dataStore.readingView.setBuffersToSend( + new ArrayDeque() { + @Override + public ByteBuf poll() { + throw new ShuffleException("Poll failure"); + } + }); + dataStore.readingView.getDataAvailableListener().notifyDataAvailable(); + checkUntil(() -> assertEquals(1, readClientH.numMessages())); + assertTrue(readClientH.getLastMsg().getClass() == ErrorResponse.class); + assertTrue(readClientH.getLastMsg().toString().contains("Poll failure")); + } + + /** Reading channels share the same connection do not disturb each other. */ + @Test + public void testReadingChannelsCanShareSameConnection() throws Exception { + initReadClient(); + List buffersToReceive = new ArrayList<>(constructBuffers(3, 3)); + + // Client sends ReadHandshakeRequest. + readChannel.writeAndFlush( + new ReadHandshakeRequest( + currentProtocolVersion(), + channelID, + dataSetID, + mapID, + 0, + 0, + 1, + emptyBufferSize(), + emptyOffset(), + emptyExtraMessage())); + checkUntil(() -> assertNotNull(dataStore.readingView)); + + // Server sends ReadData. + dataStore.readingView.notifyBuffers(buffersToReceive.subList(0, 2)); + checkUntil(() -> assertEquals(1, readClientH.numMessages())); + assertTrue(readClientH.getMsg(0) instanceof ReadData); + + // Client sends a handshake to fail the logic channel + ChannelID channelID1 = new ChannelID(); + readChannel.writeAndFlush( + new ReadHandshakeRequest( + currentProtocolVersion(), + channelID1, + dataSetID, + mapIDToFail, + 0, + 0, + 1, + emptyBufferSize(), + emptyOffset(), + emptyExtraMessage())); + checkUntil(() -> assertEquals(2, readClientH.numMessages())); + assertTrue(readClientH.getMsg(1) instanceof ErrorResponse); + assertEquals(channelID1, ((ErrorResponse) readClientH.getMsg(1)).getChannelID()); + + // Client sends ReadAddCredit. + // Server sends ReadData. + readChannel.writeAndFlush( + new ReadAddCredit(currentProtocolVersion(), channelID, 2, emptyExtraMessage())); + checkUntil(() -> assertEquals(3, readClientH.numMessages())); + assertTrue(readClientH.getMsg(2) instanceof ReadData); + + // Server sends ReadData. + dataStore.readingView.notifyBuffers(buffersToReceive.subList(2, 3)); + checkUntil(() -> assertEquals(4, readClientH.numMessages())); + checkUntil(() -> assertTrue(readClientH.getMsg(3) instanceof ReadData)); + + dataStore.readingView.setNoMoreData(true); + assertTrue(dataStore.readingView.isFinished()); + + List buffers = + readClientH.getMessages().stream() + .filter(obj -> obj instanceof ReadData) + .map(obj -> ((ReadData) obj).getBuffer()) + .collect(Collectors.toList()); + verifyBuffers(3, 3, buffers); + buffers.forEach(ReferenceCounted::release); + } + + private void initWriteClient() throws Exception { + writeClientH = new DummyChannelInboundHandlerAdaptor(); + writeNettyClient = new NettyClient(nettyConfig); + writeNettyClient.init( + () -> + new ChannelHandler[] { + new TransferMessageEncoder(), + DecoderDelegate.writeClientDecoderDelegate(), + writeClientH + }); + writeChannel = writeNettyClient.connect(address).await().channel(); + } + + private void initReadClient() throws Exception { + readClientH = new DummyChannelInboundHandlerAdaptor(); + readNettyClient = new NettyClient(nettyConfig); + readNettyClient.init( + () -> + new ChannelHandler[] { + new TransferMessageEncoder(), + DecoderDelegate.readClientDecoderDelegate( + ignore -> () -> transferBufferPool.requestBuffer()), + readClientH + }); + readChannel = readNettyClient.connect(address).await().channel(); + } + + private int initServer() throws Exception { + dataStore = new FakedPartitionDataStore(); + int dataPort = getAvailablePort(); + nettyConfig.getConfig().setInteger(TransferOptions.SERVER_DATA_PORT, dataPort); + nettyServer = new NettyServer(dataStore, nettyConfig); + nettyServer.disableHeartbeat(); + nettyServer.start(); + return dataPort; + } + + private class FakedPartitionDataStore extends NoOpPartitionedDataStore { + + private final ArrayDeque receivedBuffers; + + private FakedDataPartitionWritingView writingView; + + private FakedDataPartitionReadingView readingView; + + public FakedPartitionDataStore() { + this.receivedBuffers = new ArrayDeque<>(); + } + + @Override + public DataPartitionWritingView createDataPartitionWritingView(WritingViewContext context) { + if (context.getMapPartitionID().equals(mapIDToFail)) { + throw new ShuffleException("Fail a writing handshake."); + } + + Consumer dataHandler = + buffer -> { + ByteBuf buffer0 = transferBufferPool.requestBuffer(); + buffer0.writeBytes(buffer); + dataStore.receivedBuffers.add(buffer0); + }; + List buffers = new ArrayList<>(); + buffers.add((Buffer) transferBufferPool.requestBuffer()); + writingView = + new FakedDataPartitionWritingView( + context.getDataSetID(), + context.getMapPartitionID(), + dataHandler, + context.getDataRegionCreditListener(), + context.getFailureListener(), + buffers); + return writingView; + } + + @Override + public DataPartitionReadingView createDataPartitionReadingView(ReadingViewContext context) { + if (context.getPartitionID().equals(mapIDToFail)) { + throw new ShuffleException("Fail a reading handshake."); + } + readingView = + new FakedDataPartitionReadingView( + context.getDataListener(), + context.getBacklogListener(), + context.getFailureListener()); + return readingView; + } + + @Override + public void shutDown(boolean releaseData) { + receivedBuffers.forEach(ReferenceCounted::release); + } + } +} diff --git a/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/ShuffleWriteClientTest.java b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/ShuffleWriteClientTest.java new file mode 100644 index 00000000..548e49ff --- /dev/null +++ b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/ShuffleWriteClientTest.java @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.config.TransferOptions; +import com.alibaba.flink.shuffle.core.ids.MapPartitionID; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseChannel; +import com.alibaba.flink.shuffle.transfer.TransferMessage.CloseConnection; +import com.alibaba.flink.shuffle.transfer.TransferMessage.ErrorResponse; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteAddCredit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteData; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteFinish; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteFinishCommit; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteHandshakeRequest; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteRegionFinish; +import com.alibaba.flink.shuffle.transfer.TransferMessage.WriteRegionStart; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; + +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.currentProtocolVersion; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyBufferSize; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyDataPartitionType; +import static com.alibaba.flink.shuffle.common.utils.ProtocolUtils.emptyExtraMessage; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +/** Test for {@link ShuffleWriteClient}. */ +public class ShuffleWriteClientTest extends AbstractNettyTest { + + private NettyServer nettyServer; + + private ConnectionManager connManager; + + private ShuffleWriteClient client0; + + private ShuffleWriteClient client1; + + private volatile DummyChannelInboundHandlerAdaptor serverH; + + private final CreditListener creditListener = new TestCreditListener(); + + @Override + @Before + public void setup() throws Exception { + super.setup(); + int dataPort = initShuffleServer(); + MapPartitionID mapID0 = new MapPartitionID(CommonUtils.randomBytes(16)); + MapPartitionID mapID1 = new MapPartitionID(CommonUtils.randomBytes(16)); + int subsNum = 2; + connManager = ConnectionManager.createWriteConnectionManager(nettyConfig, false); + connManager.start(); + address = new InetSocketAddress(InetAddress.getLocalHost(), dataPort); + client0 = + new ShuffleWriteClient( + address, + jobID, + dataSetID, + mapID0, + subsNum, + emptyBufferSize(), + emptyDataPartitionType(), + connManager); + client1 = + new ShuffleWriteClient( + address, + jobID, + dataSetID, + mapID1, + subsNum, + emptyBufferSize(), + emptyDataPartitionType(), + connManager); + } + + @Override + @After + public void tearDown() throws Exception { + connManager.shutdown(); + serverH.close(); + nettyServer.shutdown(); + super.tearDown(); + } + + /** Basic routine. */ + @Test + public void testWriteDataAndSendCredit() throws Exception { + Queue buffersToSend = constructBuffers(4, 3); + int subIdx = 0; + + checkUntil(() -> assertTrue(serverH.isEmpty())); + + // Client send WriteHandshakeRequest. + runAsync(() -> client0.open()); + checkUntil(() -> assertEquals(1, serverH.numMessages())); + assertTrue(serverH.isConnected()); + assertTrue(serverH.getLastMsg() instanceof WriteHandshakeRequest); + + // Client sends WriteRegionStart; + runAsync(() -> client0.regionStart(false)); + checkUntil(() -> assertEquals(2, serverH.numMessages())); + assertTrue(serverH.getLastMsg() instanceof WriteRegionStart); + + // Client sends WriteData. + runAsync(() -> client0.write(buffersToSend.poll(), subIdx)); + checkUntil(() -> assertTrue(client0.isWaitingForCredit())); + + // Server sends 1 credit. + serverH.send( + new WriteAddCredit( + currentProtocolVersion(), + client0.getChannelID(), + 1, + 0, + emptyExtraMessage())); + checkUntil(() -> assertEquals(3, serverH.numMessages())); + checkUntil(() -> assertTrue(serverH.getLastMsg() instanceof WriteData)); + + // Client sends a WriteData. + runAsync(() -> client0.write(buffersToSend.poll(), subIdx)); + checkUntil(() -> assertTrue(client0.isWaitingForCredit())); + + // Server sends 1 credit. + // Client sends WriteData and WriteRegionFinish. + serverH.send( + new WriteAddCredit( + currentProtocolVersion(), + client0.getChannelID(), + 1, + 0, + emptyExtraMessage())); + runAsync(() -> client0.regionFinish()); + checkUntil(() -> assertEquals(5, serverH.numMessages())); + assertTrue(serverH.getMsg(3) instanceof WriteData); + assertTrue(serverH.getMsg(4) instanceof WriteRegionFinish); + + // Server sends outdated credits + serverH.send( + new WriteAddCredit( + currentProtocolVersion(), + client0.getChannelID(), + 1, + 0, + emptyExtraMessage())); + serverH.send( + new WriteAddCredit( + currentProtocolVersion(), + client0.getChannelID(), + 1, + 0, + emptyExtraMessage())); + runAsync( + () -> { + client0.write(buffersToSend.poll(), subIdx); + client0.write(buffersToSend.poll(), subIdx); + client0.regionFinish(); + client0.finish(); + }); + checkUntil(() -> assertTrue(client0.isWaitingForCredit())); + + // Server sends 2 credits. + // Client sends WriteData and WriteRegionFinish and WriteFinish. + serverH.send( + new WriteAddCredit( + currentProtocolVersion(), + client0.getChannelID(), + 2, + 1, + emptyExtraMessage())); + + checkUntil(() -> assertEquals(9, serverH.numMessages())); + assertTrue(serverH.getMsg(5) instanceof WriteData); + assertTrue(serverH.getMsg(6) instanceof WriteData); + assertTrue(serverH.getMsg(7) instanceof WriteRegionFinish); + assertTrue(serverH.getMsg(8) instanceof WriteFinish); + assertTrue(client0.isWaitingForFinishCommit()); + + // Server sends WriteFinishCommit. + serverH.send( + new WriteFinishCommit( + currentProtocolVersion(), client0.getChannelID(), emptyExtraMessage())); + checkUntil(() -> assertFalse(client0.isWaitingForFinishCommit())); + + client0.close(); + checkUntil(() -> assertEquals(11, serverH.numMessages())); + assertTrue(serverH.getMsg(9) instanceof CloseChannel); + assertTrue(serverH.getMsg(10) instanceof CloseConnection); + + // verify buffers + List receivedBuffers = new ArrayList<>(); + serverH.getMessages().stream() + .filter(o -> o instanceof WriteData) + .forEach(obj -> receivedBuffers.add(((WriteData) obj).getBuffer())); + verifyBuffers(4, 3, receivedBuffers); + } + + /** Client receives a {@link Throwable} when broken connection. */ + @Test + public void testClientReceiveThrowableWhenBrokenConnection() throws Exception { + // Client sends WriteHandshakeRequest. + runAsync(() -> client0.open()); + checkUntil(() -> assertTrue(serverH.isConnected())); + + // Server closes connection. + serverH.close(); + checkUntil(() -> assertTrue(client0.getCause() instanceof IOException)); + + client0.close(); + } + + /** Client gets notified lock when broken connection. */ + @Test + public void testClientGetsNotifiedLockWhenBrokenConnection() throws Exception { + Queue buffersToSend = constructBuffers(1, 3); + int subIdx = 0; + + // Client send WriteHandshakeRequest. + runAsync(() -> client0.open()); + checkUntil(() -> assertTrue(serverH.isConnected())); + + // Client sends WriteRegionStart. + runAsync(() -> client0.regionStart(false)); + + // Client sends WriteData. + ByteBuf byteBuf = buffersToSend.poll(); + runAsync(() -> client0.write(byteBuf, subIdx)); + checkUntil(() -> assertTrue(client0.isWaitingForCredit())); + + // Server closes connection. + serverH.close(); + checkUntil(() -> assertTrue(!client0.isWaitingForCredit())); + + client0.close(); + } + + /** ErrorResponse from server. */ + @Test + public void testReceiveErrorResponseFromServer() throws Exception { + // Client sends WriteHandshakeRequest. + runAsync(() -> client0.open()); + checkUntil(() -> assertTrue(serverH.isConnected())); + + // ErrorResponse from server. + String errMsg = "Expected exception."; + serverH.send( + new ErrorResponse( + currentProtocolVersion(), + client0.getChannelID(), + errMsg.getBytes(), + emptyExtraMessage())); + checkUntil( + () -> { + assertTrue(client0.getCause() instanceof IOException); + assertTrue(client0.getCause().getCause().getMessage().contains(errMsg)); + }); + + Queue buffersToSend = constructBuffers(1, 3); + assertThrows(Exception.class, () -> client0.write(buffersToSend.poll(), 0)); + + client0.close(); + } + + /** Multiple channels shared the same physical connection. */ + @Test + public void testMultipleChannelsSharedSamePhysicalConnection() throws Exception { + runAsync(() -> client0.open()); + runAsync(() -> client1.open()); + checkUntil(() -> assertEquals(2, serverH.numMessages())); + + Channel channel0 = connManager.getChannel(client0.getChannelID(), address); + Channel channel1 = connManager.getChannel(client1.getChannelID(), address); + WriteClientHandler clientHandler = channel0.pipeline().get(WriteClientHandler.class); + assertEquals(channel0, channel1); + + serverH.send( + new WriteAddCredit( + currentProtocolVersion(), + client0.getChannelID(), + 2, + 0, + emptyExtraMessage())); + serverH.send( + new WriteAddCredit( + currentProtocolVersion(), + client1.getChannelID(), + 2, + 0, + emptyExtraMessage())); + Queue buffersToSend = constructBuffers(4, 3); + runAsync(() -> client0.write(buffersToSend.poll(), 0)); + runAsync(() -> client1.write(buffersToSend.poll(), 0)); + checkUntil(() -> assertEquals(4, serverH.numMessages())); + assertTrue(serverH.getMsg(0) instanceof WriteHandshakeRequest); + assertTrue(serverH.getMsg(1) instanceof WriteHandshakeRequest); + assertTrue(serverH.getMsg(2) instanceof WriteData); + assertTrue(serverH.getMsg(3) instanceof WriteData); + + runAsync(() -> client0.write(buffersToSend.poll(), 1)); + checkUntil(() -> assertEquals(5, serverH.numMessages())); + assertTrue(serverH.getMsg(4) instanceof WriteData); + + client0.close(); + assertFalse(clientHandler.isRegistered(client0.getChannelID())); + assertTrue(channel1.isActive()); + assertTrue(clientHandler.isRegistered(client1.getChannelID())); + checkUntil(() -> assertEquals(6, serverH.numMessages())); + assertTrue(serverH.getMsg(5) instanceof CloseChannel); + + runAsync(() -> client1.write(buffersToSend.poll(), 1)); + checkUntil(() -> assertEquals(7, serverH.numMessages())); + assertTrue(serverH.getMsg(6) instanceof WriteData); + + client1.close(); + checkUntil(() -> assertEquals(9, serverH.numMessages())); + assertTrue(serverH.getMsg(7) instanceof CloseChannel); + assertTrue(serverH.getMsg(8) instanceof CloseConnection); + checkUntil(() -> assertFalse(channel1.isActive())); + assertFalse(clientHandler.isRegistered(client1.getChannelID())); + } + + private int initShuffleServer() throws Exception { + serverH = new DummyChannelInboundHandlerAdaptor(); + int dataPort = getAvailablePort(); + nettyConfig.getConfig().setInteger(TransferOptions.SERVER_DATA_PORT, dataPort); + nettyServer = + new NettyServer(null, nettyConfig) { + @Override + public ChannelHandler[] getServerHandlers() { + return new ChannelHandler[] { + new TransferMessageEncoder(), + DecoderDelegate.serverDecoderDelegate(ignore -> () -> requestBuffer()), + serverH + }; + } + }; + nettyServer.start(); + return dataPort; + } +} diff --git a/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/TestCreditListener.java b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/TestCreditListener.java new file mode 100644 index 00000000..4a5e0ed6 --- /dev/null +++ b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/TestCreditListener.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +/** A {@link CreditListener} implementation for tests. */ +public class TestCreditListener extends CreditListener { + + @Override + public void notifyAvailableCredits(int numCredits) {} +} diff --git a/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/TestTransferBufferPool.java b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/TestTransferBufferPool.java new file mode 100644 index 00000000..480858ef --- /dev/null +++ b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/TestTransferBufferPool.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.core.memory.Buffer; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** A {@link TransferBufferPool} as a testing utility. */ +public class TestTransferBufferPool extends TransferBufferPool { + + public TestTransferBufferPool(int numBuffers, int bufferSize) { + super(Collections.emptyList()); + + List buffers = new ArrayList<>(allocateBuffers(numBuffers, bufferSize)); + addBuffers(buffers); + } + + public ByteBuf requestBufferBlocking() { + while (true) { + ByteBuf byteBuf = requestBuffer(); + if (byteBuf != null) { + return byteBuf; + } + CommonUtils.runQuietly(() -> Thread.sleep(10)); + } + } + + private List allocateBuffers(int numBuffers, int bufferSize) { + List buffers = new ArrayList<>(numBuffers); + for (int i = 0; i < numBuffers; i++) { + ByteBuf byteBuf = new Buffer(ByteBuffer.allocateDirect(bufferSize), this, 0); + buffers.add(byteBuf); + } + return buffers; + } +} diff --git a/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/utils/NoOpPartitionedDataStore.java b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/utils/NoOpPartitionedDataStore.java new file mode 100644 index 00000000..39853191 --- /dev/null +++ b/shuffle-transfer/src/test/java/com/alibaba/flink/shuffle/transfer/utils/NoOpPartitionedDataStore.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.transfer.utils; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.executor.SingleThreadExecutorPool; +import com.alibaba.flink.shuffle.core.ids.DataPartitionID; +import com.alibaba.flink.shuffle.core.ids.DataSetID; +import com.alibaba.flink.shuffle.core.ids.JobID; +import com.alibaba.flink.shuffle.core.memory.BufferDispatcher; +import com.alibaba.flink.shuffle.core.storage.DataPartitionMeta; +import com.alibaba.flink.shuffle.core.storage.DataPartitionReadingView; +import com.alibaba.flink.shuffle.core.storage.DataPartitionWritingView; +import com.alibaba.flink.shuffle.core.storage.PartitionedDataStore; +import com.alibaba.flink.shuffle.core.storage.ReadingViewContext; +import com.alibaba.flink.shuffle.core.storage.StorageMeta; +import com.alibaba.flink.shuffle.core.storage.WritingViewContext; + +import javax.annotation.Nullable; + +/** An empty partitioned data store used for tests. */ +public class NoOpPartitionedDataStore implements PartitionedDataStore { + + @Override + public DataPartitionWritingView createDataPartitionWritingView(WritingViewContext context) + throws Exception { + return null; + } + + @Override + public DataPartitionReadingView createDataPartitionReadingView(ReadingViewContext context) + throws Exception { + return null; + } + + @Override + public boolean isDataPartitionConsumable(DataPartitionMeta partitionMeta) { + return false; + } + + @Override + public void addDataPartition(DataPartitionMeta partitionMeta) throws Exception {} + + @Override + public void removeDataPartition(DataPartitionMeta partitionMeta) {} + + @Override + public void releaseDataPartition( + DataSetID dataSetID, DataPartitionID partitionID, @Nullable Throwable throwable) {} + + @Override + public void releaseDataSet(DataSetID dataSetID, @Nullable Throwable throwable) {} + + @Override + public void releaseDataByJobID(JobID jobID, @Nullable Throwable throwable) {} + + @Override + public void shutDown(boolean releaseData) {} + + @Override + public boolean isShutDown() { + return false; + } + + @Override + public Configuration getConfiguration() { + return null; + } + + @Override + public BufferDispatcher getWritingBufferDispatcher() { + return null; + } + + @Override + public BufferDispatcher getReadingBufferDispatcher() { + return null; + } + + @Override + public SingleThreadExecutorPool getExecutorPool(StorageMeta storageMeta) { + return null; + } +} diff --git a/shuffle-transfer/src/test/resources/log4j2-test.properties b/shuffle-transfer/src/test/resources/log4j2-test.properties new file mode 100644 index 00000000..d7fcb327 --- /dev/null +++ b/shuffle-transfer/src/test/resources/log4j2-test.properties @@ -0,0 +1,26 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +rootLogger.level=OFF +rootLogger.appenderRef.test.ref=TestLogger +appender.testlogger.name=TestLogger +appender.testlogger.type=CONSOLE +appender.testlogger.target=SYSTEM_ERR +appender.testlogger.layout.type=PatternLayout +appender.testlogger.layout.pattern=%-4r [%t] %-5p %c %x - %m%n diff --git a/shuffle-yarn/pom.xml b/shuffle-yarn/pom.xml new file mode 100644 index 00000000..25f4f1d0 --- /dev/null +++ b/shuffle-yarn/pom.xml @@ -0,0 +1,295 @@ + + + + + com.alibaba.flink.shuffle + flink-shuffle-parent + 1.0-SNAPSHOT + + 4.0.0 + + shuffle-yarn + + + + com.alibaba.flink.shuffle + shuffle-coordinator + ${project.version} + + + org.apache.flink + * + + + + + + com.alibaba.flink.shuffle + shuffle-common + ${project.version} + + + + com.alibaba.flink.shuffle + shuffle-plugin + ${project.version} + + + org.apache.flink + * + + + + + + org.apache.flink + flink-yarn_${scala.binary.version} + ${flink.version} + test + + + org.apache.commons + commons-math3 + + + + + + org.apache.flink + flink-test-utils-junit + ${flink.version} + test + + + + org.apache.curator + curator-test + ${curator.version} + + + log4j + log4j + + + test + + + + org.apache.hadoop + hadoop-hdfs + ${hadoop.version} + provided + + + jdk.tools + jdk.tools + + + log4j + log4j + + + org.slf4j + slf4j-log4j12 + + + + + + org.apache.hadoop + hadoop-yarn-client + ${hadoop.version} + provided + + + jdk.tools + jdk.tools + + + log4j + log4j + + + org.slf4j + slf4j-log4j12 + + + + + + org.apache.hadoop + hadoop-minicluster + ${hadoop.version} + provided + + + log4j + log4j + + + org.slf4j + slf4j-log4j12 + + + org.apache.commons + commons-math3 + + + + + + commons-cli + commons-cli + ${commons.cli.version} + + + + + org.apache.logging.log4j + log4j-1.2-api + test + + + + + com.google.guava + guava + ${java.guava.version} + test + + + + org.apache.flink + flink-runtime + ${flink.version} + test + + + + org.apache.flink + flink-examples-batch_${scala.binary.version} + ${flink.version} + test + + + + com.alibaba.flink.shuffle + shuffle-core + ${project.version} + test-jar + test + + + + org.apache.flink + flink-runtime + ${flink.version} + test-jar + test + + + + + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + copy + process-test-resources + + copy + + + + + org.apache.flink + flink-examples-batch_${scala.binary.version} + + jar + WordCount + true + BatchWordCount.jar + + + ${project.build.directory}/programs + false + true + + + + + store-classpath-in-target-for-tests + process-test-resources + + build-classpath + + + ${project.build.directory}/yarn.classpath + org.apache.flink + + + + + + org.apache.maven.plugins + maven-shade-plugin + + + shade-remote-shuffle + package + + shade + + + false + false + ${project.artifactId}-${project.version} + + + commons-cli:commons-cli + + + + + + + + org.apache.commons.cli + + com.alibaba.flink.shuffle.yarn.shaded.org.apache.commons.cli + + + + + + + + + + + diff --git a/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/AppClient.java b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/AppClient.java new file mode 100644 index 00000000..c45c6178 --- /dev/null +++ b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/AppClient.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn.entry.manager; + +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleManager; +import com.alibaba.flink.shuffle.coordinator.utils.EnvironmentInformation; +import com.alibaba.flink.shuffle.yarn.utils.DeployOnYarnUtils; +import com.alibaba.flink.shuffle.yarn.utils.YarnConstants; + +import org.apache.commons.cli.ParseException; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.io.DataOutputBuffer; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.yarn.api.ApplicationConstants; +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ApplicationSubmissionContext; +import org.apache.hadoop.yarn.api.records.ContainerLaunchContext; +import org.apache.hadoop.yarn.api.records.LocalResource; +import org.apache.hadoop.yarn.api.records.LocalResourceType; +import org.apache.hadoop.yarn.api.records.Priority; +import org.apache.hadoop.yarn.api.records.Resource; +import org.apache.hadoop.yarn.client.api.YarnClient; +import org.apache.hadoop.yarn.client.api.YarnClientApplication; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Vector; + +/** + * Client for application submission to YARN. + * + *

This client is meant to submit an application to start a {@link YarnShuffleManagerRunner} as + * an Application Master. + */ +public class AppClient { + private static final Logger LOG = LoggerFactory.getLogger(AppClient.class); + + private final Configuration hadoopConf; + + private final YarnClient yarnClient; + + /** Main class to start application master. */ + private final String appMasterMainClass; + + private final AppClientEnvs appEnvs; + + private ApplicationId appId; + + /** + * Use {@link YarnShuffleManagerRunner} as the application master main class. {@link + * ShuffleManager} will be started in the container. + */ + AppClient(String[] args) throws IOException, ParseException { + this(args, new YarnConfiguration()); + } + + /** Privilege is set as public for testing. */ + public AppClient(String[] args, YarnConfiguration hadoopConf) + throws IOException, ParseException { + EnvironmentInformation.logEnvironmentInfo(LOG, "Yarn Shuffle Manager", args); + this.appMasterMainClass = YarnShuffleManagerRunner.class.getCanonicalName(); + this.hadoopConf = hadoopConf; + this.appEnvs = new AppClientEnvs(hadoopConf, args); + this.hadoopConf.set( + YarnConstants.MANAGER_AM_MAX_ATTEMPTS_KEY, + String.valueOf(appEnvs.getAmMaxAttempts())); + this.yarnClient = YarnClient.createYarnClient(); + yarnClient.init(hadoopConf); + } + + /** Main method for submission. */ + boolean run() { + boolean success; + try { + success = submitApplication(); + } catch (Exception e) { + LOG.error("Submit application failed, ", e); + success = false; + } + return success; + } + + /** Privilege is set as public for testing. */ + public boolean submitApplication() throws IOException, YarnException, URISyntaxException { + yarnClient.start(); + + // Get a new application id + final YarnClientApplication app = yarnClient.createApplication(); + + // Set the application name + final ApplicationSubmissionContext appContext = app.getApplicationSubmissionContext(); + appId = appContext.getApplicationId(); + + // Only one application master container is running in the application and there are no + // other containers in it, so the keep-container flag is set as false + appContext.setKeepContainersAcrossApplicationAttempts(false); + appContext.setApplicationName(appEnvs.getAppName()); + + // Set local resources for the application master + final Map localResources = new HashMap<>(); + + final FileSystem fs = FileSystem.get(hadoopConf); + + prepareAppLocalResources(appId, localResources, fs); + + // Set up the container launch context for the application master + final ContainerLaunchContext amContainer = + ContainerLaunchContext.newInstance( + localResources, generateEnvs(), generateCommands(), null, null, null); + + // Setup resource capability according to the memory size and vcore count + final Resource capability = + Resource.newInstance( + appEnvs.getAmMemory() + appEnvs.getMemoryOverhead(), appEnvs.getAmVCores()); + appContext.setResource(capability); + + // Setup security tokens + if (UserGroupInformation.isSecurityEnabled()) { + addTokensToAmContainer(fs, amContainer); + } + + appContext.setAMContainerSpec(amContainer); + + final Priority pri = Priority.newInstance(appEnvs.getAmPriority()); + appContext.setPriority(pri); + + // Set the application queue + appContext.setQueue(appEnvs.getAmQueue()); + + LOG.info("Submitting application, the app id is " + appId.toString()); + yarnClient.submitApplication(appContext); + return true; + } + + private void addTokensToAmContainer(FileSystem fs, ContainerLaunchContext amContainer) + throws IOException { + // Note: Credentials class is marked as LimitedPrivate for HDFS and MapReduce + final Credentials credentials = new Credentials(); + final String tokenRenewer = hadoopConf.get(YarnConfiguration.RM_PRINCIPAL); + if (tokenRenewer == null || tokenRenewer.length() == 0) { + throw new IOException( + "Can't get Master Kerberos principal for the RM to use as renewer"); + } + + // For now, only getting tokens for the default file-system. + final Token[] tokens = fs.addDelegationTokens(tokenRenewer, credentials); + if (tokens != null) { + for (Token token : tokens) { + LOG.info("Got delegation token for " + fs.getUri() + ", token: " + token); + } + } + final DataOutputBuffer dob = new DataOutputBuffer(); + credentials.writeTokenStorageToStream(dob); + final ByteBuffer fsTokens = ByteBuffer.wrap(dob.getData(), 0, dob.getLength()); + amContainer.setTokens(fsTokens); + } + + private void prepareAppLocalResources( + ApplicationId appId, Map localResources, FileSystem fs) + throws IOException, URISyntaxException { + appEnvs.prepareDirAndFilesAmNeeded(appId); + // Copy jars to the filesystem + // Create a local resource to point to the destination jar path + DeployOnYarnUtils.addFrameworkToDistributedCache( + appEnvs.getShuffleHomeDirInHdfs(), + localResources, + LocalResourceType.ARCHIVE, + YarnConstants.MANAGER_AM_REMOTE_SHUFFLE_PATH_NAME, + hadoopConf); + + // Set the log4j properties if needed + DeployOnYarnUtils.addFrameworkToDistributedCache( + appEnvs.getLog4jPropertyFile(), + localResources, + LocalResourceType.FILE, + YarnConstants.MANAGER_AM_LOG4J_FILE_NAME, + hadoopConf); + } + + /** Set the env variables to be setup in the env where the application master will be run. */ + private Map generateEnvs() { + final Map env = new HashMap<>(); + final String classPaths = DeployOnYarnUtils.buildClassPathEnv(hadoopConf); + env.put(YarnConstants.MANAGER_APP_ENV_CLASS_PATH_KEY, classPaths); + LOG.info("Set the classpath for the application master, classpath: " + classPaths); + return env; + } + + private List generateCommands() { + // Set the necessary commands to execute the application master + Vector vargs = new Vector(30); + vargs.add("echo classpath: $CLASSPATH;echo path: $PATH;"); + // Set java executable commands + vargs.add(Environment.JAVA_HOME.$$() + "/bin/java"); + // Set Xmx and Xms + vargs.add("-Xmx" + appEnvs.getAmMemory() + "m"); + vargs.add("-Xms" + appEnvs.getAmMemory() + "m"); + // Set log4j properties + vargs.add("-Dlog4j.configurationFile=" + YarnConstants.MANAGER_AM_LOG4J_FILE_NAME); + // Set other jvm options if needed + if (appEnvs.getAmJvmOptions() != null && !appEnvs.getAmJvmOptions().isEmpty()) { + vargs.add(appEnvs.getAmJvmOptions()); + } + vargs.add( + " -cp $CLASSPATH:'./" + YarnConstants.MANAGER_AM_REMOTE_SHUFFLE_PATH_NAME + "/*'"); + // Set class name + vargs.add(appMasterMainClass); + vargs.add( + "-D" + + YarnConstants.MANAGER_AM_APPID_TIMESTAMP_KEY + + "=" + + appId.getClusterTimestamp()); + vargs.add("-D" + YarnConstants.MANAGER_AM_APPID_ID_KEY + "=" + appId.getId()); + // Set shuffle manager other options + vargs.add(appEnvs.getShuffleManagerArgs()); + + vargs.add("1>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/AppMaster.stdout"); + vargs.add("2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/AppMaster.stderr"); + + // Get final commmands + final StringBuilder command = new StringBuilder(); + for (CharSequence str : vargs) { + command.append(str).append(" "); + } + + LOG.info("Start Shuffle Manager by command: " + command.toString()); + final List commands = new ArrayList<>(); + commands.add(command.toString()); + return commands; + } +} diff --git a/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/AppClientEnvs.java b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/AppClientEnvs.java new file mode 100644 index 00000000..43fed9e3 --- /dev/null +++ b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/AppClientEnvs.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn.entry.manager; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.yarn.utils.DeployOnYarnUtils; +import com.alibaba.flink.shuffle.yarn.utils.YarnConstants; + +import org.apache.commons.cli.ParseException; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkNotNull; +import static com.alibaba.flink.shuffle.core.utils.ConfigurationParserUtils.DEFAULT_SHUFFLE_CONF_DIR; +import static com.alibaba.flink.shuffle.yarn.utils.YarnConstants.MANAGER_AM_MEMORY_JVM_OPTIONS_DEFAULT; +import static com.alibaba.flink.shuffle.yarn.utils.YarnConstants.MANAGER_AM_MEMORY_JVM_OPTIONS_KEY; + +/** + * Arg parser used by {@link AppClient}. The class will parse all arguments in {@link YarnConstants} + */ +public class AppClientEnvs { + private static final Logger LOG = LoggerFactory.getLogger(AppClientEnvs.class); + + private final org.apache.hadoop.conf.Configuration hadoopConf; + + private final FileSystem fs; + + /** Local home dir contains all jars, confs etc. */ + private final String localShuffleHomeDir; + + /** Application name to start Shuffle Manager. */ + private final String appName; + + /** Application master priority. */ + private final int amPriority; + + /** Queue for Application master. */ + private final String amQueue; + + /** Application master max attempts count when encountering a exception. */ + private final int amMaxAttempts; + + /** Virtual core count to start Application master. */ + private final int amVCores; + + /** Other JVM options of application master. */ + private final String amJvmOptions; + + /** Memory resource to start Application master. */ + private int amMemory; + + /** Overhead memory to start Application master. */ + private int memoryOverhead; + + /** + * Shuffle manager args. Specified by pattern of -D remote-shuffle.yarn.manager-start-opts="-D + * a.b.c1=v1 -D a.b.c2=v2". + */ + private final StringBuilder shuffleManagerArgs; + + /** + * When the input arguments are not illegal, the map will store the overridden keys and values. + */ + private Map overrideOptions; + + /** + * Application directory in HDFS which contains all jars and configurations. The directory path + * is "fs.getHomeDirectory()/appName/appId/MANAGER_HOME_DIR/AM_HDFS_TMP_DIR". + */ + private String shuffleHomeDirInHdfs; + + /** + * Log4j properties file. The file should be contained in the specified MANAGER_HOME_DIR + * directory. + */ + private String log4jPropertyFile; + + /** + * A HDFS jar for starting App Master. The jar files should be contained in the specified + * MANAGER_HOME_DIR directory. + */ + private String amJarFilePath; + + public AppClientEnvs(final org.apache.hadoop.conf.Configuration hadoopConf, final String[] args) + throws IOException, ParseException { + Configuration conf = DeployOnYarnUtils.parseParameters(args); + this.localShuffleHomeDir = checkNotNull(conf.getString(YarnConstants.MANAGER_HOME_DIR)); + loadOptionsInConfigurationFile(conf, localShuffleHomeDir); + + this.hadoopConf = hadoopConf; + this.fs = FileSystem.get(hadoopConf); + this.amVCores = YarnConstants.MANAGER_AM_VCORE_COUNT; + this.appName = + conf.getString( + YarnConstants.MANAGER_APP_NAME_KEY, YarnConstants.MANAGER_APP_NAME_DEFAULT); + this.amPriority = + conf.getInteger( + YarnConstants.MANAGER_APP_PRIORITY_KEY, + YarnConstants.MANAGER_APP_PRIORITY_DEFAULT); + this.amQueue = + conf.getString( + YarnConstants.MANAGER_APP_QUEUE_NAME_KEY, + YarnConstants.MANAGER_APP_QUEUE_NAME_DEFAULT); + this.amMemory = + conf.getInteger( + YarnConstants.MANAGER_AM_MEMORY_SIZE_KEY, + YarnConstants.MANAGER_AM_MEMORY_SIZE_DEFAULT); + this.memoryOverhead = + conf.getInteger( + YarnConstants.MANAGER_AM_MEMORY_OVERHEAD_SIZE_KEY, + YarnConstants.MANAGER_AM_MEMORY_OVERHEAD_SIZE_DEFAULT); + this.amMaxAttempts = + conf.getInteger( + YarnConstants.MANAGER_AM_MAX_ATTEMPTS_VAL_KEY, + YarnConstants.MANAGER_AM_MAX_ATTEMPTS_VAL_DEFAULT); + this.amJvmOptions = + conf.getString( + MANAGER_AM_MEMORY_JVM_OPTIONS_KEY, MANAGER_AM_MEMORY_JVM_OPTIONS_DEFAULT); + this.shuffleManagerArgs = new StringBuilder(); + this.overrideOptions = new HashMap<>(); + + // Check the input arguments, if any argument is wrong, modify it or throw an exception. + checkArguments(); + generateShuffleManagerArgString(conf); + } + + private static void loadOptionsInConfigurationFile(Configuration conf, String localHomeDir) + throws IOException { + String confFile = localHomeDir + "/" + DEFAULT_SHUFFLE_CONF_DIR; + Configuration confFromFile = new Configuration(confFile); + conf.addAll(confFromFile); + LOG.info("Loaded " + confFromFile.toMap().size() + " options from " + confFile); + } + + /** + * Refactor the Yarn home directory containing jars and other resources. Move all libs and + * configurations to a new temporary directory. After reorganizing the input directory, we + * should find out the main executed jar, log4j properties file, etc. + */ + public void prepareDirAndFilesAmNeeded(ApplicationId appId) throws IOException { + String remoteShuffleDir = + DeployOnYarnUtils.uploadLocalDirToHDFS( + fs, + localShuffleHomeDir, + appId.toString(), + YarnConstants.MANAGER_AM_REMOTE_SHUFFLE_PATH_NAME); + shuffleHomeDirInHdfs = + DeployOnYarnUtils.refactorDirectoryHierarchy(fs, remoteShuffleDir, hadoopConf); + amJarFilePath = DeployOnYarnUtils.findApplicationMasterJar(fs, shuffleHomeDirInHdfs); + log4jPropertyFile = DeployOnYarnUtils.findLog4jPropertyFile(fs, shuffleHomeDirInHdfs); + LOG.info( + "Remote shuffle home directory in HDFS is " + + shuffleHomeDirInHdfs + + ", AM jar is " + + amJarFilePath + + ", log4j properties file is " + + log4jPropertyFile); + } + + private void checkArguments() { + if (amVCores < 1) { + throw new IllegalArgumentException( + "Invalid virtual cores specified for application master, exiting."); + } + + if (localShuffleHomeDir == null) { + throw new IllegalArgumentException( + "Shuffle local home directory is not specified. The specified directory path " + + "should contains all jars and files, e.g. lib, conf, log. Specify it by " + + YarnConstants.MANAGER_HOME_DIR); + } + + if (amMemory < YarnConstants.MIN_VALID_AM_MEMORY_SIZE_MB) { + LOG.warn( + "The input AM memory size is too small, the minimum size " + + YarnConstants.MIN_VALID_AM_MEMORY_SIZE_MB + + " mb will be used."); + amMemory = YarnConstants.MIN_VALID_AM_MEMORY_SIZE_MB; + overrideOptions.put(YarnConstants.MANAGER_AM_MEMORY_SIZE_KEY, amMemory); + } + + if (memoryOverhead < YarnConstants.MIN_VALID_AM_MEMORY_SIZE_MB) { + LOG.warn( + "The input overhead memory size is too small, the minimum size " + + YarnConstants.MIN_VALID_AM_MEMORY_SIZE_MB + + " mb will be used."); + memoryOverhead = YarnConstants.MIN_VALID_AM_MEMORY_SIZE_MB; + overrideOptions.put(YarnConstants.MANAGER_AM_MEMORY_OVERHEAD_SIZE_KEY, memoryOverhead); + } + } + + private void generateShuffleManagerArgString(final Configuration conf) { + Map confMap = conf.toMap(); + for (String optionKey : confMap.keySet()) { + shuffleManagerArgs.append(" -D ").append(optionKey).append("="); + if (overrideOptions.containsKey(optionKey)) { + shuffleManagerArgs.append(overrideOptions.get(optionKey)); + } else { + shuffleManagerArgs.append(confMap.get(optionKey)); + } + } + } + + String getAppName() { + return appName; + } + + int getAmPriority() { + return amPriority; + } + + String getAmQueue() { + return amQueue; + } + + int getAmMemory() { + return amMemory; + } + + int getMemoryOverhead() { + return memoryOverhead; + } + + int getAmMaxAttempts() { + return amMaxAttempts; + } + + int getAmVCores() { + return amVCores; + } + + String getAmJvmOptions() { + return amJvmOptions; + } + + String getShuffleHomeDirInHdfs() { + return shuffleHomeDirInHdfs; + } + + String getShuffleManagerArgs() { + return shuffleManagerArgs.toString(); + } + + String getLog4jPropertyFile() { + return log4jPropertyFile; + } +} diff --git a/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/NMCallbackHandler.java b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/NMCallbackHandler.java new file mode 100644 index 00000000..258f42d8 --- /dev/null +++ b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/NMCallbackHandler.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn.entry.manager; + +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.api.records.ContainerStatus; +import org.apache.hadoop.yarn.client.api.async.NMClientAsync; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.Map; + +/** + * This is the Yarn NM callback handler class. By implementing this interface, we can control the + * job to execute the different logic corresponding to the different state. + * + *

This class does not implement special operations according to different stat, but we still + * need to implement this interface {@link NMClientAsync.CallbackHandler} to transfer the job state + * into the STARTED state. + */ +public class NMCallbackHandler implements NMClientAsync.CallbackHandler { + private static final Logger LOG = LoggerFactory.getLogger(NMCallbackHandler.class); + + @Override + public void onContainerStarted( + ContainerId containerId, Map allServiceResponse) { + LOG.info("The " + containerId + " is started"); + } + + @Override + public void onContainerStatusReceived( + ContainerId containerId, ContainerStatus containerStatus) { + LOG.info("Receive " + containerId + " status: " + containerStatus); + } + + @Override + public void onContainerStopped(ContainerId containerId) { + LOG.info("The " + containerId + " is stopped"); + } + + @Override + public void onStartContainerError(ContainerId containerId, Throwable t) { + LOG.info("Start " + containerId + " failed, ", t); + } + + @Override + public void onGetContainerStatusError(ContainerId containerId, Throwable t) { + LOG.info("Get " + containerId + " status failed, ", t); + } + + @Override + public void onStopContainerError(ContainerId containerId, Throwable t) { + LOG.info("Stop " + containerId + " failed, ", t); + } +} diff --git a/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/RMCallbackHandler.java b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/RMCallbackHandler.java new file mode 100644 index 00000000..037bb188 --- /dev/null +++ b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/RMCallbackHandler.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn.entry.manager; + +import org.apache.hadoop.util.StringUtils; +import org.apache.hadoop.yarn.api.records.Container; +import org.apache.hadoop.yarn.api.records.ContainerStatus; +import org.apache.hadoop.yarn.api.records.NodeReport; +import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +/** + * This is the Yarn RM callback handler class. By implementing this interface, we can control the + * job to execute the different logic corresponding to the different state. + * + *

This class does not implement special operations according to different stat, but we still + * need to implement this interface {@link AMRMClientAsync.CallbackHandler} to register container + * with Yarn resource manager. After executing this class, it will continuously send heartbeat + * signals to the resource manager, indicating that the current application is in Running state. + */ +public class RMCallbackHandler implements AMRMClientAsync.CallbackHandler { + private static final Logger LOG = LoggerFactory.getLogger(RMCallbackHandler.class); + + @Override + public void onContainersCompleted(List statuses) { + LOG.info( + "All containers are completed, container status: " + + StringUtils.join(", ", statuses)); + } + + @Override + public void onContainersAllocated(List containers) { + LOG.info("All containers are allocated, num: " + containers.size()); + } + + @Override + public void onShutdownRequest() { + LOG.info("Received shutdown request"); + } + + @Override + public void onNodesUpdated(List updatedNodes) { + LOG.info( + "Nodes are updated and the node reports are " + + StringUtils.join(",", updatedNodes)); + } + + @Override + public float getProgress() { + return 0; + } + + @Override + public void onError(Throwable e) { + LOG.info("Encountered a error, ", e); + } +} diff --git a/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/YarnShuffleManagerEntrypoint.java b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/YarnShuffleManagerEntrypoint.java new file mode 100644 index 00000000..fadffa00 --- /dev/null +++ b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/YarnShuffleManagerEntrypoint.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn.entry.manager; + +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleManagerRunner; +import com.alibaba.flink.shuffle.yarn.utils.YarnConstants; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Yarn deployment entry point of {@link ShuffleManagerRunner}. Reference doc : + * https://hadoop.apache.org/docs/r2.7.1/hadoop-yarn/hadoop-yarn-site/WritingYarnApplications.html + * + *

Starting Shuffle Manager by submitting a Yarn application, and it will run in the application + * master container. Manager related configurations in {@link YarnConstants} should be specified + * when submitting the application. + */ +public class YarnShuffleManagerEntrypoint { + private static final Logger LOG = LoggerFactory.getLogger(YarnShuffleManagerEntrypoint.class); + + private static boolean runAppClient(String[] args) throws Exception { + AppClient client = new AppClient(args); + return client.run(); + } + + public static void main(String[] args) { + boolean success = false; + try { + success = runAppClient(args); + } catch (Throwable t) { + LOG.error("Starting Shuffle Manager application encountered an error, ", t); + return; + } + if (success) { + LOG.info("Start Shuffle Manager application successfully"); + return; + } + LOG.error("Start Shuffle Manager on Yarn failed"); + } +} diff --git a/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/YarnShuffleManagerRunner.java b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/YarnShuffleManagerRunner.java new file mode 100644 index 00000000..578e3556 --- /dev/null +++ b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/manager/YarnShuffleManagerRunner.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn.entry.manager; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.JvmShutdownSafeguard; +import com.alibaba.flink.shuffle.common.utils.SignalHandler; +import com.alibaba.flink.shuffle.coordinator.manager.ShuffleManager; +import com.alibaba.flink.shuffle.coordinator.manager.entrypoint.ShuffleManagerEntrypoint; +import com.alibaba.flink.shuffle.coordinator.utils.EnvironmentInformation; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; +import com.alibaba.flink.shuffle.yarn.utils.DeployOnYarnUtils; +import com.alibaba.flink.shuffle.yarn.utils.YarnConstants; + +import org.apache.commons.cli.ParseException; +import org.apache.hadoop.net.NetUtils; +import org.apache.hadoop.yarn.api.records.ApplicationAttemptReport; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ContainerReport; +import org.apache.hadoop.yarn.client.api.YarnClient; +import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync; +import org.apache.hadoop.yarn.client.api.async.NMClientAsync; +import org.apache.hadoop.yarn.client.api.async.impl.NMClientAsyncImpl; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.util.List; + +import static com.alibaba.flink.shuffle.common.config.Configuration.REMOTE_SHUFFLE_CONF_FILENAME; +import static com.alibaba.flink.shuffle.yarn.utils.YarnConstants.MANAGER_AM_REMOTE_SHUFFLE_PATH_NAME; + +/** This class is meant to start {@link ShuffleManager} based on a yarn-based application master. */ +public class YarnShuffleManagerRunner { + private static final Logger LOG = LoggerFactory.getLogger(YarnShuffleManagerRunner.class); + + private final String[] args; + + private final Configuration conf; + + private String amHost; + + private static final int AM_RPC_PORT_DEFAULT = -1; + + public YarnShuffleManagerRunner(String[] args) throws ParseException { + this.args = args; + this.conf = mergeOptionsWithConfigurationFile(args); + } + + private Configuration mergeOptionsWithConfigurationFile(String[] args) throws ParseException { + Configuration optionsInArgs = DeployOnYarnUtils.parseParameters(args); + File confFile = new File(MANAGER_AM_REMOTE_SHUFFLE_PATH_NAME, REMOTE_SHUFFLE_CONF_FILENAME); + String confFilePath = confFile.getAbsolutePath(); + if (!confFile.exists()) { + LOG.info("Configuration file " + confFilePath + " is not exist"); + return optionsInArgs; + } + + Configuration mergedOptions = new Configuration(); + Configuration optionsInFile; + try { + optionsInFile = new Configuration(confFile.getParent()); + } catch (IOException ioe) { + LOG.error("Failed to load options in the configuration file " + confFilePath, ioe); + return optionsInArgs; + } + LOG.info( + "Loaded all options in the configuration file " + + confFilePath + + ", options count: " + + optionsInFile.toMap().size()); + // The options in the input parameters will override the options in the configuration file + mergedOptions.addAll(optionsInFile); + mergedOptions.addAll(optionsInArgs); + return mergedOptions; + } + + /** Main run function for the application master. */ + public void run() throws Exception { + setShuffleManagerAddress(); + startShuffleManagerInternal(); + } + + private void setShuffleManagerAddress() throws IOException, YarnException { + amHost = getAMContainerHost(); + CommonUtils.checkArgument( + amHost != null && !amHost.isEmpty(), + "The Shuffle Manager address must not be empty."); + conf.setString(ManagerOptions.RPC_ADDRESS, amHost); + LOG.info( + "Set Shuffle Manager address " + + ManagerOptions.RPC_ADDRESS.key() + + " as " + + amHost); + } + + private String getAMContainerHost() throws IOException, YarnException { + long timestamp = conf.getLong(YarnConstants.MANAGER_AM_APPID_TIMESTAMP_KEY, -1L); + int id = conf.getInteger(YarnConstants.MANAGER_AM_APPID_ID_KEY, -1); + if (timestamp < 0 || id < 0) { + throw new IOException( + "Unable to resolve appid because id or timestamp is empty, id: " + + id + + " timestamp: " + + timestamp); + } + ApplicationId applicationId = ApplicationId.newInstance(timestamp, id); + + YarnClient yarnClient = YarnClient.createYarnClient(); + yarnClient.init(new YarnConfiguration()); + yarnClient.start(); + + // Get container report from Yarn + List attemptReports = + yarnClient.getApplicationAttempts(applicationId); + int reportIndex = findMaxAttemptIndex(attemptReports); + ContainerReport containerReport = + yarnClient.getContainerReport(attemptReports.get(reportIndex).getAMContainerId()); + + // Dump AM container report + LOG.info( + "Dump AM container report. id: " + + containerReport.getContainerId() + + " state: " + + containerReport.getContainerState() + + " assignedNode: " + + containerReport.getAssignedNode().getHost() + + " url: " + + containerReport.getLogUrl()); + return containerReport.getAssignedNode().getHost(); + } + + private int findMaxAttemptIndex(List attemptReports) { + int maxAttempt = 0; + for (ApplicationAttemptReport attemptReport : attemptReports) { + int currentAttemptId = attemptReport.getApplicationAttemptId().getAttemptId(); + maxAttempt = Math.max(currentAttemptId, maxAttempt); + } + for (int idx = 0; idx < attemptReports.size(); idx++) { + if (maxAttempt == attemptReports.get(idx).getApplicationAttemptId().getAttemptId()) { + return idx; + } + } + return attemptReports.size() - 1; + } + + private void startShuffleManagerInternal() throws Exception { + EnvironmentInformation.logEnvironmentInfo(LOG, "Shuffle Manager on Yarn", args); + SignalHandler.register(LOG); + JvmShutdownSafeguard.installAsShutdownHook(LOG); + + ShuffleManagerEntrypoint shuffleManagerEntrypoint = new ShuffleManagerEntrypoint(conf); + ShuffleManagerEntrypoint.runShuffleManagerEntrypoint(shuffleManagerEntrypoint); + LOG.info("Shuffle Manager on yarn starts successfully"); + + AMRMClientAsync.CallbackHandler allocListener = new RMCallbackHandler(); + AMRMClientAsync amRMClient = + AMRMClientAsync.createAMRMClientAsync( + conf.getInteger( + YarnConstants.MANAGER_RM_HEARTBEAT_INTERVAL_MS_KEY, + YarnConstants.MANAGER_RM_HEARTBEAT_INTERVAL_MS_DEFAULT), + allocListener); + amRMClient.init(new YarnConfiguration()); + amRMClient.start(); + + NMCallbackHandler containerListener = new NMCallbackHandler(); + NMClientAsync nmClientAsync = new NMClientAsyncImpl(containerListener); + nmClientAsync.init(new YarnConfiguration()); + nmClientAsync.start(); + + amRMClient.registerApplicationMaster(NetUtils.getHostname(), AM_RPC_PORT_DEFAULT, null); + } + + public static void main(String[] args) { + try { + YarnShuffleManagerRunner yarnShuffleManagerRunner = new YarnShuffleManagerRunner(args); + yarnShuffleManagerRunner.run(); + } catch (Throwable t) { + LOG.error("Encountering a error when starting Shuffle Manager, ", t); + } + } +} diff --git a/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/worker/YarnShuffleWorkerEntrypoint.java b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/worker/YarnShuffleWorkerEntrypoint.java new file mode 100644 index 00000000..51668f34 --- /dev/null +++ b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/entry/worker/YarnShuffleWorkerEntrypoint.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn.entry.worker; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.common.utils.CommonUtils; +import com.alibaba.flink.shuffle.common.utils.FatalErrorExitUtils; +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerRunner; +import com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions; +import com.alibaba.flink.shuffle.core.config.MemoryOptions; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.yarn.utils.DeployOnYarnUtils; +import com.alibaba.flink.shuffle.yarn.utils.YarnConstants; +import com.alibaba.flink.shuffle.yarn.utils.YarnOptions; + +import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext; +import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext; +import org.apache.hadoop.yarn.server.api.AuxiliaryService; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; + +/** + * Yarn deployment entry point of shuffle worker runner. It should be deployed along with Yarn + * NodeManager as a auxiliary service in NodeManager. + * + *

Before start it as an auxiliary service on Yarn, some specific configurations should be added + * to yarn-site.xml: 1. Add yarn_remote_shuffle_worker_for_flink to yarn.nodemanager.aux-services. + * 2. Set yarn.nodemanager.aux-services.yarn_remote_shuffle_worker_for_flink.class as {@link + * YarnShuffleWorkerEntrypoint}. 3. Add HA mode options in {@link HighAvailabilityOptions}. 4. Add + * memory in {@link MemoryOptions} and storage options in {@link StorageOptions}. + * + *

When starting Node Manager, the Hadoop classpath should contain all compiled remote shuffle + * jars. Otherwise, {@link ClassNotFoundException} or other exceptions may be thrown out. + */ +public class YarnShuffleWorkerEntrypoint extends AuxiliaryService { + private static final Logger LOG = LoggerFactory.getLogger(YarnShuffleWorkerEntrypoint.class); + + private static volatile ShuffleWorkerRunner shuffleWorkerRunner; + + public YarnShuffleWorkerEntrypoint() { + super(YarnConstants.WORKER_AUXILIARY_SERVICE_NAME); + } + + /** Starts the shuffle worker with the given configuration. */ + @Override + protected void serviceInit(org.apache.hadoop.conf.Configuration hadoopConf) throws Exception { + LOG.info("Initializing Shuffle Worker on Yarn for Flink"); + final boolean stopOnFailure; + final Configuration configuration; + try { + configuration = Configuration.fromMap(DeployOnYarnUtils.hadoopConfToMaps(hadoopConf)); + stopOnFailure = configuration.getBoolean(YarnOptions.WORKER_STOP_ON_FAILURE); + } catch (Throwable t) { + LOG.error( + "Get configuration for " + + YarnOptions.WORKER_STOP_ON_FAILURE.key() + + " failed, ", + t); + return; + } + + FatalErrorExitUtils.setNeedStopProcess(stopOnFailure); + + try { + shuffleWorkerRunner = ShuffleWorkerRunner.runShuffleWorker(configuration); + } catch (Exception e) { + LOG.error("Failed to start Shuffle Worker on Yarn for Flink, ", e); + if (stopOnFailure) { + throw e; + } else { + noteFailure(e); + } + } catch (Throwable t) { + LOG.error( + "Failed to start Shuffle Worker on Yarn for Flink with the throwable error, ", + t); + } + } + + /** Currently this method is of no use. */ + @Override + public void initializeApplication(ApplicationInitializationContext initAppContext) {} + + /** Currently this method is of no use. */ + @Override + public void stopApplication(ApplicationTerminationContext stopAppContext) {} + + /** Close the shuffle worker. */ + @Override + protected void serviceStop() { + if (shuffleWorkerRunner == null) { + return; + } + + try { + shuffleWorkerRunner.close(); + } catch (Exception e) { + LOG.error("Close shuffle worker failed with error, ", e); + return; + } + LOG.info("Stop shuffle worker normally."); + } + + /** Currently this method is of no use. */ + @Override + public ByteBuffer getMetaData() { + return CommonUtils.allocateHeapByteBuffer(0); + } + + @Override + public String getName() { + return YarnConstants.WORKER_AUXILIARY_SERVICE_NAME; + } + + /** Only for tests. */ + public static ShuffleWorkerRunner getShuffleWorkerRunner() { + return shuffleWorkerRunner; + } +} diff --git a/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/utils/DeployOnYarnUtils.java b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/utils/DeployOnYarnUtils.java new file mode 100644 index 00000000..e8e49af5 --- /dev/null +++ b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/utils/DeployOnYarnUtils.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn.utils; + +import com.alibaba.flink.shuffle.common.config.Configuration; + +import org.apache.commons.cli.DefaultParser; +import org.apache.commons.cli.Option; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.FileUtil; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.yarn.api.ApplicationConstants; +import org.apache.hadoop.yarn.api.records.LocalResource; +import org.apache.hadoop.yarn.api.records.LocalResourceType; +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.util.ConverterUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.Map; + +/** Utility class that provides helper methods to work with Apache Hadoop YARN. */ +public class DeployOnYarnUtils { + private static final Logger LOG = LoggerFactory.getLogger(DeployOnYarnUtils.class); + + // --------------------------------------------------------------- + // Configuration utils + // --------------------------------------------------------------- + + private static final Option PARSE_PROPERTY_OPTION = + Option.builder("D") + .argName("property=value") + .numberOfArgs(2) + .valueSeparator('=') + .desc("use value for given property") + .build(); + + /** Parse hadoop configurations into maps. */ + public static Map hadoopConfToMaps( + final org.apache.hadoop.conf.Configuration conf) { + Map confMaps = new HashMap<>(); + for (Map.Entry entry : conf) { + confMaps.put(entry.getKey(), entry.getValue()); + } + return confMaps; + } + + /** + * Parsing the input arguments into {@link Configuration}. The method may throw a {@link + * ParseException} when inputting some wrong arguments. + * + * @param args The String array that shall be parsed. + * @return The {@link Configuration} configurations. + */ + public static Configuration parseParameters(String[] args) throws ParseException { + final DefaultParser defaultParser = new DefaultParser(); + final Options options = new Options(); + options.addOption(PARSE_PROPERTY_OPTION); + return new Configuration( + defaultParser + .parse(options, args, true) + .getOptionProperties(PARSE_PROPERTY_OPTION.getOpt())); + } + + // --------------------------------------------------------------- + // Yarn utils + // --------------------------------------------------------------- + + /** Build classpath string according to the input configurations. */ + public static String buildClassPathEnv(org.apache.hadoop.conf.Configuration conf) { + StringBuilder classPathEnv = + new StringBuilder(ApplicationConstants.Environment.CLASSPATH.$$()) + .append(ApplicationConstants.CLASS_PATH_SEPARATOR) + .append("$PWD/" + YarnConstants.MANAGER_AM_LOG4J_FILE_NAME) + .append(ApplicationConstants.CLASS_PATH_SEPARATOR) + .append("$PWD") + .append(ApplicationConstants.CLASS_PATH_SEPARATOR) + .append("$HADOOP_CLIENT_CONF_DIR") + .append(ApplicationConstants.CLASS_PATH_SEPARATOR) + .append("$HADOOP_CONF_DIR") + .append(ApplicationConstants.CLASS_PATH_SEPARATOR) + .append("$JAVA_HOME/lib/tools.jar") + .append(ApplicationConstants.CLASS_PATH_SEPARATOR) + .append("$PWD/") + .append(YarnConstants.MANAGER_AM_REMOTE_SHUFFLE_PATH_NAME) + .append("/") + .append(ApplicationConstants.CLASS_PATH_SEPARATOR) + .append("$PWD/") + .append(YarnConstants.MANAGER_AM_REMOTE_SHUFFLE_PATH_NAME) + .append("/conf/") + .append(ApplicationConstants.CLASS_PATH_SEPARATOR) + .append("$PWD/") + .append(YarnConstants.MANAGER_AM_REMOTE_SHUFFLE_PATH_NAME) + .append("/*"); + for (String c : + conf.getStrings( + YarnConfiguration.YARN_APPLICATION_CLASSPATH, + YarnConfiguration.DEFAULT_YARN_CROSS_PLATFORM_APPLICATION_CLASSPATH)) { + classPathEnv.append(ApplicationConstants.CLASS_PATH_SEPARATOR); + classPathEnv.append(c.trim()); + } + return classPathEnv.toString(); + } + + public static void addFrameworkToDistributedCache( + String javaPathInHdfs, + Map localResources, + LocalResourceType resourceType, + String resourceKey, + org.apache.hadoop.conf.Configuration conf) + throws IOException, URISyntaxException { + FileSystem fs = FileSystem.get(conf); + URI uri = getURIFromHdfsPath(javaPathInHdfs, resourceKey, conf); + + FileStatus scFileStatus = fs.getFileStatus(new Path(uri.getPath())); + LocalResource scRsrc = + LocalResource.newInstance( + ConverterUtils.getYarnUrlFromURI(uri), + resourceType, + LocalResourceVisibility.PRIVATE, + scFileStatus.getLen(), + scFileStatus.getModificationTime()); + localResources.put(resourceKey, scRsrc); + } + + public static URI getURIFromHdfsPath( + String inputPath, String resourceKey, org.apache.hadoop.conf.Configuration conf) + throws IOException, URISyntaxException { + URI uri = FileSystem.get(conf).resolvePath(new Path(inputPath)).toUri(); + return new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, resourceKey); + } + + // --------------------------------------------------------------- + // HDFS utils + // --------------------------------------------------------------- + + /** + * Uploading specific path into HDFS target directory. The target directory path will be + * "fs.getHomeDirectory()/AM_REMOTE_SHUFFLE_DIST_DIR_NAME/appId/dstDir". + */ + public static String uploadLocalDirToHDFS( + FileSystem fs, String fileSrcPath, String appId, String fileDstPath) + throws IOException { + Path dst = + new Path( + fs.getHomeDirectory(), + YarnConstants.MANAGER_AM_REMOTE_SHUFFLE_PATH_NAME + + "/" + + appId + + "/" + + fileDstPath); + if (fs.exists(dst)) { + throw new IOException("Upload files failed, because the path " + dst + " is exist."); + } + fs.copyFromLocalFile(new Path(fileSrcPath), dst); + LOG.info("Upload local " + fileSrcPath + " to " + dst); + return dst.toString(); + } + + /** + * Refactoring the target directory. The original directory has multiple sub directories, and + * this method will move all files and jars in the sub directories into a new temporary + * directory. The AM will be startup based on this new temporary directory. + */ + public static String refactorDirectoryHierarchy( + FileSystem fs, String shuffleHomeDir, org.apache.hadoop.conf.Configuration hadoopConf) + throws IOException { + Path targetDir = new Path(shuffleHomeDir); + Path newTargetDir = new Path(shuffleHomeDir + "/" + YarnConstants.MANAGER_AM_TMP_PATH_NAME); + if (!fs.exists(newTargetDir)) { + fs.mkdirs(newTargetDir); + } + return refactorDirectoryHierarchyInternal(fs, targetDir, newTargetDir, hadoopConf) + .toString(); + } + + private static Path refactorDirectoryHierarchyInternal( + FileSystem fs, + Path srcDir, + Path targetDir, + org.apache.hadoop.conf.Configuration hadoopConf) + throws IOException { + FileStatus[] fileStatuses = fs.listStatus(srcDir); + for (FileStatus fileStatus : fileStatuses) { + if (fs.isDirectory(fileStatus.getPath())) { + refactorDirectoryHierarchyInternal(fs, fileStatus.getPath(), targetDir, hadoopConf); + } else { + Path curPath = fileStatus.getPath(); + Path targetPath = new Path(targetDir, curPath.getName()); + FileUtil.copy(fs, curPath, fs, targetPath, false, hadoopConf); + } + } + return targetDir; + } + + /** + * Find out the AM jar to start Application Master in the input directory. If not found, this + * method will throw a {@link IOException}. + */ + public static String findApplicationMasterJar(FileSystem fs, String targetDir) + throws IOException { + for (FileStatus fileStatus : listFileStatus(fs, targetDir)) { + Path curPath = fileStatus.getPath(); + if (curPath.getName().startsWith(YarnConstants.MANAGER_AM_JAR_FILE_PREFIX) + && !curPath.getName() + .endsWith(YarnConstants.MANAGER_AM_JAR_FILE_EXCLUDE_SUFFIX)) { + return curPath.toString(); + } + } + throw new IOException("Can not find application master jar in the directory " + targetDir); + } + + /** + * Find out the log4j properties file in the input directory. If not found, this method will + * throw a {@link IOException}. + */ + public static String findLog4jPropertyFile(FileSystem fs, String targetDir) throws IOException { + for (FileStatus fileStatus : listFileStatus(fs, targetDir)) { + Path curPath = fileStatus.getPath(); + if (curPath.getName().equals(YarnConstants.MANAGER_AM_LOG4J_FILE_NAME)) { + return curPath.toString(); + } + } + + throw new IOException( + "Can not find " + + YarnConstants.MANAGER_AM_LOG4J_FILE_NAME + + " in the directory " + + targetDir); + } + + private static FileStatus[] listFileStatus(FileSystem fs, String targetDir) throws IOException { + return fs.listStatus(new Path(targetDir)); + } +} diff --git a/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/utils/YarnConstants.java b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/utils/YarnConstants.java new file mode 100644 index 00000000..97c921cf --- /dev/null +++ b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/utils/YarnConstants.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn.utils; + +/** Constants for Yarn deployment. */ +public class YarnConstants { + + public static final String MANAGER_APP_ENV_CLASS_PATH_KEY = "CLASSPATH"; + + public static final String MANAGER_AM_REMOTE_SHUFFLE_PATH_NAME = "shuffleManager"; + + public static final String MANAGER_AM_TMP_PATH_NAME = "tmpDir"; + + public static final String MANAGER_AM_JAR_FILE_PREFIX = "shuffle-dist"; + + public static final String MANAGER_AM_JAR_FILE_EXCLUDE_SUFFIX = "tests.jar"; + + public static final String MANAGER_AM_LOG4J_FILE_NAME = "log4j2.properties"; + + public static final String MANAGER_AM_MAX_ATTEMPTS_KEY = "yarn.resourcemanager.am.max-attempts"; + + /** + * This parameter is the Id part of ApplicationId. This is only used to pass arguments from Yarn + * client to application master. + * + *

It will connect with another parameter {@link #MANAGER_AM_APPID_TIMESTAMP_KEY} to form a + * complete ApplicationId. Application master will get ip address from the ApplicationId and + * store the address to Zookeeper. + */ + public static final String MANAGER_AM_APPID_ID_KEY = "managerAppidId"; + + /** + * This parameter is the Timestamp part of ApplicationId. This is only used to pass arguments + * from Yarn client to application master. + * + *

It will connect with another parameter {@link #MANAGER_AM_APPID_ID_KEY} to form a complete + * ApplicationId. Application master will get ip address from the ApplicationId and store the + * address to Zookeeper. + */ + public static final String MANAGER_AM_APPID_TIMESTAMP_KEY = "managerAppidTimestamp"; + + /** + * Because only the Shuffle Manager is running in the application master and there are no other + * containers, only one VCore is required. + */ + public static final int MANAGER_AM_VCORE_COUNT = 1; + + /** + * Define the name of shuffle worker service. + * + *

It's used to: 1. Configure shuffle service in NodeManger in yarn-site.xml. 2. Suggest the + * auxiliary service name of shuffle worker in NodeManger. + */ + public static final String WORKER_AUXILIARY_SERVICE_NAME = + "yarn_remote_shuffle_worker_for_flink"; + + // ------------------------------------------------------------------------ + // Yarn Config Constants + // ------------------------------------------------------------------------ + /** + * Minimum valid size of memory in megabytes to start application master. If the specified + * memory is less than this min value, the memory size will be replaced by the min value. + * + *

The values of configurations {@link #MANAGER_AM_MEMORY_SIZE_KEY} and {@link + * #MANAGER_AM_MEMORY_OVERHEAD_SIZE_KEY} can not be less than this value. + */ + public static final int MIN_VALID_AM_MEMORY_SIZE_MB = 128; + + /** + * Local home directory containing all jars, configuration files and other resources, which is + * used to start Shuffle Manager on Yarn. + */ + public static final String MANAGER_HOME_DIR = "remote-shuffle.yarn.manager-home-dir"; + + /** Application name when deploying Shuffle Manager service on Yarn. */ + public static final String MANAGER_APP_NAME_KEY = "remote-shuffle.yarn.manager-app-name"; + + public static final String MANAGER_APP_NAME_DEFAULT = "Flink-Remote-Shuffle-Manager"; + + /** Application priority when deploying Shuffle Manager service on Yarn. */ + public static final String MANAGER_APP_PRIORITY_KEY = + "remote-shuffle.yarn.manager-app-priority"; + + public static final int MANAGER_APP_PRIORITY_DEFAULT = 0; + + /** Application queue name when deploying Shuffle Manager service on Yarn. */ + public static final String MANAGER_APP_QUEUE_NAME_KEY = + "remote-shuffle.yarn.manager-app-queue-name"; + + public static final String MANAGER_APP_QUEUE_NAME_DEFAULT = "default"; + + /** + * Application master max attempt counts. The AM is a Shuffle Manager. In order to make Shuffle + * Manager run more stably, the attempt count is set as a very large value. + */ + public static final String MANAGER_AM_MAX_ATTEMPTS_VAL_KEY = + "remote-shuffle.yarn.manager-am-max-attempts"; + + public static final int MANAGER_AM_MAX_ATTEMPTS_VAL_DEFAULT = 1000000; + + /** + * Size of memory in megabytes allocated for starting application master, which is used to + * specify memory size for Shuffle Manager. If the configured value is smaller than + * MIN_VALID_AM_MEMORY_SIZE_MB, the memory size will be replaced by MIN_VALID_AM_MEMORY_SIZE_MB. + */ + public static final String MANAGER_AM_MEMORY_SIZE_KEY = + "remote-shuffle.yarn.manager-am-memory-size-mb"; + + public static final int MANAGER_AM_MEMORY_SIZE_DEFAULT = 2048; + + /** + * Size of overhead memory in megabytes allocated for starting application master, which is used + * to specify overhead memory size for ShuffleManager. If the configured value is smaller than + * MIN_VALID_AM_MEMORY_SIZE_MB, the overhead memory size will be replaced by + * MIN_VALID_AM_MEMORY_SIZE_MB. + */ + public static final String MANAGER_AM_MEMORY_OVERHEAD_SIZE_KEY = + "remote-shuffle.yarn.manager-am-memory-overhead-mb"; + + public static final int MANAGER_AM_MEMORY_OVERHEAD_SIZE_DEFAULT = 512; + + /** + * Use this option if you want to modify other JVM options for the ShuffleManager running in the + * application master. For example, you can configure JVM heap size, JVM GC logs, JVM GC + * operations, etc. + */ + public static final String MANAGER_AM_MEMORY_JVM_OPTIONS_KEY = + "remote-shuffle.yarn.manager-am-jvm-options"; + + public static final String MANAGER_AM_MEMORY_JVM_OPTIONS_DEFAULT = ""; + + /** + * Shuffle Manager is started in a container, the container will keep sending heartbeats to Yarn + * resource manager, and this parameter indicates the heartbeat interval. + */ + public static final String MANAGER_RM_HEARTBEAT_INTERVAL_MS_KEY = + "remote-shuffle.yarn.manager-rm-heartbeat-interval-ms"; + + public static final int MANAGER_RM_HEARTBEAT_INTERVAL_MS_DEFAULT = 1000; +} diff --git a/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/utils/YarnOptions.java b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/utils/YarnOptions.java new file mode 100644 index 00000000..3edfcabe --- /dev/null +++ b/shuffle-yarn/src/main/java/com/alibaba/flink/shuffle/yarn/utils/YarnOptions.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn.utils; + +import com.alibaba.flink.shuffle.common.config.ConfigOption; + +/** This class holds configuration constants used by the remote shuffle on Yarn deployment. */ +public class YarnOptions { + + /** + * Flag indicating whether to throw the encountered exceptions to the upper Yarn service. The + * parameter's default value is false. If it is set as true, the upper Yarn service may be + * stopped because of the exceptions from the ShuffleWorker. Note: This parameter needs to be + * configured in yarn-site.xml. + */ + public static final ConfigOption WORKER_STOP_ON_FAILURE = + new ConfigOption("remote-shuffle.yarn.worker-stop-on-failure") + .defaultValue(false) + .description( + "Flag indicating whether to throw the encountered exceptions to the " + + "upper Yarn service. The parameter's default value is false. " + + "If it is set as true, the upper Yarn service may be stopped " + + "because of the exceptions from the ShuffleWorker. Note: This" + + " parameter needs to be configured in yarn-site.xml."); +} diff --git a/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/BatchJobTestBase.java b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/BatchJobTestBase.java new file mode 100644 index 00000000..a0466908 --- /dev/null +++ b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/BatchJobTestBase.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn; + +import com.alibaba.flink.shuffle.common.config.Configuration; +import com.alibaba.flink.shuffle.core.config.ClusterOptions; +import com.alibaba.flink.shuffle.core.config.ManagerOptions; +import com.alibaba.flink.shuffle.core.config.StorageOptions; +import com.alibaba.flink.shuffle.core.config.WorkerOptions; +import com.alibaba.flink.shuffle.minicluster.ShuffleMiniCluster; +import com.alibaba.flink.shuffle.minicluster.ShuffleMiniClusterConfiguration; +import com.alibaba.flink.shuffle.plugin.RemoteShuffleServiceFactory; +import com.alibaba.flink.shuffle.plugin.config.PluginOptions; + +import org.apache.flink.api.common.RuntimeExecutionMode; +import org.apache.flink.configuration.ExecutionOptions; +import org.apache.flink.configuration.JobManagerOptions; +import org.apache.flink.configuration.MemorySize; +import org.apache.flink.configuration.NettyShuffleEnvironmentOptions; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.configuration.TaskManagerOptions; +import org.apache.flink.runtime.highavailability.HighAvailabilityServices; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.TestingMiniCluster; +import org.apache.flink.runtime.minicluster.TestingMiniClusterConfiguration; +import org.apache.flink.runtime.shuffle.ShuffleServiceOptions; +import org.apache.flink.util.ExceptionUtils; + +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.TemporaryFolder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.InetAddress; +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import static com.alibaba.flink.shuffle.yarn.utils.TestTimeoutUtils.waitAllCompleted; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** A base class for batch job cases which using the remote shuffle. */ +public abstract class BatchJobTestBase { + private static final Logger LOG = LoggerFactory.getLogger(BatchJobTestBase.class); + + @Rule public final TemporaryFolder temporaryFolder = new TemporaryFolder(); + + protected int numShuffleWorkers = 4; + + protected int numTaskManagers = 2; + + protected int numSlotsPerTaskManager = 2; + + protected final Configuration configuration = new Configuration(); + + protected final org.apache.flink.configuration.Configuration flinkConfiguration = + new org.apache.flink.configuration.Configuration(); + + protected MiniCluster flinkCluster; + + protected ShuffleMiniCluster shuffleCluster; + + protected Supplier highAvailabilityServicesSupplier = null; + + @Before + public void before() throws Exception { + // basic configuration + String address = InetAddress.getLocalHost().getHostAddress(); + configuration.setString( + StorageOptions.STORAGE_LOCAL_DATA_DIRS, + temporaryFolder.getRoot().getAbsolutePath()); + configuration.setString(ManagerOptions.RPC_ADDRESS, address); + configuration.setString(ManagerOptions.RPC_BIND_ADDRESS, address); + configuration.setString(WorkerOptions.BIND_HOST, address); + configuration.setString(WorkerOptions.HOST, address); + configuration.setInteger(ManagerOptions.RPC_PORT, ManagerOptions.RPC_PORT.defaultValue()); + configuration.setInteger( + ManagerOptions.RPC_BIND_PORT, ManagerOptions.RPC_PORT.defaultValue()); + configuration.setDuration(ClusterOptions.REGISTRATION_TIMEOUT, Duration.ofHours(1)); + + // flink basic configuration. + flinkConfiguration.set(ExecutionOptions.RUNTIME_MODE, RuntimeExecutionMode.BATCH); + flinkConfiguration.setString( + ShuffleServiceOptions.SHUFFLE_SERVICE_FACTORY_CLASS, + RemoteShuffleServiceFactory.class.getName()); + flinkConfiguration.setString(ManagerOptions.RPC_ADDRESS.key(), address); + flinkConfiguration.setLong(JobManagerOptions.SLOT_REQUEST_TIMEOUT, 5000L); + flinkConfiguration.setString(RestOptions.BIND_PORT, "0"); + flinkConfiguration.set(TaskManagerOptions.TOTAL_PROCESS_MEMORY, MemorySize.parse("1g")); + flinkConfiguration.set(JobManagerOptions.TOTAL_PROCESS_MEMORY, MemorySize.parse("1g")); + flinkConfiguration.set(TaskManagerOptions.NETWORK_MEMORY_FRACTION, 0.4F); + flinkConfiguration.setString(PluginOptions.MEMORY_PER_INPUT_GATE.key(), "8m"); + flinkConfiguration.setString(PluginOptions.MEMORY_PER_RESULT_PARTITION.key(), "8m"); + flinkConfiguration.setString( + NettyShuffleEnvironmentOptions.NETWORK_BUFFERS_MEMORY_MAX, "512mb"); + + // setup special config. + setup(); + + asyncStartShuffleAndFlinkCluster(address); + } + + private void startShuffleAndFlinkCluster(String address) throws Exception { + ShuffleMiniClusterConfiguration clusterConf = + new ShuffleMiniClusterConfiguration.Builder() + .setConfiguration(configuration) + .setNumShuffleWorkers(numShuffleWorkers) + .setCommonBindAddress(address) + .build(); + shuffleCluster = new ShuffleMiniCluster(clusterConf); + shuffleCluster.start(); + + TestingMiniClusterConfiguration miniClusterConfiguration = + TestingMiniClusterConfiguration.newBuilder() + .setConfiguration(flinkConfiguration) + .setNumTaskManagers(numTaskManagers) + .setNumSlotsPerTaskManager(numSlotsPerTaskManager) + .build(); + + flinkCluster = + new TestingMiniCluster(miniClusterConfiguration, highAvailabilityServicesSupplier); + flinkCluster.start(); + } + + private void asyncStartShuffleAndFlinkCluster(String address) throws Exception { + long start = System.currentTimeMillis(); + CompletableFuture startClusterFuture = + CompletableFuture.supplyAsync( + () -> { + try { + startShuffleAndFlinkCluster(address); + return true; + } catch (Exception e) { + LOG.error("Failed to setup shuffle and flink cluster, ", e); + return false; + } + }); + List results = + waitAllCompleted( + Collections.singletonList(startClusterFuture), 600, TimeUnit.SECONDS); + long duration = System.currentTimeMillis() - start; + LOG.info("The process of start shuffle and flink cluster took " + duration + " ms"); + assertEquals(1, results.size()); + assertTrue(results.get(0)); + } + + @After + public void after() { + Throwable exception = null; + + try { + if (flinkCluster != null) { + flinkCluster.close(); + } + } catch (Throwable throwable) { + exception = throwable; + } + + try { + if (shuffleCluster != null) { + shuffleCluster.close(); + } + } catch (Throwable throwable) { + exception = exception != null ? exception : throwable; + } + + if (exception != null) { + ExceptionUtils.rethrow(exception); + } + + shutdown(); + } + + abstract void setup() throws Exception; + + abstract void shutdown(); +} diff --git a/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/RemoteShuffleOnYarnTestCluster.java b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/RemoteShuffleOnYarnTestCluster.java new file mode 100644 index 00000000..c2991a02 --- /dev/null +++ b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/RemoteShuffleOnYarnTestCluster.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn; + +import com.alibaba.flink.shuffle.coordinator.worker.ShuffleWorkerRunner; +import com.alibaba.flink.shuffle.yarn.entry.manager.AppClient; +import com.alibaba.flink.shuffle.yarn.entry.worker.YarnShuffleWorkerEntrypoint; +import com.alibaba.flink.shuffle.yarn.utils.YarnConstants; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFrameworkFactory; +import org.apache.flink.shaded.curator4.org.apache.curator.retry.RetryNTimes; + +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.service.Service; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.rules.TemporaryFolder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; + +import static com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions.HA_MODE; +import static com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM; +import static com.alibaba.flink.shuffle.core.config.ManagerOptions.RPC_BIND_PORT; +import static com.alibaba.flink.shuffle.core.config.ManagerOptions.RPC_PORT; +import static com.alibaba.flink.shuffle.core.config.MemoryOptions.MEMORY_SIZE_FOR_DATA_READING; +import static com.alibaba.flink.shuffle.core.config.MemoryOptions.MEMORY_SIZE_FOR_DATA_WRITING; +import static com.alibaba.flink.shuffle.core.config.MemoryOptions.MIN_VALID_MEMORY_SIZE; +import static com.alibaba.flink.shuffle.core.config.StorageOptions.STORAGE_LOCAL_DATA_DIRS; +import static com.alibaba.flink.shuffle.yarn.utils.YarnConstants.MANAGER_AM_REMOTE_SHUFFLE_PATH_NAME; +import static com.alibaba.flink.shuffle.yarn.utils.YarnConstants.MANAGER_AM_TMP_PATH_NAME; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +/** IT case with ShuffleManager and ShuffleWorkers deployed on Yarn framework. */ +public class RemoteShuffleOnYarnTestCluster extends YarnTestBase { + private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleOnYarnTestCluster.class); + + private static final TemporaryFolder tmp = new TemporaryFolder(); + + @BeforeClass + public static void setup() throws Exception { + client = startCuratorFramework(); + client.start(); + + setupConfigurations(); + + startYARNWithRetries(YARN_CONFIGURATION, false); + + deployShuffleManager(); + + checkServiceRunning(); + } + + private static void setupConfigurations() throws Exception { + YARN_CONFIGURATION.set(TEST_CLUSTER_NAME_KEY, "remote-shuffle-on-yarn-tests"); + + // HA service configurations + YARN_CONFIGURATION.set(HA_MODE.key(), "ZOOKEEPER"); + YARN_CONFIGURATION.set(HA_ZOOKEEPER_QUORUM.key(), zookeeperTestServer.getConnectString()); + + // Shuffle worker configurations + YARN_CONFIGURATION.set( + "yarn.nodemanager.aux-services", YarnConstants.WORKER_AUXILIARY_SERVICE_NAME); + YARN_CONFIGURATION.set( + "yarn.nodemanager.aux-services." + + YarnConstants.WORKER_AUXILIARY_SERVICE_NAME + + ".class", + YarnShuffleWorkerEntrypoint.class.getCanonicalName()); + YARN_CONFIGURATION.set(STORAGE_LOCAL_DATA_DIRS.key(), "[HDD]" + getStorageDataDirs()); + YARN_CONFIGURATION.set( + MEMORY_SIZE_FOR_DATA_WRITING.key(), MIN_VALID_MEMORY_SIZE.toString()); + YARN_CONFIGURATION.set( + MEMORY_SIZE_FOR_DATA_READING.key(), MIN_VALID_MEMORY_SIZE.toString()); + } + + /** Simulate the submission workflow to start the Shuffle Manager. */ + private static void deployShuffleManager() throws Exception { + File shuffleHomeDir = new File(findShuffleLocalHomeDir()); + assertTrue(shuffleHomeDir.exists()); + String mockArgs = + "-D " + + YarnConstants.MANAGER_HOME_DIR + + "=" + + findShuffleLocalHomeDir() + + " -D " + + YarnConstants.MANAGER_AM_MEMORY_SIZE_KEY + + "=128 -D " + + YarnConstants.MANAGER_AM_MEMORY_OVERHEAD_SIZE_KEY + + "=128 -D " + + YarnConstants.MANAGER_APP_QUEUE_NAME_KEY + + "=root.default -D " + + RPC_PORT.key() + + "=23123 -D " + + RPC_BIND_PORT.key() + + "=23123 -D " + + HA_MODE.key() + + "=ZOOKEEPER -D" + + HA_ZOOKEEPER_QUORUM.key() + + "=" + + zookeeperTestServer.getConnectString(); + AppClient client = new AppClient(mockArgs.split("\\s+"), YARN_CONFIGURATION); + assertTrue(client.submitApplication()); + } + + private static void checkServiceRunning() throws IOException { + assertSame(yarnCluster.getServiceState(), Service.STATE.STARTED); + checkFileInHdfsExists(); + checkShuffleWorkerRunning(); + LOG.info("All services are good"); + } + + private static void checkShuffleWorkerRunning() { + // Check Shuffle Worker is running as a auxiliary service of Node Manager + Thread workerRunnerThread = + new Thread( + () -> { + try { + ShuffleWorkerRunner shuffleWorkerRunner = + YarnShuffleWorkerEntrypoint.getShuffleWorkerRunner(); + assertNotNull(shuffleWorkerRunner); + assertEquals( + shuffleWorkerRunner.getTerminationFuture().get(), + ShuffleWorkerRunner.Result.SUCCESS); + LOG.info( + "Shuffle Worker runner status: " + + shuffleWorkerRunner.getTerminationFuture().get()); + } catch (Exception e) { + LOG.error("Shuffle Worker encountered an exception, ", e); + Assert.fail(e.getMessage()); + } + }); + workerRunnerThread.start(); + } + + private static File getHomeDir() { + File homeDir = null; + try { + tmp.create(); + homeDir = tmp.newFolder(); + } catch (IOException e) { + e.printStackTrace(); + Assert.fail(e.getMessage()); + } + System.setProperty("user.home", homeDir.getAbsolutePath()); + return homeDir; + } + + private static String getStorageDataDirs() { + File storageDir = new File(getHomeDir().getAbsolutePath(), "dataStorage"); + if (!storageDir.exists()) { + assertTrue(storageDir.mkdirs()); + assertTrue(storageDir.exists()); + } + return storageDir.getAbsolutePath(); + } + + private static CuratorFramework startCuratorFramework() throws Exception { + return CuratorFrameworkFactory.builder() + .connectString(zookeeperTestServer.getConnectString()) + .retryPolicy(new RetryNTimes(50, 100)) + .build(); + } + + private static void checkFileInHdfsExists() throws IOException { + FileSystem fs = FileSystem.get(YARN_CONFIGURATION); + Path hdfsPath = new Path(fs.getHomeDirectory(), MANAGER_AM_REMOTE_SHUFFLE_PATH_NAME + "/"); + FileStatus[] appDirs = fs.listStatus(hdfsPath); + assertEquals(1, appDirs.length); + Path shuffleManagerWorkDir = + new Path(appDirs[0].getPath(), MANAGER_AM_REMOTE_SHUFFLE_PATH_NAME); + FileStatus[] workDir = fs.listStatus(shuffleManagerWorkDir); + assertTrue(workDir.length > 0); + FileStatus[] tmpDir = + fs.listStatus(new Path(shuffleManagerWorkDir, MANAGER_AM_TMP_PATH_NAME)); + assertTrue(tmpDir.length > 0); + StringBuilder fileNames = new StringBuilder(); + Arrays.stream(tmpDir).forEach(curFile -> fileNames.append(",").append(curFile.getPath())); + assertTrue( + fileNames.toString().contains("shuffle-dist") + && fileNames.toString().contains("log4j") + && fileNames.toString().contains(MANAGER_AM_TMP_PATH_NAME) + && fileNames.toString().contains(MANAGER_AM_REMOTE_SHUFFLE_PATH_NAME)); + } + + private static String findShuffleLocalHomeDir() throws IOException { + File parentDir = + findSpecificDirectory("../shuffle-dist/target/", "flink-remote-shuffle-", "-bin"); + File found = + findSpecificDirectory(parentDir.getAbsolutePath(), "flink-remote-shuffle-", ""); + if (found == null) { + throw new IOException("Can't find lib in " + parentDir.getAbsolutePath()); + } + return found.getAbsolutePath(); + } + + private static File findSpecificDirectory(String startAt, String prefix, String suffix) { + File[] subFiles = (new File(startAt)).listFiles(); + assertNotNull(subFiles); + File found = null; + for (File curFile : subFiles) { + if (curFile.isDirectory()) { + if (curFile.getName().startsWith(prefix) && curFile.getName().endsWith(suffix)) { + found = curFile; + break; + } + found = findSpecificDirectory(curFile.getAbsolutePath(), prefix, suffix); + } + } + return found; + } +} diff --git a/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/WordCountITCaseTest.java b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/WordCountITCaseTest.java new file mode 100644 index 00000000..076c17e0 --- /dev/null +++ b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/WordCountITCaseTest.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn; + +import com.alibaba.flink.shuffle.plugin.RemoteShuffleServiceFactory; + +import org.apache.flink.api.common.ExecutionMode; +import org.apache.flink.api.common.InputDependencyConstraint; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.JobType; +import org.apache.flink.runtime.jobmaster.JobResult; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; +import org.apache.flink.streaming.api.graph.GlobalStreamExchangeMode; +import org.apache.flink.streaming.api.graph.StreamGraph; +import org.apache.flink.streaming.api.graph.StreamingJobGraphGenerator; +import org.apache.flink.util.Collector; + +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import static com.alibaba.flink.shuffle.common.utils.CommonUtils.checkArgument; +import static com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions.HA_MODE; +import static com.alibaba.flink.shuffle.core.config.HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM; +import static com.alibaba.flink.shuffle.yarn.utils.TestTimeoutUtils.waitAllCompleted; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** A simple word-count integration test. */ +public class WordCountITCaseTest extends BatchJobTestBase { + private static final Logger LOG = LoggerFactory.getLogger(WordCountITCaseTest.class); + + private static final int NUM_WORDS = 20; + + private static final int WORD_COUNT = 2000; + + @Override + public void setup() throws Exception { + asyncSetupShuffleClusterOnYarn(); + + configuration.setString(HA_MODE, "ZOOKEEPER"); + configuration.setString( + HA_ZOOKEEPER_QUORUM, YarnTestBase.zookeeperTestServer.getConnectString()); + + flinkConfiguration.setString( + "shuffle-service-factory.class", RemoteShuffleServiceFactory.class.getName()); + flinkConfiguration.setString("remote-shuffle.job.memory-per-gate", "8m"); + + flinkConfiguration.setString(HA_MODE.key(), "ZOOKEEPER"); + flinkConfiguration.setString( + HA_ZOOKEEPER_QUORUM.key(), YarnTestBase.zookeeperTestServer.getConnectString()); + flinkConfiguration.setString("remote-shuffle.job.memory-per-partition", "8m"); + flinkConfiguration.setString("remote-shuffle.job.concurrent-readings-per-gate", "5"); + } + + private static void asyncSetupShuffleClusterOnYarn() throws Exception { + long start = System.currentTimeMillis(); + CompletableFuture setupFuture = + CompletableFuture.supplyAsync( + () -> { + try { + RemoteShuffleOnYarnTestCluster.setup(); + return true; + } catch (Exception e) { + LOG.error("Failed to setup shuffle cluster on Yarn, ", e); + return false; + } + }); + List results = + waitAllCompleted(Collections.singletonList(setupFuture), 600, TimeUnit.SECONDS); + long duration = System.currentTimeMillis() - start; + LOG.info("The process of setting up shuffle cluster on Yarn took " + duration + " ms"); + assertEquals(1, results.size()); + assertTrue(results.get(0)); + } + + @Override + void shutdown() { + RemoteShuffleOnYarnTestCluster.shutdown(); + } + + @Test(timeout = 600000L) + public void testWordCount() throws Exception { + StreamExecutionEnvironment env = + StreamExecutionEnvironment.getExecutionEnvironment(flinkConfiguration); + + int parallelism = numTaskManagers * numSlotsPerTaskManager; + env.getConfig().setExecutionMode(ExecutionMode.BATCH); + env.getConfig().setParallelism(parallelism); + env.getConfig().setDefaultInputDependencyConstraint(InputDependencyConstraint.ALL); + env.disableOperatorChaining(); + + DataStream> words = + env.fromSequence(0, NUM_WORDS) + .broadcast() + .map(new WordsMapper()) + .flatMap(new WordsFlatMapper(WORD_COUNT)); + words.keyBy(value -> value.f0) + .sum(1) + .map((MapFunction, Long>) wordCount -> wordCount.f1) + .addSink(new VerifySink(parallelism * WORD_COUNT)); + + StreamGraph streamGraph = env.getStreamGraph(); + streamGraph.setGlobalStreamExchangeMode(GlobalStreamExchangeMode.ALL_EDGES_BLOCKING); + streamGraph.setJobType(JobType.BATCH); + JobGraph jobGraph = StreamingJobGraphGenerator.createJobGraph(streamGraph); + + JobID jobID = flinkCluster.submitJob(jobGraph).get().getJobID(); + JobResult jobResult = flinkCluster.requestJobResult(jobID).get(); + if (jobResult.getSerializedThrowable().isPresent()) { + throw new AssertionError(jobResult.getSerializedThrowable().get()); + } + } + + private static class WordsMapper implements MapFunction { + + private static final long serialVersionUID = 5666190363617738047L; + + private static final String WORD_SUFFIX_1K = getWordSuffix1k(); + + private static String getWordSuffix1k() { + StringBuilder builder = new StringBuilder(); + builder.append("-"); + for (int i = 0; i < 1024; ++i) { + builder.append("0"); + } + return builder.toString(); + } + + @Override + public String map(Long value) { + return "WORD-" + value + WORD_SUFFIX_1K; + } + } + + private static class WordsFlatMapper implements FlatMapFunction> { + + private static final long serialVersionUID = -1503963599349890992L; + + private final int wordsCount; + + public WordsFlatMapper(int wordsCount) { + checkArgument(wordsCount > 0, "Must be positive."); + this.wordsCount = wordsCount; + } + + @Override + public void flatMap(String word, Collector> collector) { + for (int i = 0; i < wordsCount; ++i) { + collector.collect(new Tuple2<>(word, 1L)); + } + } + } + + private static class VerifySink implements SinkFunction { + + private static final long serialVersionUID = -1504978632259778200L; + + private final Long wordCount; + + public VerifySink(long wordCount) { + this.wordCount = wordCount; + } + + @Override + public void invoke(Long value, SinkFunction.Context context) { + assertEquals(wordCount, value); + } + } +} diff --git a/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/YarnTestBase.java b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/YarnTestBase.java new file mode 100644 index 00000000..541093c7 --- /dev/null +++ b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/YarnTestBase.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn; + +import com.alibaba.flink.shuffle.core.utils.TestLogger; +import com.alibaba.flink.shuffle.yarn.zk.ZookeeperTestServer; + +import org.apache.flink.shaded.curator4.org.apache.curator.framework.CuratorFramework; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hdfs.MiniDFSCluster; +import org.apache.hadoop.service.Service; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.server.MiniYARNCluster; +import org.junit.After; +import org.junit.Assert; +import org.junit.ClassRule; +import org.junit.rules.TemporaryFolder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.UUID; + +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +/** + * The cluster is re-used for all tests based on Yarn framework. + * + *

The main goal of this class is to start {@link MiniYARNCluster} and {@link MiniDFSCluster}. + * Users can run the Yarn test case on the mini clusters. + */ +public abstract class YarnTestBase extends TestLogger { + private static final Logger LOG = LoggerFactory.getLogger(YarnTestBase.class); + + private static final int YARN_CLUSTER_START_RETRY_TIMES = 15; + + private static final int YARN_CLUSTER_START_RETRY_INTERVAL_MS = 20000; + + protected static ZookeeperTestServer zookeeperTestServer = new ZookeeperTestServer(); + + protected static CuratorFramework client; + + protected static final String TEST_CLUSTER_NAME_KEY = + "flink-remote-shuffle-yarn-minicluster-name"; + + protected static final int NUM_NODEMANAGERS = 1; + + // Temp directory for mini hdfs cluster + @ClassRule public static TemporaryFolder tmpHDFS = new TemporaryFolder(); + + protected static MiniYARNCluster yarnCluster = null; + + protected static MiniDFSCluster miniDFSCluster = null; + + protected static final YarnConfiguration YARN_CONFIGURATION; + + protected static File yarnSiteXML = null; + + protected static File hdfsSiteXML = null; + + static { + try { + tmpHDFS.create(); + } catch (Exception e) { + LOG.error("Create temporary folder failed, ", e); + } + YARN_CONFIGURATION = new YarnConfiguration(); + YARN_CONFIGURATION.setInt(YarnConfiguration.RM_SCHEDULER_MINIMUM_ALLOCATION_MB, 32); + YARN_CONFIGURATION.setInt( + YarnConfiguration.RM_SCHEDULER_MAXIMUM_ALLOCATION_MB, + 2048); // 2048 is the available memory anyways + YARN_CONFIGURATION.setBoolean( + YarnConfiguration.RM_SCHEDULER_INCLUDE_PORT_IN_NODE_NAME, true); + YARN_CONFIGURATION.setInt(YarnConfiguration.RM_AM_MAX_ATTEMPTS, 2); + YARN_CONFIGURATION.setInt(YarnConfiguration.RM_MAX_COMPLETED_APPLICATIONS, 2); + YARN_CONFIGURATION.setInt(YarnConfiguration.RM_SCHEDULER_MAXIMUM_ALLOCATION_VCORES, 4); + YARN_CONFIGURATION.setInt(YarnConfiguration.DEBUG_NM_DELETE_DELAY_SEC, 3600); + YARN_CONFIGURATION.setBoolean(YarnConfiguration.LOG_AGGREGATION_ENABLED, false); + YARN_CONFIGURATION.setInt( + YarnConfiguration.NM_VCORES, 666); // memory is overwritten in the MiniYARNCluster. + // so we have to change the number of cores for testing. + YARN_CONFIGURATION.setInt( + YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, + 20000); // 20 seconds expiry (to ensure we properly heartbeat with YARN). + YARN_CONFIGURATION.setFloat( + YarnConfiguration.NM_MAX_PER_DISK_UTILIZATION_PERCENTAGE, 99.0F); + } + + @After + public static void shutdown() { + zookeeperTestServer.afterOperations(); + if (yarnCluster != null) { + yarnCluster.stop(); + } + if (miniDFSCluster != null) { + miniDFSCluster.shutdown(); + } + } + + public static void startYARNWithRetries(YarnConfiguration conf, boolean withDFS) { + for (int i = 0; i < YARN_CLUSTER_START_RETRY_TIMES; i++) { + LOG.info("Waiting for the mini yarn cluster, retrying " + i + " times"); + boolean started = startYARNWithConfig(conf, withDFS); + if (started) { + LOG.info("Started yarn mini cluster successfully"); + return; + } + } + LOG.info("Failed to start yarn mini cluster"); + Assert.fail(); + } + + private static boolean startYARNWithConfig(YarnConfiguration conf, boolean withDFS) { + long deadline = System.currentTimeMillis() + YARN_CLUSTER_START_RETRY_INTERVAL_MS; + try { + LOG.info("Starting up MiniYARNCluster"); + if (yarnCluster == null) { + setupMiniYarnCluster(conf); + } + + File targetTestClassesFolder = new File("target/test-classes"); + writeYarnSiteConfigXML(conf, targetTestClassesFolder); + + if (withDFS) { + LOG.info("Starting up MiniDFSCluster"); + setupMiniDFSCluster(targetTestClassesFolder); + } + + assertSame(Service.STATE.STARTED, yarnCluster.getServiceState()); + + // wait for the nodeManagers to connect + boolean needWait = true; + boolean started = false; + while (needWait) { + started = yarnCluster.waitForNodeManagersToConnect(500); + LOG.info("Waiting for node managers to connect"); + needWait = !started && (System.currentTimeMillis() < deadline); + } + + if (!started) { + yarnCluster.stop(); + yarnCluster = null; + } + return started; + } catch (Exception ex) { + LOG.error("setup failure", ex); + Assert.fail(); + } + + assertTrue(yarnCluster.getResourceManager().toString().endsWith("STARTED")); + return true; + } + + private static void setupMiniYarnCluster(YarnConfiguration conf) { + final String testName = conf.get(YarnTestBase.TEST_CLUSTER_NAME_KEY); + yarnCluster = + new MiniYARNCluster( + testName == null ? "YarnTest_" + UUID.randomUUID() : testName, + NUM_NODEMANAGERS, + 1, + 1); + + yarnCluster.init(conf); + yarnCluster.start(); + } + + private static void setupMiniDFSCluster(File targetTestClassesFolder) throws Exception { + if (miniDFSCluster == null) { + Configuration hdfsConfiguration = new Configuration(); + hdfsConfiguration.set( + MiniDFSCluster.HDFS_MINIDFS_BASEDIR, tmpHDFS.getRoot().getAbsolutePath()); + miniDFSCluster = + new MiniDFSCluster.Builder(hdfsConfiguration) + .numDataNodes(1) + .waitSafeMode(false) + .build(); + miniDFSCluster.waitClusterUp(); + + hdfsConfiguration = miniDFSCluster.getConfiguration(0); + writeHDFSSiteConfigXML(hdfsConfiguration, targetTestClassesFolder); + YARN_CONFIGURATION.addResource(hdfsConfiguration); + } + } + + // write yarn-site.xml to target/test-classes + public static void writeYarnSiteConfigXML(Configuration yarnConf, File targetFolder) + throws IOException { + yarnSiteXML = new File(targetFolder, "/yarn-site.xml"); + try (FileWriter writer = new FileWriter(yarnSiteXML)) { + yarnConf.writeXml(writer); + writer.flush(); + } + } + + // write hdfs-site.xml to target/test-classes + private static void writeHDFSSiteConfigXML(Configuration coreSite, File targetFolder) + throws IOException { + hdfsSiteXML = new File(targetFolder, "/hdfs-site.xml"); + try (FileWriter writer = new FileWriter(hdfsSiteXML)) { + coreSite.writeXml(writer); + writer.flush(); + } + } +} diff --git a/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/entry/manager/AppClientEnvsTest.java b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/entry/manager/AppClientEnvsTest.java new file mode 100644 index 00000000..d170f652 --- /dev/null +++ b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/entry/manager/AppClientEnvsTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn.entry.manager; + +import org.apache.commons.cli.ParseException; +import org.junit.Test; + +import java.io.IOException; + +import static com.alibaba.flink.shuffle.yarn.utils.YarnConstants.MANAGER_AM_MEMORY_OVERHEAD_SIZE_KEY; +import static com.alibaba.flink.shuffle.yarn.utils.YarnConstants.MANAGER_AM_MEMORY_SIZE_KEY; +import static com.alibaba.flink.shuffle.yarn.utils.YarnConstants.MIN_VALID_AM_MEMORY_SIZE_MB; +import static org.junit.Assert.assertEquals; + +/** Unit tests for {@link AppClientEnvs}. */ +public class AppClientEnvsTest { + @Test + public void testShuffleManagerArgs() throws IOException, ParseException { + String args = + "-D remote-shuffle.yarn.manager-home-dir=aa " + + "-D ab1=cd1 " + + "-D ab2=cd2 " + + "-D " + + MANAGER_AM_MEMORY_SIZE_KEY + + "=10 " + + "-D " + + MANAGER_AM_MEMORY_OVERHEAD_SIZE_KEY + + "=10"; + String[] splitArgs = args.split(" "); + AppClientEnvs envs = + new AppClientEnvs(new org.apache.hadoop.conf.Configuration(), splitArgs); + String shuffleManagerArgString = envs.getShuffleManagerArgs(); + + assertContainsKeyVal(shuffleManagerArgString, "remote-shuffle.yarn.manager-home-dir", "aa"); + assertContainsKeyVal(shuffleManagerArgString, "ab1", "cd1"); + assertContainsKeyVal(shuffleManagerArgString, "ab2", "cd2"); + assertContainsKeyVal( + shuffleManagerArgString, + MANAGER_AM_MEMORY_SIZE_KEY, + String.valueOf(MIN_VALID_AM_MEMORY_SIZE_MB)); + assertContainsKeyVal( + shuffleManagerArgString, + MANAGER_AM_MEMORY_OVERHEAD_SIZE_KEY, + String.valueOf(MIN_VALID_AM_MEMORY_SIZE_MB)); + assertEquals(6, shuffleManagerArgString.split("=").length); + } + + private static void assertContainsKeyVal(String argString, String key, String val) { + assertEquals(val, argString.split("-D " + key + "=")[1].split(" ")[0]); + } +} diff --git a/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/utils/DeployOnYarnUtilsTest.java b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/utils/DeployOnYarnUtilsTest.java new file mode 100644 index 00000000..b0d58a43 --- /dev/null +++ b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/utils/DeployOnYarnUtilsTest.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn.utils; + +import com.alibaba.flink.shuffle.common.config.Configuration; + +import org.apache.commons.cli.ParseException; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.junit.Test; + +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** Unit Tests for {@link DeployOnYarnUtils}. */ +public class DeployOnYarnUtilsTest { + + @Test + public void testHadoopConfToMaps() { + YarnConfiguration hadoopConf = new YarnConfiguration(); + String testKey = "a.b.test.key"; + String testVal = "a.b.test.val"; + hadoopConf.set(testKey, testVal); + Map confMaps = DeployOnYarnUtils.hadoopConfToMaps(hadoopConf); + assertEquals(confMaps.get(testKey), testVal); + } + + @Test + public void testParseParameters() throws ParseException { + String[] inputArgs = + "-D a.k1=a.v1 -Da.k2=a.v2 -D remote-shuffle.yarn.worker-stop-on-failure=true" + .split(" "); + Configuration conf = DeployOnYarnUtils.parseParameters(inputArgs); + assertEquals(conf.getString("a.k1"), "a.v1"); + assertEquals(conf.getString("a.k2"), "a.v2"); + assertEquals(conf.getBoolean(YarnOptions.WORKER_STOP_ON_FAILURE), true); + } +} diff --git a/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/utils/TestTimeoutUtils.java b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/utils/TestTimeoutUtils.java new file mode 100644 index 00000000..41476f1b --- /dev/null +++ b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/utils/TestTimeoutUtils.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn.utils; + +import org.junit.Test; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +/** Utils to deal with timeout cases for test. */ +public class TestTimeoutUtils { + public static List waitAllCompleted( + List> futuresList, long timeout, TimeUnit unit) throws Exception { + CompletableFuture futureResult = + CompletableFuture.allOf(futuresList.toArray(new CompletableFuture[0])); + futureResult.get(timeout, unit); + return futuresList.stream() + .filter(future -> future.isDone() && !future.isCompletedExceptionally()) + .map(CompletableFuture::join) + .collect(Collectors.toList()); + } + + @Test(expected = Exception.class) + public void testTimeoutWorkNormal() throws Exception { + CompletableFuture setupFuture = + CompletableFuture.supplyAsync( + () -> { + try { + Thread.sleep(2000); + return true; + } catch (InterruptedException e) { + e.printStackTrace(); + } + return false; + }); + waitAllCompleted(Collections.singletonList(setupFuture), 1, TimeUnit.SECONDS); + } +} diff --git a/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/zk/ZookeeperTestServer.java b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/zk/ZookeeperTestServer.java new file mode 100644 index 00000000..81775891 --- /dev/null +++ b/shuffle-yarn/src/test/java/com/alibaba/flink/shuffle/yarn/zk/ZookeeperTestServer.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.flink.shuffle.yarn.zk; + +import com.alibaba.flink.shuffle.common.utils.CommonUtils; + +import org.apache.curator.test.TestingServer; +import org.junit.rules.ExternalResource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.io.IOException; + +/** A class to start a Zookeeper {@link TestingServer}. */ +public class ZookeeperTestServer extends ExternalResource { + private static final Logger LOG = LoggerFactory.getLogger(ZookeeperTestServer.class); + + @Nullable private TestingServer zooKeeperServer; + + public String getConnectString() throws Exception { + initTestingServer(); + verifyIsRunning(); + return zooKeeperServer.getConnectString(); + } + + private void verifyIsRunning() { + CommonUtils.checkState(zooKeeperServer != null); + } + + private void initTestingServer() throws Exception { + if (zooKeeperServer == null) { + zooKeeperServer = new TestingServer(true); + } + } + + private void terminateZooKeeperServer() throws IOException { + if (zooKeeperServer != null) { + zooKeeperServer.stop(); + zooKeeperServer = null; + } + } + + public void afterOperations() { + try { + terminateZooKeeperServer(); + } catch (IOException e) { + LOG.warn("Could not terminate the zookeeper server properly, ", e); + } + } +} diff --git a/shuffle-yarn/src/test/resources/log4j2-test.properties b/shuffle-yarn/src/test/resources/log4j2-test.properties new file mode 100644 index 00000000..337c65fc --- /dev/null +++ b/shuffle-yarn/src/test/resources/log4j2-test.properties @@ -0,0 +1,26 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +rootLogger.level=OFF +rootLogger.appenderRef.test.ref=TestLogger +appender.testlogger.name=TestLogger +appender.testlogger.type=CONSOLE +appender.testlogger.target=SYSTEM_ERR +appender.testlogger.layout.type=PatternLayout +appender.testlogger.layout.pattern=%d{ISO8601} %-4r [%t] %-5p %c %x - %m%n diff --git a/tools/build_docker_image.sh b/tools/build_docker_image.sh new file mode 100755 index 00000000..8d83bac1 --- /dev/null +++ b/tools/build_docker_image.sh @@ -0,0 +1,36 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +bin=`dirname "$0"` +cd $bin/../ + +mvn clean install -DskipTests + +if [[ $? != 0 ]]; then + echo "Compile error" + exit 1; +fi + +REMOTE_SHUFFLE_VERSION=`xmllint --xpath "//*[local-name()='project']/*[local-name()='version']/text()" ./pom.xml` + +REPOSITORY='flink-remote-shuffle' + +LOCAL_IMAGE="${REPOSITORY}:${REMOTE_SHUFFLE_VERSION}" + +docker rmi ${LOCAL_IMAGE} > /dev/null 2>&1 +docker build --build-arg REMOTE_SHUFFLE_VERSION=${REMOTE_SHUFFLE_VERSION} -t ${LOCAL_IMAGE} . diff --git a/tools/change_version.sh b/tools/change_version.sh new file mode 100755 index 00000000..a4e12d9c --- /dev/null +++ b/tools/change_version.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +OLD_VERSION="0.1.0-SNAPSHOT" +NEW_VERSION="1.0-SNAPSHOT" + +bin=`dirname "$0"` +cd $bin/../ + +# change version in all pom files +find .. -name 'pom.xml' -type f -exec perl -pi -e 's#'"$OLD_VERSION"'#'"$NEW_VERSION"'#' {} \; diff --git a/tools/maven/checkstyle.xml b/tools/maven/checkstyle.xml new file mode 100644 index 00000000..00e72ee0 --- /dev/null +++ b/tools/maven/checkstyle.xml @@ -0,0 +1,516 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tools/maven/suppressions.xml b/tools/maven/suppressions.xml new file mode 100644 index 00000000..db648ed9 --- /dev/null +++ b/tools/maven/suppressions.xml @@ -0,0 +1,30 @@ + + + + + + + + + diff --git a/tools/publish_docker_image.sh b/tools/publish_docker_image.sh new file mode 100755 index 00000000..0db08784 --- /dev/null +++ b/tools/publish_docker_image.sh @@ -0,0 +1,52 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +bin=`dirname "$0"` +cd $bin/../ + +REGISTRY='docker.io' +NAMESPACE='flinkremoteshuffle' +REPOSITORY='flink-remote-shuffle' + +if [[ "$1" ]] ; then + REGISTRY="$1" +fi + +if [[ "$2" ]] ; then + NAMESPACE="$2" +fi + +if [[ "$3" ]] ; then + REPOSITORY="$3" +fi + +REMOTE_SHUFFLE_VERSION=`xmllint --xpath "//*[local-name()='project']/*[local-name()='version']/text()" ./pom.xml` +LOCAL_IMAGE="${REPOSITORY}:${REMOTE_SHUFFLE_VERSION}" +REMOTE_IMAGE="${REGISTRY}/${NAMESPACE}/${REPOSITORY}:${REMOTE_SHUFFLE_VERSION}" +echo $REMOTE_IMAGE + +sh ./tools/build_docker_image.sh + +if [[ $? != 0 ]]; then + echo "Build docker image error" + exit 1; +fi + +docker rmi ${REMOTE_IMAGE} > /dev/null 2>&1 +docker tag ${LOCAL_IMAGE} ${REMOTE_IMAGE} +docker push ${REMOTE_IMAGE}