From d7b72419aa588b84edee9379b1791a5bfc7ff04d Mon Sep 17 00:00:00 2001 From: zhangjiajin Date: Thu, 31 Mar 2022 09:30:49 +0800 Subject: [PATCH] release 1.8.2 --- README.cn.md | 12 +- README.md | 12 +- RELEASE.md | 2 +- docs/cn/algorithms/nago.md | 5 +- docs/cn/developer/quick_start.md | 1 - docs/cn/user/security_configure.md | 45 ++-- docs/en/algorithms/nago.md | 6 +- docs/en/developer/quick_start.md | 1 - docs/en/user/security_configure.md | 63 ++--- evaluate_service/README.cn.md | 8 +- evaluate_service/README.md | 7 +- evaluate_service/RELEASE.md | 2 +- evaluate_service/evaluate_service/__init__.py | 2 +- .../hardwares/davinci/davinci.py | 7 +- .../hardwares/davinci/get_latency_from_log.sh | 7 +- .../hardwares/davinci/inference_atlas300.sh | 22 +- .../samples/atlas300/src/model_process.cpp | 3 +- evaluate_service/evaluate_service/main.py | 9 +- .../evaluate_service/run_flask.py | 1 - .../evaluate_service/security/__init__.py | 25 -- .../evaluate_service/security/args.py | 120 --------- .../evaluate_service/security/check_env.py | 25 -- .../evaluate_service/security/conf.py | 140 ----------- .../security/kmc/encrypt_key.py | 121 ---------- .../evaluate_service/security/kmc/kmc.py | 228 ------------------ .../evaluate_service/security/kmc/utils.py | 44 ---- .../evaluate_service/security/load_pickle.py | 57 ----- .../evaluate_service/security/post.py | 57 ----- .../evaluate_service/security/run_dask.py | 139 ----------- .../evaluate_service/security/utils.py | 46 ---- .../evaluate_service/security/verify_cert.py | 38 --- .../security/verify_config.py | 152 ------------ .../evaluate_service/security/zmq_op.py | 70 ------ evaluate_service/setup.py | 2 +- setup.py | 2 +- vega/__init__.py | 2 +- .../prune_ea/prune_trainer_callback.py | 4 +- vega/core/pipeline/search_pipe_step.py | 4 +- vega/datasets/common/pacs.py | 19 +- vega/datasets/conf/pacs.py | 2 - vega/evaluator/conf.py | 1 + vega/evaluator/device_evaluator.py | 12 +- vega/evaluator/tools/evaluate_davinci_bolt.py | 5 +- vega/model_zoo/model_zoo.py | 3 +- vega/modules/operators/cell.py | 10 +- .../operators/functions/mindspore_fn.py | 11 +- vega/networks/mindspore/super_network.py | 27 ++- vega/networks/pytorch/losses/decaug_loss.py | 4 +- vega/report/report_persistence.py | 26 +- vega/report/report_server.py | 7 +- vega/security/conf.py | 2 +- vega/security/kmc/encrypt_key.py | 4 +- vega/security/load_pickle.py | 2 + vega/security/verify_config.py | 2 +- vega/trainer/callbacks/model_checkpoint.py | 3 + vega/trainer/conf.py | 1 + vega/trainer/trainer_torch.py | 8 +- 57 files changed, 210 insertions(+), 1430 deletions(-) delete mode 100644 evaluate_service/evaluate_service/security/__init__.py delete mode 100644 evaluate_service/evaluate_service/security/args.py delete mode 100644 evaluate_service/evaluate_service/security/check_env.py delete mode 100644 evaluate_service/evaluate_service/security/conf.py delete mode 100644 evaluate_service/evaluate_service/security/kmc/encrypt_key.py delete mode 100644 evaluate_service/evaluate_service/security/kmc/kmc.py delete mode 100644 evaluate_service/evaluate_service/security/kmc/utils.py delete mode 100644 evaluate_service/evaluate_service/security/load_pickle.py delete mode 100644 evaluate_service/evaluate_service/security/post.py delete mode 100644 evaluate_service/evaluate_service/security/run_dask.py delete mode 100644 evaluate_service/evaluate_service/security/utils.py delete mode 100644 evaluate_service/evaluate_service/security/verify_cert.py delete mode 100644 evaluate_service/evaluate_service/security/verify_config.py delete mode 100644 evaluate_service/evaluate_service/security/zmq_op.py diff --git a/README.cn.md b/README.cn.md index 85521c8..157c6af 100644 --- a/README.cn.md +++ b/README.cn.md @@ -9,13 +9,13 @@ --- -**Vega ver1.8.0 发布** +**Vega ver1.8.2 发布** -- 特性增强 +- 错误修正 - - 安全增强,组件间通信支持安全协议。 - - 提供独立的评估服务安装。 - - 更新Auto-lane模型,提供auto-lane推理代码。 + - 修正文档中链接错误。 + - 评估服务支持多输入。 + - 修正在NPU下使用Apex的错误。 --- @@ -84,7 +84,7 @@ vega ./examples/nas/cars/cars.yml -s | 对象 | 参考 | | :--: | :-- | -| **用户** | [安装指导](./docs/cn/user/install.md)、[部署指导](./docs/cn/user/deployment.md)、[安全配置](./docs/cn/user/security_configure.md)、[配置指导](./docs/cn/user/config_reference.md)、[示例参考](./docs/cn/user/examples.md)、[评估服务](./evaluate_service/docs/cn/evaluate_service.md) | +| **用户** | [安装指导](./docs/cn/user/install.md)、[部署指导](./docs/cn/user/deployment.md)、[安全配置](./docs/cn/user/security_configure.md)、[配置指导](./docs/cn/user/config_reference.md)、[示例参考](./docs/cn/user/examples.md)、[评估服务](./evaluate_service/README.cn.md) | | **开发者** | [开发者指导](./docs/cn/developer/developer_guide.md)、[快速入门指导](./docs/cn/developer/quick_start.md)、[数据集指导](./docs/cn/developer/datasets.md)、[算法开发指导](./docs/cn/developer/new_algorithm.md) | ## FAQ diff --git a/README.md b/README.md index 091b3b8..1969dc3 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,13 @@ --- -**Vega ver1.8.0 released** +**Vega ver1.8.2 released** -- Feature enhancement: +- Bug Fixed: - - Security enhancement: Security protocols communication. - - Provide evaluation service release package. - - Update the auto-lane model and provide auto-lane inference sample code. + - Fixed bad document links. + - The model to be evaluated supports multiple imputs. + - Fixed using Apex on the NPU. --- @@ -85,7 +85,7 @@ vega ./examples/nas/cars/cars.yml -s | Reader | Refrence | | :--: | :-- | -| **User** | [Install Guide](./docs/en/user/install.md), [Deployment Guide](./docs/en/user/deployment.md), [Configuration Guide](./docs/en/user/config_reference.md), [Security Configuration](./docs/en/user/security_configure.md), [Examples](./docs/en/user/examples.md), [Evaluate Service](./evaluate_service/docs/en/evaluate_service.md) | +| **User** | [Install Guide](./docs/en/user/install.md), [Deployment Guide](./docs/en/user/deployment.md), [Configuration Guide](./docs/en/user/config_reference.md), [Security Configuration](./docs/en/user/security_configure.md), [Examples](./docs/en/user/examples.md), [Evaluate Service](./evaluate_service/README.md) | | **Developer** | [Development Reference](./docs/en/developer/developer_guide.md), [Quick Start Guide](./docs/en/developer/quick_start.md), [Dataset Guide](./docs/en/developer/datasets.md), [Algorithm Development Guide](./docs/en/developer/new_algorithm.md) | ## FAQ diff --git a/RELEASE.md b/RELEASE.md index 9564fff..9f38a25 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,4 +1,4 @@ -**Vega ver1.8.0 released:** +**Vega ver1.8.2 released:** **Introduction** diff --git a/docs/cn/algorithms/nago.md b/docs/cn/algorithms/nago.md index a8b0c8d..18f1601 100644 --- a/docs/cn/algorithms/nago.md +++ b/docs/cn/algorithms/nago.md @@ -66,15 +66,14 @@ search_space: ### 4.2 搜索策略设置 NAGO的搜索空间适用于任何贝叶斯优化算法。在VEGA中我们采用了BOHB, 所以我们需要在 `nago.yml` 文件中也设置BOHB算法的基本参数。 -例如,下面的设置会跑50个搜索循环的BOHB,并用最少30 epochs和最多120 epochs来训练和评估所生成的神经网络结构。 ```yaml search_algorithm: type: BohbHpo policy: total_epochs: -1 - repeat_times: 50 - num_samples: 350 + repeat_times: 1 + num_samples: 7 max_epochs: 120 min_epochs: 30 eta: 2 diff --git a/docs/cn/developer/quick_start.md b/docs/cn/developer/quick_start.md index ff2deec..bfbb6df 100644 --- a/docs/cn/developer/quick_start.md +++ b/docs/cn/developer/quick_start.md @@ -161,7 +161,6 @@ nas: lr_scheduler: type: MultiStepLR params: - warmup: False milestones: [30] gamma: 0.5 loss: diff --git a/docs/cn/user/security_configure.md b/docs/cn/user/security_configure.md index 087490f..9f387b3 100644 --- a/docs/cn/user/security_configure.md +++ b/docs/cn/user/security_configure.md @@ -32,7 +32,10 @@ openssl genrsa -out ca.key 4096 openssl req -new -x509 -key ca.key -out ca.crt -subj "/C=/ST=/L=/O=/OU=/CN=" ``` -注意:以上``、``、``、``、``、``根据实际情况填写,本文后面的配置也是同样的。并且CA的配置需要和其他的不同。 +注意: + +1. 以上``、``、``、``、``、``根据实际情况填写,去掉符号`<>`,本文后面的配置也是同样的。并且CA的配置需要和其他的不同。 +2. RSA密钥长度建议在3072位及以上,如本例中使用4096长度。 ## 3. 生成评估服务使用的证书 @@ -139,12 +142,12 @@ vega-encrypt_key --cert=client.crt --key=client.key --key_component_1=ksmaster_c ```shell mkdir ~/.vega mv * ~/.vega/ -chmod -R 600 ~/.vega +chmod 600 ~/.vega/* ``` 说明: -1. 如上的秘钥、证书、加密材料也可以放到其他目录位置,注意访问权限要设置为`600`,并在后继的配置文件中同步修改该文件的位置。 +1. 如上的秘钥、证书、加密材料也可以放到其他目录位置,注意访问权限要设置为`600`,并在后继的配置文件中同步修改该文件的位置,需要使用绝对路径。 2. 在训练集群上,需要保留`ca.crt`、`client.key`、`client.crt`、`ksmaster_client.dat`、`ksstandby_client.dat`、`server_dask.key`、`server_dask.crt`、`client_dask.key`、`client_dask.crt`,并删除其他文件。 3. 评估服务上,需要保留`ca.crt`、`server.key`、`server.crt`、`ksmaster_server.dat`、`ksstandby_server.dat`,并删除其他文件。 @@ -155,36 +158,36 @@ chmod -R 600 ~/.vega server.ini: ```ini -[security] - ca_cert=<~/.vega/car.crt> +[security] # 以下文件路径需要修改为绝对路径 + ca_cert=<~/.vega/ca.crt> server_cert_dask=<~/.vega/server_dask.crt> server_secret_key_dask=<~/.vega/server_dask.key> client_cert_dask=<~/.vega/client_dask.crt> - client_secret_key_dask=<~/.vega/ client_dask.key> + client_secret_key_dask=<~/.vega/client_dask.key> ``` client.ini: ```ini -[security] - ca_cert=<~/.vega/car.crt> +[security] # 以下文件路径需要修改为绝对路径 + ca_cert=<~/.vega/ca.crt> client_cert=<~/.vega/client.crt> client_secret_key=<~/.vega/client.key> - encrypted_password=<加密后的client端的口令> #如果使用普通证书, 此项配置为空 - key_component_1=<~/.vega/ksmaster_client.dat> #如果使用普通证书, 此项配置为空 - key_component_2=<~/.vega/ksstandby_client.dat> #如果使用普通证书, 此项配置为空 + encrypted_password=<加密后的client端的口令> # 如果使用普通证书, 此项配置为空 + key_component_1=<~/.vega/ksmaster_client.dat> # 如果使用普通证书, 此项配置为空 + key_component_2=<~/.vega/ksstandby_client.dat> # 如果使用普通证书, 此项配置为空 ``` 在评估服务器上,需要配置`~/.vega/vega.ini`: ```ini -[security] -ca_cert=<~/.vega/car.crt> -server_cert=<~/.vega/server.crt> -server_secret_key=<~/.vega/server.key> -encrypted_password=<加密后的server端的口令> #如果使用普通证书, 此项配置为空 -key_component_1=<~/.vega/ksmaster_server.dat> #如果使用普通证书, 此项配置为空 -key_component_2=<~/.vega/ksstandby_server.dat> #如果使用普通证书, 此项配置为空 +[security] # 以下文件路径需要修改为绝对路径 + ca_cert=<~/.vega/ca.crt> + server_cert=<~/.vega/server.crt> + server_secret_key=<~/.vega/server.key> + encrypted_password=<加密后的server端的口令> # 如果使用普通证书, 此项配置为空 + key_component_1=<~/.vega/ksmaster_server.dat> # 如果使用普通证书, 此项配置为空 + key_component_2=<~/.vega/ksstandby_server.dat> # 如果使用普通证书, 此项配置为空 ``` ## 7. 配置评估服务守护服务 @@ -197,7 +200,7 @@ key_component_2=<~/.vega/ksstandby_server.dat> #如果使用普通证书, 此 vega-evaluate_service-service -i -w ``` -然后再创建一个守护服务的文件`evaluate-service`,脚本内容如下,注意替换为真实的脚本位置: +然后再创建一个守护服务的文件`evaluate-service.service`,脚本内容如下,注意替换为真实的脚本位置: ```ini [Unit] @@ -211,10 +214,10 @@ vega-evaluate_service-service -i -w WantedBy=multi-user.target ``` -然后将`evaluate-service`拷贝到目录`/usr/lib/systemd/system`中,并启动该服务: +然后将`evaluate-service.service`拷贝到目录`/usr/lib/systemd/system`中,并启动该服务: ```shell -sudo cp evaluate-service /usr/lib/systemd/system/ +sudo cp evaluate-service.service /usr/lib/systemd/system/ sudo systemctl daemon-reload sudo systemctl start evaluate-service ``` diff --git a/docs/en/algorithms/nago.md b/docs/en/algorithms/nago.md index 430cc1c..5ff9bc5 100644 --- a/docs/en/algorithms/nago.md +++ b/docs/en/algorithms/nago.md @@ -66,15 +66,15 @@ The exact code for the architecture generator (return a trainable PyTorch networ ### 4.2 Search Strategy -Our NAGO search space is amenable to any Bayesian optimisation search strategies. In this code package, we use BOHB to perform the optimisation and the configuration of BOHB needs to be specified in `nago.yml`. The example below defines a BOHB run with `eta=2` and `t=50` search iterations. The minimum and maxmimum training epochs used for evaluating a recommended configuration is 30 and 120 respectively. +Our NAGO search space is amenable to any Bayesian optimisation search strategies. In this code package, we use BOHB to perform the optimisation and the configuration of BOHB needs to be specified in `nago.yml`. ```yaml search_algorithm: type: BohbHpo policy: total_epochs: -1 - repeat_times: 50 - num_samples: 350 + repeat_times: 1 + num_samples: 7 max_epochs: 120 min_epochs: 30 eta: 2 diff --git a/docs/en/developer/quick_start.md b/docs/en/developer/quick_start.md index 4f600eb..dc05c14 100644 --- a/docs/en/developer/quick_start.md +++ b/docs/en/developer/quick_start.md @@ -161,7 +161,6 @@ nas: lr_scheduler: type: MultiStepLR params: - warmup: False milestones: [30] gamma: 0.5 loss: diff --git a/docs/en/user/security_configure.md b/docs/en/user/security_configure.md index 33df6d0..fd3b272 100644 --- a/docs/en/user/security_configure.md +++ b/docs/en/user/security_configure.md @@ -26,14 +26,17 @@ pip3 install --user pyOpenSSL==19.0.0 ## 2. Generate the CA Certificate Run the following command to generate a CA certificate: +Note: The length of the RSA key must be 3072 bits or more. The following RSA key length configuration also requires the same. ```shell openssl genrsa -out ca.key 4096 openssl req -new -x509 -key ca.key -out ca.crt -subj "/C=/ST=/L=/O=/OU=/CN=" ``` -Note: ``, ``, ``, ``, ``, and `` should be set based on the situation. The configuration in this document is the same. -In addition, the CA configuration must be different from other configurations. +Note: + +1. ``, ``, ``, ``, ``, and `` should be set based on the situation. The values do not contain `< >'. In addition, the CA configuration must be different from other configurations. +2. It is recommended that the length of the RSA key be 3072 bits or more. ## 3. Generate the Certificate for Evaluate_service @@ -138,15 +141,15 @@ vega-encrypt_key --cert=client.crt --key=client.key --key_component_1=ksmaster_c Create the `.vega` directory in the home directory of the current user, copy the generated keys, certificates, and encryption materials to this directory, and change the permission. ```shell -mkdir -/.vega -mv * -/.vega/ -chmod -R 600 -/.vega +mkdir ~/.vega +mv * ~/.vega/ +chmod 600 ~/.vega/* ``` Description: -1. The preceding keys, certificates, and encryption materials can also be stored in other directories. The access permission must be set to 600 and the file location must be changed in subsequent configuration files. -2. In the train cluster, reserve `ca.crt`, `client.key`, `client.crt`, `ksmaster_client.dat`, `ksstandby_client.dat`, and `server_dask.key. `, `server_dask.crt`, `client_dask.key`, `client_dask.crt`, and delete other files. +1. The preceding keys, certificates, and encryption materials can also be stored in other directories. The access permission must be set to 600, and the file location must be changed to an absolute path in subsequent configuration files. +2. In the train cluster, reserve `ca.crt`, `client.key`, `client.crt`, `ksmaster_client.dat`, `ksstandby_client.dat`, and `server_dask.key`, `server_dask.crt`, `client_dask.key`, `client_dask.crt`, and delete other files. 3. In the evaluate service, reserve `ca.crt`, `server.key`, `server.crt`, `ksmaster_server.dat`, and `ksstandby_server.dat` files, and delete other files. Create `server.ini` and `client.ini` in the `~/.vega` directory. @@ -156,36 +159,36 @@ In the train cluster, configure `~/.vega/server.ini` and `~/.vega/client.ini`. server.ini: ```ini -[security] -ca_cert=<-/.vega/car.crt> -server_cert_dask=<-/.vega/server_dask.crt> -server_secret_key_dask=<-/.vega/server_dask.key> -client_cert_dask=<-/.vega/client_dask.crt> -client_secret_key_dask=<-/.vega/ client_dask.key> +[security] # The following file paths need to be changed to absolute paths. + ca_cert=<~/.vega/ca.crt> + server_cert_dask=<~/.vega/server_dask.crt> + server_secret_key_dask=<~/.vega/server_dask.key> + client_cert_dask=<~/.vega/client_dask.crt> + client_secret_key_dask=<~/.vega/ client_dask.key> ``` client.ini: ```ini -[security] -ca_cert=<-/.vega/car.crt> -client_cert=<-/.vega/client.crt> -client_secret_key=<-/.vega/client.key> -encrypted_password= #If a common certificate is used, leave this parameter blank. -If the key_component_1=<~/.vega/ksmaster_client.dat> #If a common certificate is used, leave this parameter blank. -If the key_component_2=<~/.vega/ksstandby_client.dat> #If a common certificate is used, leave this parameter blank. +[security] # The following file paths need to be changed to absolute paths. + ca_cert=<~/.vega/ca.crt> + client_cert=<~/.vega/client.crt> + client_secret_key=<~/.vega/client.key> + encrypted_password= # If a common certificate is used, leave this parameter blank. + key_component_1=<~/.vega/ksmaster_client.dat> # If a common certificate is used, leave this parameter blank. + key_component_2=<~/.vega/ksstandby_client.dat> # If a common certificate is used, leave this parameter blank. ``` On the evaluation server, configure `~/.vega/vega.ini`. ```ini -[security] -ca_cert=<-/.vega/car.crt> -server_cert=<-/.vega/server.crt> -server_secret_key=<-/.vega/server.key> -encrypted_password= #If a common certificate is used, leave this parameter blank. -If the key_component_1=<~/.vega/ksmaster_server.dat> # uses a common certificate, leave this parameter blank. -If the key_component_2=<~/.vega/ksstandby_server.dat> # uses a common certificate, leave this parameter blank. +[security] # The following file paths need to be changed to absolute paths. + ca_cert=<~/.vega/ca.crt> + server_cert=<~/.vega/server.crt> + server_secret_key=<~/.vega/server.key> + encrypted_password= # If a common certificate is used, leave this parameter blank. + key_component_1=<~/.vega/ksmaster_server.dat> # uses a common certificate, leave this parameter blank. + key_component_2=<~/.vega/ksstandby_server.dat> # uses a common certificate, leave this parameter blank. ``` ## 7. Configuring the Evaluation Service Daemon Service @@ -198,7 +201,7 @@ Create a script `run_evaluate_service.sh` for starting the evaluation service. R vega-evaluate_service-service -i -w ``` -Create a daemon service file `evaluate-service`. The script content is as follows. Replace it with the actual script location. +Create a daemon service file `evaluate-service.service`. The script content is as follows. Replace it with the actual script location. ```ini [Unit] @@ -212,10 +215,10 @@ RestartSec=60 WantedBy=multi-user.target ``` -Copy `evaluate-service` to the `/usr/lib/systemd/system` directory and start the service. +Copy `evaluate-service.service` to the `/usr/lib/systemd/system` directory and start the service. ```shell -sudo cp evaluate-service /usr/lib/systemd/system/ +sudo cp evaluate-service.service /usr/lib/systemd/system/ sudo systemctl daemon-reload sudo systemctl start evaluate-service ``` diff --git a/evaluate_service/README.cn.md b/evaluate_service/README.cn.md index 6e72cfb..781360b 100644 --- a/evaluate_service/README.cn.md +++ b/evaluate_service/README.cn.md @@ -35,7 +35,7 @@ ### 3.1 安装配置Atlas300环境 -首先需要配置Ascend 300环境,请参考[配置文档](./ascend_310.md)。 +首先需要配置Ascend 300环境,请参考[配置文档](./docs/cn/ascend_310.md)。 然后请安装评估服务,请执行如下命令安装: @@ -57,8 +57,12 @@ cd build/intermediates/host cmake ../../src -DCMAKE_CXX_COMPILER=g++ -DCMAKE_SKIP_RPATH=TRUE make && echo "[INFO] check the env sucess!" ``` +### 3.2 编译推理程序 +参考 [https://gitee.com/ascend/tools/tree/master/msame](https://gitee.com/ascend/tools/tree/master/msame), 下载代码并完成编译。 +并把编译后的可执行文件拷贝到`~/.local/lib/python3.7/site-packages/evaluate_service/hardwares/davinci/`目录下。 -### 3.2 启动评估服务 + +### 3.3 启动评估服务 使用如下命令启动评估服务: diff --git a/evaluate_service/README.md b/evaluate_service/README.md index 08342e0..69eb6ae 100644 --- a/evaluate_service/README.md +++ b/evaluate_service/README.md @@ -40,7 +40,7 @@ Configure the hardware (Atlas 200 DK, Atlas 300, or mobile phone) by following t Please contact us. ### 3.1.2 Install and configure the Atlas 300 Environment (Optional) - +Please refer to [configuration documentation](./docs/en/ascend_310.md). For details, see the Huawei official tutorial at . Note: The preceding documents may be updated. Please follow the released updates or obtain the corresponding guide documents. After the environment is installed, you need to set environment variables. For details, see the preceding guide. To facilitate environment configuration, we provide the environment variable configuration template [env_atlas300.sh](https://github.com/huawei-noah/vega/blob/master/evaluate_service/hardwares/davinci/env/env_atlas300.sh) for your reference. The actual environment prevails. @@ -55,8 +55,11 @@ Please contact us. Please contact us. +### 3.2 Compile the inference program +Please refer to [https://gitee.com/ascend/tools/tree/master/msame](https://gitee.com/ascend/tools/tree/master/msame). +Download the code and finish compiling, then copy the compiled executable file to the `~/.local/lib/python3.7/site-packages/evaluate_service/hardwares/davinci/` directory. -### 3.2 Start the evaluation service +### 3.3 Start the evaluation service Run the following command to start the evaluate service: ```shell diff --git a/evaluate_service/RELEASE.md b/evaluate_service/RELEASE.md index 85891f2..6cf330e 100644 --- a/evaluate_service/RELEASE.md +++ b/evaluate_service/RELEASE.md @@ -1,4 +1,4 @@ -**Evaluate Service ver1.8.0 released:** +**Evaluate Service ver1.8.2 released:** **Introduction** diff --git a/evaluate_service/evaluate_service/__init__.py b/evaluate_service/evaluate_service/__init__.py index 42f0b74..f58810b 100644 --- a/evaluate_service/evaluate_service/__init__.py +++ b/evaluate_service/evaluate_service/__init__.py @@ -16,4 +16,4 @@ """Evaluate service.""" -__version__ = "1.8.0" +__version__ = "1.8.2" diff --git a/evaluate_service/evaluate_service/hardwares/davinci/davinci.py b/evaluate_service/evaluate_service/hardwares/davinci/davinci.py index 72dbda4..b146edb 100644 --- a/evaluate_service/evaluate_service/hardwares/davinci/davinci.py +++ b/evaluate_service/evaluate_service/hardwares/davinci/davinci.py @@ -76,7 +76,7 @@ def _compile_atlas300(self, save_path): except subprocess.CalledProcessError as exc: logging.error("compile failed. the return message is : {}.".format(exc)) - def inference(self, converted_model, input_data, is_last=False, cal_metric=False, **kwargs): + def inference(self, converted_model, input_data, is_last=False, cal_metric=False, muti_input=False, **kwargs): """Inference in Davinci. :param converted_model: converted model file @@ -94,11 +94,12 @@ def inference(self, converted_model, input_data, is_last=False, cal_metric=False command_line = self._get_200dk_infer_cmd(save_path=log_save_path) result_file = os.path.join(log_save_path, "result_file") else: - if not os.path.exists(os.path.join(share_dir, "main")): + if not os.path.exists(os.path.join(share_dir, "main")) and not muti_input: self._compile_atlas300(share_dir) # execute the Davinci program + command_line = ["bash", self.current_path + "/inference_atlas300.sh", - input_data, converted_model, share_dir, log_save_path] + input_data, converted_model, share_dir, log_save_path, str(muti_input)] result_file = os.path.join(log_save_path, "result.txt") try: diff --git a/evaluate_service/evaluate_service/hardwares/davinci/get_latency_from_log.sh b/evaluate_service/evaluate_service/hardwares/davinci/get_latency_from_log.sh index 534c7ab..32de61f 100644 --- a/evaluate_service/evaluate_service/hardwares/davinci/get_latency_from_log.sh +++ b/evaluate_service/evaluate_service/hardwares/davinci/get_latency_from_log.sh @@ -1,2 +1,7 @@ LOG_FILE=$1 -cat $LOG_FILE |grep costTime | awk -F ' ' '{print $NF}' +err_num=`cat $LOG_FILE | grep ERROR |wc -l` +if [[ $err_num == 0 ]];then + cat $LOG_FILE |grep "Inference time:" | awk -F ' ' '{print $NF}' |awk -F 'ms' '{print $1}' +else + echo None +fi diff --git a/evaluate_service/evaluate_service/hardwares/davinci/inference_atlas300.sh b/evaluate_service/evaluate_service/hardwares/davinci/inference_atlas300.sh index d8a91aa..63dedec 100644 --- a/evaluate_service/evaluate_service/hardwares/davinci/inference_atlas300.sh +++ b/evaluate_service/evaluate_service/hardwares/davinci/inference_atlas300.sh @@ -2,12 +2,20 @@ INPUT_DATA=$1 OM_MODEL=$2 EXECUTE_FILE_PATH=$3 LOG_SAVE_PATH=$4 +MUTI_INPUT=$5 -cp $OM_MODEL $LOG_SAVE_PATH/ -# cp $INPUT_DATA $LOG_SAVE_PATH/ -cp $EXECUTE_FILE_PATH/main $LOG_SAVE_PATH/ -cp $EXECUTE_FILE_PATH/acl.json $LOG_SAVE_PATH/ -cd $LOG_SAVE_PATH/ -#sudo env "LD_LIBRARY_PATH=/usr/local/Ascend/acllib/lib64:/usr/local/Ascend/add-ons:/usr/local/Ascend/driver/lib64/" ./main >$WORK_DIR/ome.log -./main >$LOG_SAVE_PATH/ome.log \ No newline at end of file +if [ $MUTI_INPUT == "True" ]; then + cp $OM_MODEL $LOG_SAVE_PATH/ + script_dir=$(cd $(dirname $0);pwd) + chmod +x $script_dir/msame + $script_dir/msame --model $LOG_SAVE_PATH/davinci_model.om --output $LOG_SAVE_PATH --outfmt TXT >$LOG_SAVE_PATH/ome.log + +else + cp $OM_MODEL $LOG_SAVE_PATH/ + cp $EXECUTE_FILE_PATH/main $LOG_SAVE_PATH/ + cp $EXECUTE_FILE_PATH/acl.json $LOG_SAVE_PATH/ + cd $LOG_SAVE_PATH/ + + ./main >$LOG_SAVE_PATH/ome.log +fi \ No newline at end of file diff --git a/evaluate_service/evaluate_service/hardwares/davinci/samples/atlas300/src/model_process.cpp b/evaluate_service/evaluate_service/hardwares/davinci/samples/atlas300/src/model_process.cpp index dd32b56..ed0b09d 100644 --- a/evaluate_service/evaluate_service/hardwares/davinci/samples/atlas300/src/model_process.cpp +++ b/evaluate_service/evaluate_service/hardwares/davinci/samples/atlas300/src/model_process.cpp @@ -328,8 +328,7 @@ Result ModelProcess::Execute() gettimeofday(&start, NULL); aclError ret = aclmdlExecute(modelId_, input_, output_); gettimeofday(&end, NULL); - cout<< "costTime "<< eplasedtime(&end, &start)/1000<(end_time-start_time)/CLOCKS_PER_SEC*1000< 0: - raise ValueError("{} contains invalid characters.".format(value)) - - -def _check_dict(dict_value, pattern): - """Check dict.""" - if not isinstance(dict_value, dict): - return - for item in dict_value: - value = dict_value[item] - if isinstance(value, dict): - _check_dict(value, pattern) - else: - _check_value(value, pattern) - - -def check_msg(msg): - """Check msg.""" - _check_dict(msg, pattern="[^_A-Za-z0-9\\s:/.~-]") - - -def check_args(args): - """Check args.""" - args_dict = vars(args) - _check_dict(args_dict, pattern="[^_A-Za-z0-9:/.~-]") - - -def check_yml(config_yaml): - """Check yml.""" - if config_yaml is None: - raise ValueError("config path can't be None or empty") - if os.stat(config_yaml).st_uid != os.getuid(): - raise ValueError(f"The file {config_yaml} not belong to the current user") - with open(config_yaml) as f: - raw_dict = yaml.safe_load(f) - _check_dict(raw_dict, pattern=r"[^_A-Za-z0-9\s\<\>=\[\]\(\),!\{\}:/.~-]") - - -def check_job_id(job_id): - """Check Job id.""" - if not isinstance(job_id, str): - raise TypeError('"job_id" must be str, not {}'.format(type(job_id))) - _check_value(job_id, pattern="[^_A-Za-z0-9]") - - -def check_input_shape(input_shape): - """Check input shape.""" - if not isinstance(input_shape, str): - raise TypeError('"input_shape" must be str, not {}'.format(type(input_shape))) - _check_value(input_shape, pattern="[^_A-Za-z0-9:,]") - - -def check_out_nodes(out_nodes): - """Check out nodes.""" - if not isinstance(out_nodes, str): - raise TypeError('"out_nodes" must be str, not {}'.format(type(out_nodes))) - _check_value(out_nodes, pattern="[^_A-Za-z0-9:/]") - - -def check_backend(backend): - """Check backend.""" - if backend not in ["tensorflow", "caffe", "onnx", "mindspore"]: - raise ValueError("The backend only support tensorflow, caffe, onnx and mindspore.") - - -def check_hardware(hardware): - """Check hardware.""" - if hardware not in ["Davinci", "Bolt", "Kirin990_npu"]: - raise ValueError("The hardware only support Davinci and Bolt.") - - -def check_precision(precision): - """Check precision.""" - if precision.upper() not in ["FP32", "FP16"]: - raise ValueError("The precision only support FP32 and FP16.") - - -def check_repeat_times(repeat_times): - """Check repeat times.""" - MAX_EVAL_EPOCHS = 10000 - if not isinstance(repeat_times, int): - raise TypeError('"repeat_times" must be int, not {}'.format(type(repeat_times))) - if not 0 < repeat_times <= MAX_EVAL_EPOCHS: - raise ValueError("repeat_times {} is not in valid range (1-{})".format(repeat_times, MAX_EVAL_EPOCHS)) - - -def path_verify(path): - """Verify path.""" - return re.sub(r"[^_A-Za-z0-9\/.]", "", path) diff --git a/evaluate_service/evaluate_service/security/check_env.py b/evaluate_service/evaluate_service/security/check_env.py deleted file mode 100644 index c394a02..0000000 --- a/evaluate_service/evaluate_service/security/check_env.py +++ /dev/null @@ -1,25 +0,0 @@ -# -*- coding:utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Check security env.""" - - -__all__ = ["check_env"] - - -def check_env(args) -> bool: - """Check security env.""" - return True diff --git a/evaluate_service/evaluate_service/security/conf.py b/evaluate_service/evaluate_service/security/conf.py deleted file mode 100644 index 4e9fa03..0000000 --- a/evaluate_service/evaluate_service/security/conf.py +++ /dev/null @@ -1,140 +0,0 @@ -# -*- coding:utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Security config. - -~/.vega/server.ini - -[security] - ca_cert=<~/.vega/car.crt> - server_cert_dask=<~/.vega/server_dask.crt> - server_secret_key_dask=<~/.vega/server_dask.key> - client_cert_dask=<~/.vega/client_dask.crt> - client_secret_key_dask=<~/.vega/ client_dask.key> - -~/.vega/client.ini - -[security] - ca_cert=<~/.vega/car.crt> - client_cert=<~/.vega/client.crt> - client_secret_key=<~/.vega/client.key> - encrypted_password= - key_component_1=<~/.vega/ksmaster_client.dat> - key_component_2=<~/.vega/ksstandby_client.dat> - -""" - -import os -import logging -import configparser -from .verify_config import check_risky_files - - -class Config(): - """Security Config.""" - - def load(self) -> bool: - """Load from config file.""" - if not check_risky_files([self.file_name]): - return False - config = configparser.ConfigParser() - try: - config.read(self.file_name) - except Exception: - logging.error(f"Failed to read setting from {self.file_name}") - return False - if "security" not in config.sections(): - return False - keys = [] - pass_check_keys = ["encrypted_password", "white_list"] - for key in config["security"]: - if key not in self.keys: - return False - setattr(self, key, config.get("security", key)) - if key not in pass_check_keys and not check_risky_files([config.get("security", key)]): - return False - keys.append(key) - if len(keys) != len(self.keys): - missing_keys = list(set(self.keys) - set(keys)) - logging.error(f"setting items {missing_keys} are missing in {self.file_name}") - return False - return True - - -class ServerConfig(Config): - """Security Config.""" - - def __init__(self): - """Initialize.""" - self.ca_cert = None - self.server_cert_dask = None - self.server_secret_key_dask = None - self.client_cert_dask = None - self.client_secret_key_dask = None - self.file_name = os.path.expanduser("~/.vega/server.ini") - self.keys = ["ca_cert", "server_cert_dask", "server_secret_key_dask", "client_cert_dask", - "client_secret_key_dask"] - - -class ClientConfig(Config): - """Security Config.""" - - def __init__(self): - """Initialize.""" - self.ca_cert = None - self.client_cert = None - self.client_secret_key = None - self.encrypted_password = None - self.key_component_1 = None - self.key_component_2 = None - self.white_list = [] - self.file_name = os.path.expanduser("~/.vega/client.ini") - self.keys = [ - "ca_cert", "client_cert", "client_secret_key", "encrypted_password", - "key_component_1", "key_component_2", "white_list"] - - -_server_config = ServerConfig() -_client_config = ClientConfig() - - -def load_config(_type: str) -> bool: - """Load security config.""" - if _type not in ["all", "server", "client"]: - logging.error(f"not support security config type: {_type}") - return False - if _type in ["server", "all"]: - global _server_config - if not _server_config.load(): - logging.error("load server security config fail.") - return False - if _type in ["client", "all"]: - global _client_config - if not _client_config.load(): - logging.error("load client security config fail.") - return False - return True - - -def get_config(_type: str) -> Config: - """Get config.""" - if _type not in ["server", "client"]: - logging.error(f"not support security config type: {_type}") - return False - if _type == "server": - return _server_config - else: - return _client_config diff --git a/evaluate_service/evaluate_service/security/kmc/encrypt_key.py b/evaluate_service/evaluate_service/security/kmc/encrypt_key.py deleted file mode 100644 index 7691c1d..0000000 --- a/evaluate_service/evaluate_service/security/kmc/encrypt_key.py +++ /dev/null @@ -1,121 +0,0 @@ -# -*- coding:utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Load the Certificate and encrypt the passwd.""" - -import argparse -import getpass -import logging -import subprocess -from OpenSSL.crypto import load_certificate, FILETYPE_PEM, load_privatekey -from . import kmc -from .utils import check_password_rule - - -def encrypt_mm(origin_mm, key_component_1, key_component_2): - """Encrypt the passwd.""" - ret = kmc.init(key_component_1, key_component_2, 9) - if ret is False: - logging.error("kmc init error.") - return "" - domain_id = 0 - result = kmc.encrypt(domain_id, origin_mm) - kmc.finalize() - return result - - -def validate_certificate(cert, key, origin_mm): - """Validate the certificate.""" - flag = True - with open(key, "r", encoding="utf-8") as f: - key_value = f.read() - try: - load_privatekey(FILETYPE_PEM, key_value, passphrase=origin_mm.encode('utf-8')) - except Exception: - flag = False - logging.error("Wrong PEM.") - return flag - - # check signature algorithm - with open(cert, "r", encoding="utf-8") as f: - cert_value = f.read() - cert_value = load_certificate(FILETYPE_PEM, cert_value) - enc_algorithm = cert_value.get_signature_algorithm() - if enc_algorithm in b'sha1WithRSAEncryption' b'md5WithRSAEncryption': - logging.warning("Insecure encryption algorithm: %s", enc_algorithm) - # check key length - - p1 = subprocess.Popen(["openssl", "x509", "-in", cert, "-text", "-noout"], - stdout=subprocess.PIPE, shell=False) - p2 = subprocess.Popen(["grep", "RSA Public-Key"], stdin=p1.stdout, stdout=subprocess.PIPE, shell=False) - p3 = subprocess.Popen(["tr", "-cd", "[0-9]"], stdin=p2.stdout, stdout=subprocess.PIPE, shell=False) - RSA_key = p3.communicate()[0] - if int(RSA_key) < 2048: - logging.warning("Insecure key length: %d", int(RSA_key)) - return flag - - -def import_certificate(args, origin_mm): - """Load the certificate.""" - # 1.validate private key and certification, if not pass, program will exit - ret = validate_certificate(args.cert, args.key, origin_mm) - if not ret: - logging.error("Validate certificate failed.") - return 0 - - # 2.encrypt private key's passwd. - encrypt = encrypt_mm(origin_mm, args.key_component_1, args.key_component_2) - if not encrypt: - logging.error("kmc encrypt private key error.") - return 0 - logging.warning(f"Encrypt sucuess. The encrypted of your input is {encrypt}") - logging.warning(f"The key components are {args.key_component_1} and {args.key_component_2}, please keep it safe.") - - return True - - -def args_parse(): - """Parse the input args.""" - parser = argparse.ArgumentParser(description='Certificate import') - parser.add_argument("--cert", default="./kmc/config/crt/sever.cert", type=str, - help="The path of certificate file") - parser.add_argument("--key", default='./kmc/config/crt/sever.key', type=str, - help="The path of private Key file.") - parser.add_argument("--key_component_1", default='./kmc/config/ksf/ksmaster.dat', type=str, - help="key material 1.") - parser.add_argument("--key_component_2", default='./kmc/config/ksf/ksstandby.dat', type=str, - help="key material 2.") - - args = parser.parse_args() - - return args - - -def main(): - """Run the encrypt process.""" - args = args_parse() - logging.info("process encrypt begin.") - origin_mm = getpass.getpass("Please enter the password to be encrypted: ") - if not check_password_rule(origin_mm): - logging.info("You should re-generate your server cert/key with following rules:") - logging.info("1. equals to or longer than 8 letters") - logging.info("2. contains at least one digit letter") - logging.info("3. contains at least one capital letter") - logging.info("4. contains at least one lowercase letter") - - ret = import_certificate(args, origin_mm) - if not ret: - logging.error("Encrypt failed.") diff --git a/evaluate_service/evaluate_service/security/kmc/kmc.py b/evaluate_service/evaluate_service/security/kmc/kmc.py deleted file mode 100644 index 2dcf548..0000000 --- a/evaluate_service/evaluate_service/security/kmc/kmc.py +++ /dev/null @@ -1,228 +0,0 @@ -# -*- coding:utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Huawei KMC library.""" - -import ctypes -import os -from ctypes.util import find_library -import logging -import platform - -__all__ = ["init", "encrypt", "decrypt", "check_and_update_mk", "update_root_key", "hmac", "hmac_verify", "finalize"] - -_kmc_dll: ctypes.CDLL = None -_libc_dll: ctypes.CDLL = None -ADVANCE_DAY = 3 - - -def hmac(domain_id: int, plain_text: str) -> str: - """Encode HMAC code.""" - p_char = ctypes.c_char_p() - hmac_len = ctypes.c_int(0) - c_plain_text = ctypes.create_string_buffer(plain_text.encode()) - _kmc_dll.KeHmacByDomain.restype = ctypes.c_int - _kmc_dll.KeHmacByDomain.argtypes = [ - ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.POINTER(ctypes.c_int)] - ret = _kmc_dll.KeHmacByDomain( - domain_id, c_plain_text, len(plain_text), ctypes.byref(p_char), ctypes.pointer(hmac_len)) - if ret != 0: - logging.error(f"failed to call KeHmacByDomain, code={ret}") - value = p_char.value.decode() - ret = _libc_dll.free(p_char) - if ret != 0: - logging.error(f"failed to free resource, code={ret}") - return value - - -def hmac_verify(domain_id: int, plain_text: str, hmac_text: str) -> bool: - """Verify HMAC code.""" - c_plain_text = ctypes.create_string_buffer(plain_text.encode()) - c_hmac_text = ctypes.create_string_buffer(hmac_text.encode()) - _kmc_dll.KeHmacVerifyByDomain.restype = ctypes.c_int - _kmc_dll.KeHmacVerifyByDomain.argtypes = [ - ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_int] - ret = _kmc_dll.KeHmacVerifyByDomain(domain_id, c_plain_text, len(plain_text), c_hmac_text, len(c_hmac_text)) - return ret - - -def encrypt(domain_id: int, plain_text: str) -> str: - """Encrypt.""" - p_char = ctypes.c_char_p() - cipher_len = ctypes.c_int(0) - c_plain_text = ctypes.create_string_buffer(plain_text.encode()) - - _kmc_dll.KeEncryptByDomain.restype = ctypes.c_int - _kmc_dll.KeEncryptByDomain.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), - ctypes.POINTER(ctypes.c_int)] - ret = _kmc_dll.KeEncryptByDomain(domain_id, c_plain_text, len(plain_text), ctypes.byref(p_char), - ctypes.pointer(cipher_len)) - if ret != 0: - logging.error("KeEncryptByDomain failed.") - return "" - value = p_char.value.decode() - ret = _libc_dll.free(p_char) - if ret != 0: - logging.error("free memory error. ret=%d" % ret) - return value - - -def _decrypt(domain_id: int, cipher_text: str): - """Decrypt.""" - p_char = ctypes.c_char_p() - plain_len = ctypes.c_int(0) - c_cipher_text = ctypes.create_string_buffer(cipher_text.encode()) - _kmc_dll.KeDecryptByDomain.restype = ctypes.c_int - _kmc_dll.KeDecryptByDomain.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), - ctypes.POINTER(ctypes.c_int)] - ret = _kmc_dll.KeDecryptByDomain(domain_id, c_cipher_text, len(cipher_text), ctypes.byref(p_char), - ctypes.pointer(plain_len)) - if ret != 0: - logging.error("KeDecryptByDomain failed.") - return "" - value = p_char.value.decode() - ret = _libc_dll.free(p_char) - if ret != 0: - logging.error("free memory error. ret=%d" % ret) - return value - - -def check_and_update_mk(domain_id: int, advance_day: int) -> bool: - """Check and update mk.""" - ret = _kmc_dll.KeCheckAndUpdateMk(domain_id, advance_day) - if ret != 0: - logging.error(f"failed to call KeCheckAndUpdateMk, code={ret}") - return False - return True - - -def update_root_key() -> bool: - """Update root key.""" - ret = _kmc_dll.KeUpdateRootKey() - if ret != 0: - logging.error(f"failed to call KeUpdateRootKey, code={ret}") - return False - return True - - -def finalize() -> None: - """Finalize.""" - _kmc_dll.KeFinalize.restype = ctypes.c_int - _kmc_dll.KeFinalize.argtypes = [] - _kmc_dll.KeFinalize() - - -def _get_lib_path(): - pkg_path = os.path.dirname(__file__) - if platform.processor() == "x86_64": - return os.path.join(pkg_path, "x86_64/libkmcext.so") - else: - return os.path.join(pkg_path, "aarch64/libkmcext.so") - - -def _load_dll(kmc_dll_path: str) -> None: - global _kmc_dll - if _kmc_dll: - return - global _libc_dll - if _libc_dll: - return - _libc_dll = ctypes.CDLL(find_library("c")) - _kmc_dll = ctypes.CDLL(kmc_dll_path) - - -@ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p) -def _logger(level: ctypes.c_int, msg: ctypes.c_char_p): - logging.info("level:%d, msg:%s" % (level, str(msg))) - - -def _init_log(): - _kmc_dll.KeSetLoggerCallback.restype = None - _kmc_dll.KeSetLoggerCallback.argtypes = [ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p)] - _kmc_dll.KeSetLoggerCallback(_logger) - _kmc_dll.KeSetLoggerLevel.restype = None - _kmc_dll.KeSetLoggerLevel.argtypes = [ctypes.c_int] - _kmc_dll.KeSetLoggerLevel(2) # DISABLE(0),ERROR(1),WARN(2),INFO(3),DEBUG(4),TRACE(5) - - -class KMCConfig(ctypes.Structure): - _fields_ = [ - ("primaryKeyStoreFile", ctypes.c_char * 4096), - ("standbyKeyStoreFile", ctypes.c_char * 4096), - ("domainCount", ctypes.c_int), - ("role", ctypes.c_int), - ("procLockPerm", ctypes.c_int), - ("sdpAlgId", ctypes.c_int), - ("hmacAlgId", ctypes.c_int), - ("semKey", ctypes.c_int) - ] - - -def _init_kmc_config(primary_key_store_file, standby_key_store_file, alg_id, domain_count): - config = KMCConfig() - config.primaryKeyStoreFile = primary_key_store_file.encode() - config.standbyKeyStoreFile = standby_key_store_file.encode() - config.domainCount = domain_count - config.role = 1 # Agent 0; Master 1 - config.procLockPerm = 0o0600 - config.sdpAlgId = alg_id - config.hmacAlgId = 2052 # HMAC_SHA256 2052; HMAC_SHA384 2053 HMAC_SHA512 2054 - config.semKey = 0x20161516 - _kmc_dll.KeInitialize.restype = ctypes.c_int - _kmc_dll.KeInitialize.argtypes = [ctypes.POINTER(KMCConfig)] - return _kmc_dll.KeInitialize(ctypes.byref(config)) - - -def init(primary_key_store_file: str, standby_key_store_file: str, alg_id: int, domain_count=3) -> bool: - """Initialize.""" - if alg_id not in [5, 7, 8, 9]: # AES128_CBC, AES256_CBC, AES128_GCM, AES256_GCM - logging.error(f"alg (id={alg_id}) is not legal") - return False - _load_dll(_get_lib_path()) - _init_log() - ret = _init_kmc_config(primary_key_store_file, standby_key_store_file, alg_id, domain_count) - if ret != 0: - logging.error(f"failed to call KeInitialized, code={ret}") - return False - return True - - -def decrypt(cert_pem_file, secret_key_file, key_mm, key_component_1, key_component_2): - """Decrypt the passwd.""" - sdp_alg_id = 9 - # Make sure ssl certificate file exist - ca_file_list = (cert_pem_file, secret_key_file) - for file in ca_file_list: - if file and os.path.exists(file): - continue - else: - logging.error("SSL Certificate files does not exist! Please check config.yaml and cert file.") - raise FileNotFoundError - - primary_keyStoreFile = key_component_1 - standby_keyStoreFile = key_component_2 - ret = init(primary_keyStoreFile, standby_keyStoreFile, sdp_alg_id) - if ret is False: - logging.error("kmc init error.") - raise Exception('ERROR: kmc init failed!') - domain_id = 0 - decrypt_mm = _decrypt(domain_id, key_mm) - if decrypt_mm == "": - logging.error("kmc init error.") - raise Exception('ERROR: kmc init failed!') - check_and_update_mk(domain_id, ADVANCE_DAY) - finalize() - return decrypt_mm diff --git a/evaluate_service/evaluate_service/security/kmc/utils.py b/evaluate_service/evaluate_service/security/kmc/utils.py deleted file mode 100644 index f99bf2f..0000000 --- a/evaluate_service/evaluate_service/security/kmc/utils.py +++ /dev/null @@ -1,44 +0,0 @@ -# -*- coding:utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Some tools.""" -import re -import logging - - -def check_password_rule(password): - """Check password rule.""" - digit_regex = re.compile(r'\d') - upper_regex = re.compile(r'[A-Z]') - lower_regex = re.compile(r'[a-z]') - - if len(password) < 8: - logging.warning("The length must >= 8") - return False - - if len(digit_regex.findall(password)) == 0: - logging.warning("Must contains digit letters") - return False - - if len(upper_regex.findall(password)) == 0: - logging.warning("Must contains capital letters") - return False - - if len(lower_regex.findall(password)) == 0: - logging.warning("Must contains lowercase letters") - return False - - return True diff --git a/evaluate_service/evaluate_service/security/load_pickle.py b/evaluate_service/evaluate_service/security/load_pickle.py deleted file mode 100644 index df63f23..0000000 --- a/evaluate_service/evaluate_service/security/load_pickle.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Load pickle.""" - -import pickle - -__all__ = ["restricted_loads"] - - -safe_builtins = { - 'vega', - 'torch', - 'torchvision', - 'functools', - 'timm', - 'mindspore', - 'tensorflow', - 'numpy', - 'imageio', - 'collections', -} - - -class RestrictedUnpickler(pickle.Unpickler): - """Restrict unpickler.""" - - def __init__(self, file, fix_imports, encoding, errors, security): - super(RestrictedUnpickler, self).__init__(file=file, fix_imports=fix_imports, encoding=encoding, errors=errors) - self.security = security - - def find_class(self, module, name): - """Find class.""" - _class = super().find_class(module, name) - if self.security: - if module.split('.')[0] in safe_builtins: - return _class - raise pickle.UnpicklingError(f"global '{module}' is forbidden") - else: - return _class - - -def restricted_loads(file, fix_imports=True, encoding="ASCII", errors="strict", security=False): - """Load obj.""" - return RestrictedUnpickler(file, fix_imports=fix_imports, encoding=encoding, errors=errors, - security=security).load() diff --git a/evaluate_service/evaluate_service/security/post.py b/evaluate_service/evaluate_service/security/post.py deleted file mode 100644 index a5110e1..0000000 --- a/evaluate_service/evaluate_service/security/post.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Rest post operation in security mode.""" - -import urllib -import json -import logging -import requests -from .conf import get_config -from .utils import create_context -from .args import check_msg -from .verify_cert import verify_cert - - -def post(host, files, data): - """Post a REST requstion in security mode.""" - sec_cfg = get_config('client') - - ca_file = sec_cfg.ca_cert - cert_pem_file = sec_cfg.client_cert - secret_key_file = sec_cfg.client_secret_key - encrypted_password = sec_cfg.encrypted_password - key_component_1 = sec_cfg.key_component_1 - key_component_2 = sec_cfg.key_component_2 - - if not cert_pem_file or not secret_key_file or not ca_file: - logging.error("CERT file is not existed.") - - if not verify_cert(ca_file, cert_pem_file): - logging.error(f"The cert {ca_file} and {cert_pem_file} are invalid, please check.") - - if encrypted_password == "": - context = create_context(ca_file, cert_pem_file, secret_key_file) - else: - context = create_context(ca_file, cert_pem_file, secret_key_file, encrypted_password, key_component_1, - key_component_2) - if host.lower().startswith('https') is False: - raise Exception(f'The host {host} must start with https') - prepped = requests.Request(method="POST", url=host, files=files, data=data).prepare() - request = urllib.request.Request(host, data=prepped.body, method='POST') - request.add_header("Content-Type", prepped.headers['Content-Type']) - response = urllib.request.urlopen(request, context=context) # nosec - result = json.loads(response.read().decode('utf8')) - check_msg(dict((key, value) for key, value in result.items() if key != 'error_message')) - return result diff --git a/evaluate_service/evaluate_service/security/run_dask.py b/evaluate_service/evaluate_service/security/run_dask.py deleted file mode 100644 index f403954..0000000 --- a/evaluate_service/evaluate_service/security/run_dask.py +++ /dev/null @@ -1,139 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Run dask scheduler and worker.""" -import os -import subprocess -import shutil -import logging -import socket -import random -from distributed import Client -from distributed.security import Security -from .conf import get_config -from .verify_cert import verify_cert - - -sec_cfg = get_config('server') - - -def get_client_security(address): - """Get client.""" - address = address.replace("tcp", "tls") - if not verify_cert(sec_cfg.ca_cert, sec_cfg.client_cert_dask): - logging.error(f"The cert {sec_cfg.ca_cert} and {sec_cfg.client_cert_dask} are invalid, please check.") - sec = Security(tls_ca_file=sec_cfg.ca_cert, - tls_client_cert=sec_cfg.client_cert_dask, - tls_client_key=sec_cfg.client_secret_key_dask, - require_encryption=True) - return Client(address, security=sec) - - -def get_address_security(master_host, master_port): - """Get address.""" - return "tls://{}:{}".format(master_host, master_port) - - -def run_scheduler_security(ip, port, tmp_file): - """Run scheduler.""" - if not verify_cert(sec_cfg.ca_cert, sec_cfg.server_cert_dask): - logging.error(f"The cert {sec_cfg.ca_cert} and {sec_cfg.server_cert_dask} are invalid, please check.") - return subprocess.Popen( - [ - "dask-scheduler", - "--no-dashboard", - "--no-show", - f"--tls-ca-file={sec_cfg.ca_cert}", - f"--tls-cert={sec_cfg.server_cert_dask}", - f"--tls-key={sec_cfg.server_secret_key_dask}", - f"--host={ip}", - "--protocol=tls", - f"--port={port}", - f"--scheduler-file={tmp_file}", - f"--local-directory={os.path.dirname(tmp_file)}", - ], - env=os.environ - ) - - -def _available_port(min_port, max_port) -> int: - _sock = socket.socket() - while True: - port = random.randint(min_port, max_port) - try: - _sock.bind(('', port)) - _sock.close() - return port - except Exception: - logging.debug('Failed to get available port, continue.') - continue - return None - - -def run_local_worker_security(slave_ip, address, local_dir): - """Run dask-worker on local node.""" - address = address.replace("tcp", "tls") - nanny_port = _available_port(30000, 30999) - worker_port = _available_port(29000, 29999) - pid = subprocess.Popen( - [ - "dask-worker", - address, - '--nthreads=1', - '--nprocs=1', - '--memory-limit=0', - f"--local-directory={local_dir}", - f"--tls-ca-file={sec_cfg.ca_cert}", - f"--tls-cert={sec_cfg.client_cert_dask}", - f"--tls-key={sec_cfg.client_secret_key_dask}", - "--no-dashboard", - f"--host={slave_ip}", - "--protocol=tls", - f"--nanny-port={nanny_port}", - f"--worker-port={worker_port}", - ], - env=os.environ - ) - return pid - - -def run_remote_worker_security(slave_ip, address, local_dir): - """Run dask-worker on remote node.""" - address = address.replace("tcp", "tls") - nanny_port = _available_port(30000, 30999) - worker_port = _available_port(29000, 29999) - pid = subprocess.Popen( - [ - "ssh", - slave_ip, - shutil.which("dask-worker"), - address, - '--nthreads=1', - '--nprocs=1', - '--memory-limit=0', - f"--local-directory={local_dir}", - f"--tls-ca-file={sec_cfg.ca_cert}", - f"--tls-cert={sec_cfg.client_cert_dask}", - f"--tls-key={sec_cfg.client_secret_key_dask}", - "--no-dashboard", - f"--host={slave_ip}", - "--protocol=tls", - f"--nanny-port={nanny_port}", - f"--worker-port={worker_port}", - ], - env=os.environ - ) - return pid diff --git a/evaluate_service/evaluate_service/security/utils.py b/evaluate_service/evaluate_service/security/utils.py deleted file mode 100644 index 9b6c220..0000000 --- a/evaluate_service/evaluate_service/security/utils.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Context utils.""" -import ssl -import sys -import logging - - -def create_context(ca_file, cert_pem_file, secret_key_file, key_mm=None, key_component_1=None, key_component_2=None): - """Create the SSL context.""" - ciphers = "ECDHE-ECDSA-AES128-CCM:ECDHE-ECDSA-AES256-CCM:ECDHE-ECDSA-AES128-GCM-SHA256" \ - ":ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384" \ - ":DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:DHE-DSS-AES128-GCM-SHA256" \ - ":DHE-DSS-AES256-GCM-SHA384:DHE-RSA-AES128-CCM:DHE-RSA-AES256-CCM" - context = ssl.SSLContext(ssl.PROTOCOL_TLS) - context.options += ssl.OP_NO_TLSv1 - context.options += ssl.OP_NO_TLSv1_1 - if sys.version_info >= (3, 7): - context.options += ssl.OP_NO_TLSv1_2 - context.options += ssl.OP_NO_RENEGOTIATION - context.options -= ssl.OP_ALL - context.verify_mode = ssl.CERT_REQUIRED - context.set_ciphers(ciphers) - if key_mm is not None: - from .kmc.kmc import decrypt - logging.debug("Using encrypted key.") - if key_component_1 is None or key_component_2 is None: - logging.error("For encrypted key, the component must be provided.") - decrypt_mm = decrypt(cert_pem_file, secret_key_file, key_mm, key_component_1, key_component_2) - context.load_cert_chain(cert_pem_file, secret_key_file, password=decrypt_mm) - else: - context.load_cert_chain(cert_pem_file, secret_key_file) - context.load_verify_locations(ca_file) - return context diff --git a/evaluate_service/evaluate_service/security/verify_cert.py b/evaluate_service/evaluate_service/security/verify_cert.py deleted file mode 100644 index cdc7238..0000000 --- a/evaluate_service/evaluate_service/security/verify_cert.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Verify cert.""" - -import logging - - -def verify_cert(ca_cert_file, cert_file): - """Verify the cert.""" - from OpenSSL.crypto import load_certificate, FILETYPE_PEM, X509Store, X509StoreContext, X509StoreContextError - ca_cert = load_certificate(FILETYPE_PEM, open(ca_cert_file, "r", encoding="utf-8").read()) - cert = load_certificate(FILETYPE_PEM, open(cert_file, 'r', encoding="utf-8").read()) - if ca_cert.has_expired() or cert.has_expired(): - logging.error("The cert is expired, please check.") - return False - store = X509Store() - store.add_cert(ca_cert) - ctx = X509StoreContext(store, cert) - try: - ctx.verify_certificate() - except X509StoreContextError: - logging.error("Certificate signature failure, ca cert file and cert file not match.") - return False - return True diff --git a/evaluate_service/evaluate_service/security/verify_config.py b/evaluate_service/evaluate_service/security/verify_config.py deleted file mode 100644 index f5c910e..0000000 --- a/evaluate_service/evaluate_service/security/verify_config.py +++ /dev/null @@ -1,152 +0,0 @@ -# -*- coding:utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Run pipeline.""" - -import configparser -import logging -import os -import stat - - -def _file_exist(path): - return os.access(path, os.F_OK) - - -def _file_belong_to_current_user(path): - return os.stat(path).st_uid == os.getuid() - - -def _file_other_writable(path): - return os.stat(path).st_mode & stat.S_IWOTH - - -def _file_is_link(path): - return os.path.islink(path) - - -def _get_risky_files_by_suffix(suffixes, path): - risky_files = [] - non_current_user_files = [] - others_writable_files = [] - link_files = [] - for suffix in suffixes: - if not path.endswith(suffix): - continue - abs_path = os.path.abspath(path) - if _file_exist(abs_path): - risky_files.append(abs_path) - if not _file_belong_to_current_user(abs_path): - non_current_user_files.append(abs_path) - if _file_other_writable(abs_path): - others_writable_files.append(abs_path) - if _file_is_link(abs_path): - link_files.append(abs_path) - - return risky_files, non_current_user_files, others_writable_files, link_files - - -def get_risky_files(config): - """Get contained risky file (.pth/.pth.tar/.onnx/.py).""" - risky_files = [] - non_current_user_files = [] - others_writable_files = [] - link_files = [] - from vega.common.config import Config - if not isinstance(config, Config): - return risky_files, non_current_user_files, others_writable_files, link_files - - for value in config.values(): - if isinstance(value, Config) and value.get("type") == "DeepLabNetWork": - value = value.get("dir").rstrip("/") + "/" + value.get("name").lstrip("/") + ".py" - if isinstance(value, str): - temp_risky_files, temp_non_current_user_files, temp_other_writable_files, temp_link_files \ - = _get_risky_files_by_suffix([".pth", ".pth.tar", ".py"], value) - risky_files.extend(temp_risky_files) - non_current_user_files.extend(temp_non_current_user_files) - others_writable_files.extend(temp_other_writable_files) - link_files.extend(temp_link_files) - temp_risky_files, temp_non_current_user_files, temp_other_writable_files, temp_link_files \ - = get_risky_files(value) - risky_files.extend(temp_risky_files) - non_current_user_files.extend(temp_non_current_user_files) - others_writable_files.extend(temp_other_writable_files) - link_files.extend(temp_link_files) - - return risky_files, non_current_user_files, others_writable_files, link_files - - -def check_risky_file(args, config): - """Check risky file (.pth/.pth.tar/.py).""" - if not args.security: - return True - risky_files, non_current_user_files, others_writable_files, link_files = get_risky_files(config) - if len(risky_files) == 0: - return True - - print("\033[1;33m" - "WARNING: The following executable files will be loaded:" - "\033[0m") - for file in risky_files: - print(file) - if len(non_current_user_files) > 0: - print("\033[1;33m" - "WARNING: The following executable files that will be loaded do not belong to the current user:" - "\033[0m") - for file in non_current_user_files: - print(file) - if len(others_writable_files) > 0: - print("\033[1;33m" - "WARNING: The following executable files that will be loaded have others write permission:" - "\033[0m") - for file in others_writable_files: - print(file) - if len(link_files) > 0: - print("\033[1;33m" - "WARNING: The following executable files that will be loaded is soft link file:" - "\033[0m") - for file in link_files: - print(file) - user_confirm = input("It is possible to construct malicious pickle data " - "which will execute arbitrary code during unpickling .pth/.pth.tar/.py files. " - "\nPlease ensure the safety and consistency of the loaded executable files. " - "\nDo you want to continue? (yes/no) ").strip(" ") - while user_confirm != "yes" and user_confirm != "no": - user_confirm = input("Please enter yes or no! ").strip(" ") - if user_confirm == "yes": - return True - elif user_confirm == "no": - return False - - -def check_risky_files(file_list): - """Check if cert and key file are risky.""" - res = True - for file in file_list: - if not os.path.exists(file): - logging.error(f"File <{file}> does not exist") - res = False - continue - if not _file_belong_to_current_user(file): - logging.error(f"File <{file}> is not owned by current user") - res = False - if _file_is_link(file): - logging.error(f"File <{file}> should not be soft link") - res = False - if os.stat(file).st_mode & 0o0177: - logging.error(f"File <{file}> permissions are not correct, cannot exceed 600") - res = False - return res diff --git a/evaluate_service/evaluate_service/security/zmq_op.py b/evaluate_service/evaluate_service/security/zmq_op.py deleted file mode 100644 index 29b89d5..0000000 --- a/evaluate_service/evaluate_service/security/zmq_op.py +++ /dev/null @@ -1,70 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""ZMQ operation.""" -import os -import uuid -import zmq -import zmq.auth -from zmq.auth.thread import ThreadAuthenticator - - -def listen_security(ip, min_port, max_port, max_tries, temp_path): - """Listen on server.""" - ctx = zmq.Context.instance() - # Start an authenticator for this context. - auth = ThreadAuthenticator(ctx) - auth.start() - auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) - - socket = ctx.socket(zmq.REP) - server_secret_key = os.path.join(temp_path, "server.key_secret") - if not os.path.exists(server_secret_key): - _, server_secret_key = zmq.auth.create_certificates(temp_path, "server") - server_public, server_secret = zmq.auth.load_certificate(server_secret_key) - if os.path.exists(server_secret_key): - os.remove(server_secret_key) - socket.curve_secretkey = server_secret - socket.curve_publickey = server_public - socket.curve_server = True # must come before bind - - port = socket.bind_to_random_port( - f"tcp://{ip}", min_port=min_port, max_port=max_port, max_tries=100) - return socket, port - - -def connect_security(ip, port, temp_path): - """Connect to server.""" - ctx = zmq.Context.instance() - socket = ctx.socket(zmq.REQ) - client_name = uuid.uuid1().hex[:8] - client_secret_key = os.path.join(temp_path, "{}.key_secret".format(client_name)) - if not os.path.exists(client_secret_key): - client_public_key, client_secret_key = zmq.auth.create_certificates(temp_path, client_name) - client_public, client_secret = zmq.auth.load_certificate(client_secret_key) - socket.curve_secretkey = client_secret - socket.curve_publickey = client_public - server_public_key = os.path.join(temp_path, "server.key") - if not os.path.exists(server_public_key): - server_public_key, _ = zmq.auth.create_certificates(temp_path, "server") - server_public, _ = zmq.auth.load_certificate(server_public_key) - socket.curve_serverkey = server_public - socket.connect(f"tcp://{ip}:{port}") - if os.path.exists(client_secret_key): - os.remove(client_secret_key) - if os.path.exists(client_public_key): - os.remove(client_public_key) - return socket diff --git a/evaluate_service/setup.py b/evaluate_service/setup.py index 7791459..779185b 100644 --- a/evaluate_service/setup.py +++ b/evaluate_service/setup.py @@ -60,7 +60,7 @@ def run(self): setuptools.setup( name="evaluate-service", - version="1.8.0", + version="1.8.2", packages=["evaluate_service"], include_package_data=True, python_requires=">=3.6", diff --git a/setup.py b/setup.py index e4a8cd8..b3e77f8 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ setuptools.setup( name="noah-vega", - version="1.8.0", + version="1.8.2", packages=["vega"], include_package_data=True, python_requires=">=3.6", diff --git a/vega/__init__.py b/vega/__init__.py index f6bbb56..a65c1b4 100644 --- a/vega/__init__.py +++ b/vega/__init__.py @@ -34,7 +34,7 @@ "get_quota", ] -__version__ = "1.8.0" +__version__ = "1.8.2" import sys diff --git a/vega/algorithms/compression/prune_ea/prune_trainer_callback.py b/vega/algorithms/compression/prune_ea/prune_trainer_callback.py index ba7fe06..21e62ff 100644 --- a/vega/algorithms/compression/prune_ea/prune_trainer_callback.py +++ b/vega/algorithms/compression/prune_ea/prune_trainer_callback.py @@ -109,8 +109,8 @@ def _init_chn_node_mask(self): :rtype: array """ chn_node_mask_tmp = self.base_net_desc.backbone.chn_node_mask - chn_node_mask = [single_mask for (i, single_mask) in zip([1, 3, 3, 3], chn_node_mask_tmp) - for _ in range(i)] + chn_node_mask = [single_mask for (i, single_mask) in zip([1, 3, 3, 3], chn_node_mask_tmp) + for _ in range(i)] return chn_node_mask def _generate_init_model(self): diff --git a/vega/core/pipeline/search_pipe_step.py b/vega/core/pipeline/search_pipe_step.py index 8e7d246..8e65b5e 100644 --- a/vega/core/pipeline/search_pipe_step.py +++ b/vega/core/pipeline/search_pipe_step.py @@ -98,9 +98,9 @@ def _clean_checkpoint(self): worker_parent_folder = os.path.abspath( os.path.join(TaskOps().get_local_worker_path(General.step_name, 1), "..")) patterns = [ - ".*.pkl", "*.pth", "model_*", "model.ckpt-*", "*.pb", + ".*.pkl", "*.pth", "model_*", "model.ckpt-*", "graph.*", "eval", "events*", "CKP-*", "checkpoint", ".*.log", - "*.ckpt", "*.air", "*.onnx", "*.caffemodel", + "*.ckpt", "*.caffemodel", "*.pbtxt", "*.bin", "kernel_meta", "*.prototxt", ] all_files = [] diff --git a/vega/datasets/common/pacs.py b/vega/datasets/common/pacs.py index a3b5058..88d3440 100644 --- a/vega/datasets/common/pacs.py +++ b/vega/datasets/common/pacs.py @@ -43,7 +43,7 @@ def __init__(self, **kwargs): self.args.data_path = FileOps.download_dataset(self.args.data_path) targetdomain = self.args.targetdomain domain = ['cartoon', 'art_painting', 'photo', 'sketch'] - if self.train: + if self.mode == "train": domain.remove(targetdomain) else: domain = [targetdomain] @@ -62,20 +62,9 @@ def __init__(self, **kwargs): classes.sort() class_to_idx = {classes[i]: i for i in range(len(classes))} full_label = [class_to_idx[x] for x in label_name] - if self.train: - name_train, name_val, labels_train, labels_val, concepts_train, concepts_val = \ - train_test_split(full_data, full_label, full_concept, train_size=self.args.train_portion) - if self.mode == "train": - self.data = name_train - self.label = labels_train - self.concept = concepts_train - else: - self.data = name_val - self.label = labels_val - self.concept = concepts_val - else: - self.data, self.label = full_data, full_label - self.concept = [0] * len(self.data) + self.data = full_data + self.label = full_label + self.concept = full_concept def __getitem__(self, index): """Get an item of the dataset according to the index. diff --git a/vega/datasets/conf/pacs.py b/vega/datasets/conf/pacs.py index 77831e1..054c8bf 100644 --- a/vega/datasets/conf/pacs.py +++ b/vega/datasets/conf/pacs.py @@ -29,7 +29,6 @@ class PacsCommonConfig(BaseConfig): data_path = None split_path = None targetdomain = None - train_portion = 1.0 task = None @classmethod @@ -41,7 +40,6 @@ def rules(cls): "split_path": {"type": str}, "targetdomain": {"type": str}, "batch_size": {"type": int}, - "train_portion": {"type": (int, float)}, } return rules_PacsCommonConfig diff --git a/vega/evaluator/conf.py b/vega/evaluator/conf.py index 157fd34..c0124c9 100644 --- a/vega/evaluator/conf.py +++ b/vega/evaluator/conf.py @@ -58,6 +58,7 @@ class DeviceEvaluatorConfig(ConfigSerializable): save_intermediate_file = False custom = None repeat_times = 10 + muti_input = True class EvaluatorConfig(ConfigSerializable): diff --git a/vega/evaluator/device_evaluator.py b/vega/evaluator/device_evaluator.py index a66c7bf..5d3efae 100644 --- a/vega/evaluator/device_evaluator.py +++ b/vega/evaluator/device_evaluator.py @@ -56,6 +56,7 @@ def __init__(self, worker_info=None, model=None, saved_folder=None, saved_step_n self.opset_version = self.config.opset_version self.precision = self.config.precision.upper() self.calculate_metric = self.config.calculate_metric + self.muti_input = self.config.muti_input self.quantize = self.config.quantize self.model = model self.worker_info = worker_info @@ -134,7 +135,8 @@ def _torch_valid(self, test_data, job_id): cal_metric=self.calculate_metric, intermediate_format=self.intermediate_format, opset_version=self.opset_version, repeat_times=self.repeat_times, - save_intermediate_file=self.config.save_intermediate_file, custom=self.custom) + save_intermediate_file=self.config.save_intermediate_file, custom=self.custom, + muti_input=self.muti_input) latency = np.float(results.get("latency")) data_num = 1 @@ -165,7 +167,7 @@ def _torch_valid_metric(self, test_data, job_id): precision=self.precision, cal_metric=self.calculate_metric, intermediate_format=self.intermediate_format, opset_version=self.opset_version, repeat_times=self.repeat_times, - save_intermediate_file=self.config.save_intermediate_file) + save_intermediate_file=self.config.save_intermediate_file, muti_input=self.muti_input) if results.get("status") != "sucess" and error_count <= error_threshold: error_count += 1 break @@ -236,7 +238,7 @@ def _tf_valid(self, test_data, latency_sum, data_num, global_step, job_id): reuse_model=reuse_model, job_id=job_id, quantize=self.quantize, repeat_times=self.repeat_times, precision=self.precision, cal_metric=self.calculate_metric, - save_intermediate_file=self.config.save_intermediate_file) + save_intermediate_file=self.config.save_intermediate_file, muti_input=self.muti_input) if self.calculate_metric and results.get("status") != "sucess" and error_count <= error_threshold: error_count += 1 break @@ -289,7 +291,7 @@ def _ms_valid(self, test_data, job_id): model=self.model, weight=None, test_data=test_data, input_shape=data.shape, reuse_model=False, job_id=job_id, precision=self.precision, cal_metric=self.calculate_metric, repeat_times=self.repeat_times, - save_intermediate_file=self.config.save_intermediate_file, custom=self.custom) + save_intermediate_file=self.config.save_intermediate_file, custom=self.custom, muti_input=self.muti_input) latency = np.float(results.get("latency")) pfms = {} data_num = 1 @@ -311,7 +313,7 @@ def _ms_valid_metric(self, test_data, job_id): model=self.model, weight=None, test_data=test_data, input_shape=data.shape, reuse_model=reuse_model, job_id=job_id, precision=self.precision, cal_metric=self.calculate_metric, repeat_times=self.repeat_times, - save_intermediate_file=self.config.save_intermediate_file) + save_intermediate_file=self.config.save_intermediate_file, muti_input=self.muti_input) latency = np.float(results.get("latency")) latency_sum += latency diff --git a/vega/evaluator/tools/evaluate_davinci_bolt.py b/vega/evaluator/tools/evaluate_davinci_bolt.py index 77efe2b..203c28c 100644 --- a/vega/evaluator/tools/evaluate_davinci_bolt.py +++ b/vega/evaluator/tools/evaluate_davinci_bolt.py @@ -22,7 +22,7 @@ def evaluate(backend, hardware, remote_host, model, weight, test_data, input_shape=None, reuse_model=False, job_id=None, - quantize=False, repeat_times=10, precision='FP32', cal_metric=False, **kwargs): + quantize=False, repeat_times=10, precision='FP32', cal_metric=False, muti_input=False, **kwargs): """Evaluate interface of the EvaluateService. :param backend: the backend can be one of "tensorflow", "caffe" and "pytorch" @@ -58,7 +58,8 @@ def evaluate(backend, hardware, remote_host, model, weight, test_data, input_sha upload_data = {"data_file": data_file} evaluate_config = {"backend": backend, "hardware": hardware, "remote_host": remote_host, "reuse_model": reuse_model, - "job_id": job_id, "repeat_times": repeat_times, "precision": precision, "cal_metric": cal_metric} + "job_id": job_id, "repeat_times": repeat_times, "precision": precision, "cal_metric": cal_metric, + 'muti_input': muti_input} if backend == 'tensorflow': shape_list = [str(s) for s in input_shape] shape_cfg = {"input_shape": "Placeholder:" + ",".join(shape_list)} diff --git a/vega/model_zoo/model_zoo.py b/vega/model_zoo/model_zoo.py index 207474c..2e0ced4 100644 --- a/vega/model_zoo/model_zoo.py +++ b/vega/model_zoo/model_zoo.py @@ -226,7 +226,8 @@ def _load_ms_model(cls, model, pretrained_model_file): if file.endswith(".ckpt"): pretrained_weight = os.path.join(pretrained_model_file, file) break - load_checkpoint(pretrained_weight, net=model) + network = model if not hasattr(model, "get_ori_model") else model.get_ori_model() + load_checkpoint(pretrained_weight, net=network) return model @classmethod diff --git a/vega/modules/operators/cell.py b/vega/modules/operators/cell.py index f452076..5318f56 100644 --- a/vega/modules/operators/cell.py +++ b/vega/modules/operators/cell.py @@ -15,7 +15,7 @@ # limitations under the License. """Import all torch operators.""" -import vega + from vega.common import ClassType, ClassFactory from vega.modules.operators import ops from vega.modules.operators.mix_ops import MixedOp, OPS @@ -51,7 +51,6 @@ def __init__(self, genotype, steps, concat, reduction, reduction_prev=None, C_pr op_names, indices_out, indices_inp = zip(*self.genotype) self.build_ops(self.C, op_names, indices_out, indices_inp, self.concat, self.reduction) self.concat_size = len(self.concat) - self.torch_flag = vega.is_torch_backend() def build_ops(self, C, op_names, indices_out, indices_inp, concat, reduction): """Compile the cell. @@ -132,12 +131,7 @@ def call(self, s0, s1, weights=None, drop_path_prob=0, selected_idxs=None): else: h = self.oplist[idx + j](states[inp], None, selected_idxs[idx + j]) hlist.append(h) - if self.torch_flag: - s = sum(hlist) - else: - s = hlist[0] - for ii in range(1, len(hlist)): - s += hlist[ii] + s = ops.add_n(hlist) states.append(s) idx += len(self.out_inp_list[i]) states_list = () diff --git a/vega/modules/operators/functions/mindspore_fn.py b/vega/modules/operators/functions/mindspore_fn.py index 8ef5eb7..4f4d944 100644 --- a/vega/modules/operators/functions/mindspore_fn.py +++ b/vega/modules/operators/functions/mindspore_fn.py @@ -274,14 +274,9 @@ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=N self.out_channels = out_channels stride = tuple(stride) if isinstance(stride, list) else stride dilation = tuple(dilation) if isinstance(dilation, list) else dilation - if groups == 1: - self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, - has_bias=bias, group=groups, dilation=dilation, pad_mode=pad_mode) - self.conv2d.update_parameters_name("conv2d_" + uuid.uuid1().hex[:8] + ".") - else: - self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, - has_bias=bias, group=1, dilation=dilation, pad_mode=pad_mode) - self.conv2d.update_parameters_name("conv2d_" + uuid.uuid1().hex[:8] + ".") + self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, + has_bias=bias, group=groups, dilation=dilation, pad_mode=pad_mode) + self.conv2d.update_parameters_name("conv2d_" + uuid.uuid1().hex[:8] + ".") def construct(self, input): """Call conv2d function.""" diff --git a/vega/networks/mindspore/super_network.py b/vega/networks/mindspore/super_network.py index 40995c2..3799d20 100644 --- a/vega/networks/mindspore/super_network.py +++ b/vega/networks/mindspore/super_network.py @@ -20,7 +20,9 @@ from vega.modules.blocks import AuxiliaryHead from vega.modules.connections import Cells from vega.modules.module import Module -from vega.modules.operators import ops +from mindspore import ops +import numpy as np +from mindspore import Tensor @ClassFactory.register(ClassType.NETWORK) @@ -50,12 +52,16 @@ def __init__(self, stem, cells, head, init_channels, num_classes, auxiliary, sea if not search and auxiliary: self.auxiliary_head = AuxiliaryHead(c_aux, num_classes, aux_size) # head - self.head = ClassFactory.get_instance(ClassType.NETWORK, head, base_channel=c_prev, - num_classes=num_classes) + self.head = ClassFactory.get_instance( + ClassType.NETWORK, head, base_channel=c_prev, num_classes=num_classes) # Initialize architecture parameters - self.set_parameters('alphas_normal', 1e-3 * ops.random_normal(self.len_alpha, self.num_ops)) - self.set_parameters('alphas_reduce', 1e-3 * ops.random_normal(self.len_alpha, self.num_ops)) + self.set_parameters( + 'alphas_normal', + 1e-3 * Tensor(np.random.randn(self.len_alpha, self.num_ops).astype(np.float32))) + self.set_parameters( + 'alphas_reduce', + 1e-3 * Tensor(np.random.randn(self.len_alpha, self.num_ops).astype(np.float32))) self.cell_list = self.cells_.children() self.name_list = [] @@ -72,9 +78,11 @@ def arch_weights(self): """Get weights of alphas.""" self.alphas_normal = self.get_weights('alphas_normal') self.alphas_reduce = self.get_weights('alphas_reduce') - alphas_normal = ops.softmax(self.alphas_normal, -1) - alphas_reduce = ops.softmax(self.alphas_reduce, -1) - return [ops.to_numpy(alphas_normal), ops.to_numpy(alphas_reduce)] + softmax = ops.Softmax() + alphas_normal = softmax(self.alphas_normal) + softmax = ops.Softmax() + alphas_reduce = softmax(self.alphas_reduce) + return [alphas_normal.asnumpy(), alphas_reduce.asnumpy()] def get_weight_ops(self): """Get weight ops.""" @@ -82,7 +90,8 @@ def get_weight_ops(self): def calc_alphas(self, alphas, dim=-1, **kwargs): """Calculate Alphas.""" - return ops.softmax(alphas, dim) + softmax = ops.Softmax() + return softmax(alphas) def call(self, input, alpha=None): """Forward a model that specified by alpha. diff --git a/vega/networks/pytorch/losses/decaug_loss.py b/vega/networks/pytorch/losses/decaug_loss.py index ed89cb9..195083d 100644 --- a/vega/networks/pytorch/losses/decaug_loss.py +++ b/vega/networks/pytorch/losses/decaug_loss.py @@ -36,7 +36,7 @@ def __init__(self, balance1=0.01, balance2=0.01, balanceorth=0.01, epsilon=1e-8, self.cross_entropy = ClassFactory.get_cls(ClassType.LOSS, "CrossEntropyLoss")() def forward(self, x, targets=None): - _, logits_category, logits_concept, feature_category, feature_category, model = x + _, logits_category, logits_concept, feature_category, feature_concept, model = x gt_label, gt_concept = targets loss1 = self.cross_entropy(logits_category, gt_label) loss2 = self.cross_entropy(logits_concept, gt_concept) @@ -71,7 +71,7 @@ def forward(self, x, targets=None): ratio = random.random() feature_aug = ratio * FGSM_attack - embs = ops.concat((feature_category, feature_category + feature_aug), 1) + embs = ops.concat((feature_category, feature_concept + feature_aug), 1) output = ops.matmul(embs, w_out.transpose(0, 1)) + b_out loss_class = self.cross_entropy(output, gt_label) diff --git a/vega/report/report_persistence.py b/vega/report/report_persistence.py index 78d5a12..e126823 100644 --- a/vega/report/report_persistence.py +++ b/vega/report/report_persistence.py @@ -42,16 +42,28 @@ def update_step_info(self, **kwargs): """Update step info.""" if "step_name" in kwargs: step_name = kwargs["step_name"] - if step_name not in self.steps: - self.steps[step_name] = {} - for key in kwargs: - if key in ["step_name", "start_time", "end_time", "status", "message", "num_epochs", "num_models"]: - self.steps[step_name][key] = kwargs[key] - else: - logger.warn("Invilid step info {}:{}".format(key, kwargs[key])) + + def update_each_step_info(step_name): + if step_name not in self.steps: + self.steps[step_name] = {} + for key in kwargs: + if key == "step_name": + self.steps[step_name][key] = step_name + elif key in ["start_time", "end_time", "status", + "message", "num_epochs", "num_models", "best_models"]: + self.steps[step_name][key] = kwargs[key] + else: + logger.warn("Invilid step info {}:{}".format(key, kwargs[key])) + + if isinstance(step_name, list): + for step in step_name: + update_each_step_info(step) + else: + update_each_step_info(step_name) else: logger.warn("Invilid step info: {}.".format(kwargs)) + def save_report(self, records): """Save report to `reports.json`.""" try: diff --git a/vega/report/report_server.py b/vega/report/report_server.py index 0431736..375c9bb 100644 --- a/vega/report/report_server.py +++ b/vega/report/report_server.py @@ -129,13 +129,18 @@ def get_pareto_front_records(self, step_name=None, nums=None, selected_key=None, self.old_not_finished_workers = not_finished logging.info(f"waiting for the workers {str(not_finished)} to finish") if not records: + self.update_step_info({"step_name": step_name, "best_models": []}) return [] pareto = self.pareto_front(step_name, nums, records=records) if not pareto: + self.update_step_info(**{"step_name": step_name, "best_models": []}) return [] if choice is not None: - return [random.choice(pareto)] + records = random.choice(pareto) + self.update_step_info(**{"step_name": step_name, "best_models": [record.worker_id for record in records]}) + return [records] else: + self.update_step_info(**{"step_name": step_name, "best_models": [record.worker_id for record in pareto]}) return pareto @classmethod diff --git a/vega/security/conf.py b/vega/security/conf.py index 4e9fa03..b168581 100644 --- a/vega/security/conf.py +++ b/vega/security/conf.py @@ -104,7 +104,7 @@ def __init__(self): self.file_name = os.path.expanduser("~/.vega/client.ini") self.keys = [ "ca_cert", "client_cert", "client_secret_key", "encrypted_password", - "key_component_1", "key_component_2", "white_list"] + "key_component_1", "key_component_2"] _server_config = ServerConfig() diff --git a/vega/security/kmc/encrypt_key.py b/vega/security/kmc/encrypt_key.py index 7691c1d..976ecfe 100644 --- a/vega/security/kmc/encrypt_key.py +++ b/vega/security/kmc/encrypt_key.py @@ -63,8 +63,8 @@ def validate_certificate(cert, key, origin_mm): p2 = subprocess.Popen(["grep", "RSA Public-Key"], stdin=p1.stdout, stdout=subprocess.PIPE, shell=False) p3 = subprocess.Popen(["tr", "-cd", "[0-9]"], stdin=p2.stdout, stdout=subprocess.PIPE, shell=False) RSA_key = p3.communicate()[0] - if int(RSA_key) < 2048: - logging.warning("Insecure key length: %d", int(RSA_key)) + if int(RSA_key) < 3072: + logging.warning("Insecure key length: %d. The recommended key length is at least 3072", int(RSA_key)) return flag diff --git a/vega/security/load_pickle.py b/vega/security/load_pickle.py index df63f23..2065110 100644 --- a/vega/security/load_pickle.py +++ b/vega/security/load_pickle.py @@ -30,6 +30,8 @@ 'numpy', 'imageio', 'collections', + 'apex', + 'ascend_automl' } diff --git a/vega/security/verify_config.py b/vega/security/verify_config.py index f5c910e..475d8c1 100644 --- a/vega/security/verify_config.py +++ b/vega/security/verify_config.py @@ -136,7 +136,7 @@ def check_risky_files(file_list): """Check if cert and key file are risky.""" res = True for file in file_list: - if not os.path.exists(file): + if file != '' and not os.path.exists(file): logging.error(f"File <{file}> does not exist") res = False continue diff --git a/vega/trainer/callbacks/model_checkpoint.py b/vega/trainer/callbacks/model_checkpoint.py index 481624b..a80f233 100644 --- a/vega/trainer/callbacks/model_checkpoint.py +++ b/vega/trainer/callbacks/model_checkpoint.py @@ -89,6 +89,9 @@ def _save_best_model(self): def _save_checkpoint(self, epoch): """Save checkpoint.""" + if not self.trainer.config.save_slave_model: + if not self.trainer.is_chief: + return logging.debug("Start Save Checkpoint, file_name=%s", self.trainer.checkpoint_file_name) checkpoint_file = FileOps.join_path( self.trainer.get_local_worker_path(), self.trainer.checkpoint_file_name) diff --git a/vega/trainer/conf.py b/vega/trainer/conf.py index 72820bd..4a724a1 100644 --- a/vega/trainer/conf.py +++ b/vega/trainer/conf.py @@ -62,6 +62,7 @@ class TrainerConfig(ConfigSerializable): distributed = False save_model_desc = False save_ext_model = False + save_slave_model = False report_freq = 10 seed = 0 epochs = 1 diff --git a/vega/trainer/trainer_torch.py b/vega/trainer/trainer_torch.py index d4a5f24..191bbae 100644 --- a/vega/trainer/trainer_torch.py +++ b/vega/trainer/trainer_torch.py @@ -59,10 +59,13 @@ def build(self): from apex import amp if not vega.is_npu_device(): self.model, self.optimizer = amp.initialize( - self.model, self.optimizer, opt_level=self.config.opt_level, loss_scale=self.config.apex_loss_scale) + self.model, self.optimizer, opt_level=self.config.opt_level, + loss_scale=self.config.apex_loss_scale) else: self.model, self.optimizer = amp.initialize( - self.model, self.optimizer, opt_level=self.config.opt_level, loss_scale=self.config.apex_loss_scale, combine_grad=self.config.apex_combine_grad) + self.model, self.optimizer, opt_level=self.config.opt_level, + loss_scale=self.config.apex_loss_scale, + combine_grad=self.config.apex_combine_grad) def set_training_settings(self): """Set trainer training setting.""" @@ -194,7 +197,6 @@ def _set_amp_loss(self, loss): else: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() - self.optimizer.step() def _multi_train_step(self, batch): train_batch_output = None