diff --git a/configs/modules/service/smpl_stream_service.py b/configs/modules/service/smpl_stream_service.py new file mode 100644 index 0000000..360731d --- /dev/null +++ b/configs/modules/service/smpl_stream_service.py @@ -0,0 +1,9 @@ +type = 'SMPLStreamService' +name = 'smpl_stream_service' +work_dir = f'temp/{name}' +body_model_dir = 'xrmocap_data/body_models' +device = 'cuda:0' +enable_bytes = True +enable_cors = True +port = 29091 +max_http_buffer_size = 128 * 1024 * 1024 diff --git a/dockerfiles/service_ubt18/Dockerfile b/dockerfiles/service_ubt18/Dockerfile new file mode 100644 index 0000000..e8d3f57 --- /dev/null +++ b/dockerfiles/service_ubt18/Dockerfile @@ -0,0 +1,8 @@ +ARG INPUT_TAG +FROM $INPUT_TAG + +# Install test requirements +RUN . /opt/miniconda/etc/profile.d/conda.sh && \ + conda activate openxrlab && \ + pip install -r https://raw.githubusercontent.com/openxrlab/xrmocap/main/requirements/service.txt && \ + pip cache purge diff --git a/dockerfiles/service_ubt18/build_runtime_docker.sh b/dockerfiles/service_ubt18/build_runtime_docker.sh new file mode 100755 index 0000000..2c0d6fc --- /dev/null +++ b/dockerfiles/service_ubt18/build_runtime_docker.sh @@ -0,0 +1,17 @@ +#!/bin/bash +CUDA_VER=11.6 +PY_VER=3.8 +MMCV_VER=1.6.1 +TORCH_VER=1.12.1 +TORCHV_VER=0.13.1 +CUDA_VER_DIGIT=${CUDA_VER//./} +PY_VER_DIGIT=${PY_VER//./} +MMCV_VER_DIGIT=${MMCV_VER//./} +TORCH_VER_DIGIT=${TORCH_VER//./} +INPUT_TAG="openxrlab/xrmocap_runtime:ubuntu1804_x64_cuda${CUDA_VER_DIGIT}_py${PY_VER_DIGIT}_torch${TORCH_VER_DIGIT}_mmcv${MMCV_VER_DIGIT}" +FINAL_TAG="${INPUT_TAG}_service" +echo "tag to build: $FINAL_TAG" +BUILD_ARGS="--build-arg CUDA_VER=${CUDA_VER} --build-arg PY_VER=${PY_VER} --build-arg MMCV_VER=${MMCV_VER} --build-arg TORCH_VER=${TORCH_VER} --build-arg TORCHV_VER=${TORCHV_VER} --build-arg INPUT_TAG=${INPUT_TAG}" +# build according to Dockerfile +docker build -t $FINAL_TAG -f dockerfiles/service_ubt18/Dockerfile $BUILD_ARGS --progress=plain . +echo "Successfully tagged $FINAL_TAG" diff --git a/docs/en/installation.md b/docs/en/installation.md index eab107f..daa9781 100644 --- a/docs/en/installation.md +++ b/docs/en/installation.md @@ -5,6 +5,7 @@ - [Prepare environment](#prepare-environment) - [Run with docker image](#run-with-docker-image) - [Test environment](#test-environment) +- [Client only](#client-only) - [Frequently Asked Questions](#frequently-asked-questions) ## Requirements @@ -21,14 +22,15 @@ Optional: -| Name | When it is required | What's important | -| :------------------------------------------------------- | :----------------------------- | :----------------------------------------------------------- | -| [MMPose](https://github.com/open-mmlab/mmpose) | Keypoints 2D estimation. | Install `mmcv-full`, instead of `mmcv`. | -| [MMDetection](https://github.com/open-mmlab/mmdetection) | Bbox 2D estimation. | Install `mmcv-full`, instead of `mmcv`. | -| [MMTracking](https://github.com/open-mmlab/mmtracking) | Multiple object tracking. | Install `mmcv-full`, instead of `mmcv`. | -| [MMDeploy](https://github.com/open-mmlab/mmdeploy) | Faster mmdet+mmpose inference. | Install `mmcv-full`, `cudnn` and `TensorRT`. | -| [Aniposelib](https://github.com/google/aistplusplus_api) | Triangulation. | Install from [github](https://github.com/liruilong940607/aniposelib), instead of pypi. | -| [Minimal Pytorch Rasterizer](https://github.com/rmbashirov/minimal_pytorch_rasterizer) | SMPL mesh fast visualization. | Tested on torch-1.12.0. | +| Name | When it is required | What's important | +| :----------------------------------------------------------- | :-------------------------------------- | :----------------------------------------------------------- | +| [MMPose](https://github.com/open-mmlab/mmpose) | Keypoints 2D estimation. | Install `mmcv-full`, instead of `mmcv`. | +| [MMDetection](https://github.com/open-mmlab/mmdetection) | Bbox 2D estimation. | Install `mmcv-full`, instead of `mmcv`. | +| [MMTracking](https://github.com/open-mmlab/mmtracking) | Multiple object tracking. | Install `mmcv-full`, instead of `mmcv`. | +| [MMDeploy](https://github.com/open-mmlab/mmdeploy) | Faster mmdet+mmpose inference. | Install `mmcv-full`, `cudnn` and `TensorRT`. | +| [Aniposelib](https://github.com/google/aistplusplus_api) | Triangulation. | Install from [github](https://github.com/liruilong940607/aniposelib), instead of pypi. | +| [Minimal Pytorch Rasterizer](https://github.com/rmbashirov/minimal_pytorch_rasterizer) | SMPL mesh fast visualization. | Tested on torch-1.12.0. | +| [Flask](https://flask.palletsprojects.com/en/2.3.x/) | Starting an http or a websocket server. | | ## A from-scratch setup script @@ -92,6 +94,8 @@ cd xrmocap pip install -r requirements/build.txt # install requirements for runtime pip install -r requirements/runtime.txt +# install requirements for services +pip install -r requirements/service.txt # install xrmocap rm -rf .eggs && pip install -e . @@ -196,7 +200,15 @@ cd /opt && \ **Note3:** We've only tested mmdeploy 0.12.0, other version may not work as expectation. -#### g. Run unittests or demos +#### g. Install requirements for service + +You will only need this when you are going to start a server defined in `xrmocap.service`. + +```bash +pip install -r requirements/service.txt +``` + +#### h. Run unittests or demos If everything goes well, try to [run unittest](#test-environment) or go back to [run demos](./getting_started.md#inference) @@ -225,6 +237,15 @@ sh scripts/run_docker.sh To test whether the environment is well installed, please refer to [test doc](./test.md). +### Client only + +If you only need to use the client provided by XRMoCap, the installation process will be much simpler. We have increased the compatibility of the client by reducing dependencies, and you only need to execute the commands below. + +```bash +pip install numpy tqdm flask-socketio requests websocket-client +pip install . --no-deps +``` + ### Frequently Asked Questions If your environment fails, check our [FAQ](./faq.md) first, it might be helpful to some typical questions. diff --git a/docs/en/tools/start_service.md b/docs/en/tools/start_service.md new file mode 100644 index 0000000..55e862e --- /dev/null +++ b/docs/en/tools/start_service.md @@ -0,0 +1,40 @@ +# Tool start_service + +- [Overview](#overview) +- [Argument: config_path](#argument-config_path) +- [Argument: disable_log_file](#argument-disable_log_file) +- [Example](#example) + +### Overview + +This tool starts a server in the current console according to the configuration file, and sets up a logger. The logger displays information of no less than `INFO` level in the console, and write information of no less than `DEBUG` level in the log file under the `logs/` directory. + +For services that use the `work_dir` parameter, please make sure that the target path can be created correctly. Generally speaking, running `mkdir temp` in advance can ensure that the default configuration file in the repository can be successfully used. + +### Argument: config_path + +`config_path` is the path to a configuration file for server. Please ensure that all parameters required by `SomeService.__init__()` are specified in the configuration file. An example is provided below. For more details, see the docstring in [code](../../../xrmocap/service/base_flask_service.py). + +```python +type = 'SMPLStreamService' +name = 'smpl_stream_service' +work_dir = f'temp/{name}' +body_model_dir = 'xrmocap_data/body_models' +device = 'cuda:0' +enable_cors = True +port = 29091 +``` + +Also, you can find our prepared config files in `configs/modules/service/smpl_stream_service.py`. + +### Argument: disable_log_file + +By default, `disable_log_file` is False and two log files under `logs/f'{service_name}_{time_str}'` will be written. Add `--disable_log_file` makes it True and the tool will only print log to console. + +### Example + +Run the tool with explicit paths. + +```bash +python tools/start_service.py --config_path configs/modules/service/smpl_stream_service.py +``` diff --git a/requirements/service.txt b/requirements/service.txt new file mode 100644 index 0000000..8e82848 --- /dev/null +++ b/requirements/service.txt @@ -0,0 +1,6 @@ +flask +Flask-Caching +flask-socketio +flask_api +flask_cors +simple-websocket diff --git a/scripts/start_service_docker.sh b/scripts/start_service_docker.sh new file mode 100644 index 0000000..330967c --- /dev/null +++ b/scripts/start_service_docker.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash +TAG=openxrlab/xrmocap_runtime:ubuntu1804_x64_cuda116_py38_torch1121_mmcv161_service +CONFIG_PATH=$1 +PORT=$(grep 'port =' ${CONFIG_PATH} | cut -d "=" -f 2 | tr -d ' ') +echo "Starting service on port $PORT" +PORTS="-p $PORT:$PORT" +WORKSPACE_VOLUMES="-v $PWD:/workspace/xrmocap" +WORKDIR="-w /workspace/xrmocap" +MEMORY="--memory=20g" +docker run --runtime=nvidia -it --rm --entrypoint=/bin/bash $PORTS $WORKSPACE_VOLUMES $WORKDIR $MEMORY $TAG -c " + source /opt/miniconda/etc/profile.d/conda.sh + conda activate openxrlab + pip install . + python tools/start_service.py --config_path $CONFIG_PATH +" diff --git a/setup.cfg b/setup.cfg index e2764f6..387717e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,6 +12,6 @@ multi_line_output = 5 include_trailing_comma = true known_standard_library = pkg_resources,setuptools known_first_party = xrmocap -known_third_party =PIL,cv2,dateutil,filterpy,matplotlib,mmcv,mmhuman3d,numpy,prettytable,pytest,pytorch3d,scipy,smplx,sphinx_rtd_theme,torch,torchvision,tqdm,xrprimer +known_third_party =PIL,cv2,dateutil,filterpy,flask,flask_socketio,matplotlib,mmcv,mmhuman3d,numpy,prettytable,pytest,pytorch3d,scipy,smplx,socketio,sphinx_rtd_theme,torch,torchvision,tqdm,xrprimer no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tools/clients/smpl_verts_client.py b/tools/clients/smpl_verts_client.py new file mode 100644 index 0000000..8b456ab --- /dev/null +++ b/tools/clients/smpl_verts_client.py @@ -0,0 +1,76 @@ +# yapf: disable +import argparse +import logging +import numpy as np +import os +import sys +import time +from tqdm import tqdm + +from xrmocap.client.smpl_stream_client import SMPLStreamClient + +# yapf: enable + + +def main(args) -> int: + name = os.path.basename(__file__).split('.')[0] + logger = logging.getLogger(name) + if args.verbose: + logger.setLevel(logging.INFO) + else: + logger.setLevel(logging.WARNING) + if args.smpl_data_path is None: + logger.error('Please specify smpl_data_path.') + raise ValueError + client = SMPLStreamClient( + server_ip=args.server_ip, server_port=args.server_port, logger=logger) + n_frames = client.upload_smpl_data(args.smpl_data_path) + logger.info(f'Motion of {n_frames} frames uploaded.') + faces = client.get_faces() + faces_np = np.array(faces) + logger.info(f'Get faces: {faces_np.shape}') + start_time = time.time() + for frame_idx in tqdm(range(n_frames)): + verts = client.forward(frame_idx) + if frame_idx == 0: + verts_np = np.array(verts) + logger.info(f'Get verts for first frame: {verts_np.shape}') + loop_time = time.time() - start_time + fps = n_frames / loop_time + logger.info(f'Get verts for all frames, average fps: {fps:.2f}') + client.close() + return 0 + + +def setup_parser(): + parser = argparse.ArgumentParser( + description='Send a smpl data file to ' + + 'SMPLStreamServer and receive faces and verts.') + parser.add_argument( + '--smpl_data_path', + help='Path to a SMPL(X)Data file.', + type=str, + ) + parser.add_argument( + '--server_ip', + help='IP address of the server.', + type=str, + default='127.0.0.1') + parser.add_argument( + '--server_port', + help='Port number of the server.', + type=int, + default=29091) + parser.add_argument( + '--verbose', + action='store_true', + help='If True, INFO level log will be shown.', + default=False) + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = setup_parser() + ret_val = main(args) + sys.exit(ret_val) diff --git a/tools/start_service.py b/tools/start_service.py new file mode 100644 index 0000000..4bad967 --- /dev/null +++ b/tools/start_service.py @@ -0,0 +1,71 @@ +# yapf: disable +import argparse +import json +import mmcv +import os +from xrprimer.utils.log_utils import logging, setup_logger + +from xrmocap.service.builder import build_service +from xrmocap.utils.date_utils import get_datetime_local, get_str_from_datetime + +# yapf: enable + + +def main(args): + # load config + service_config = dict(mmcv.Config.fromfile(args.config_path)) + service_name = service_config['name'] + # setup logger + if not args.disable_log_file: + datetime = get_datetime_local() + time_str = get_str_from_datetime(datetime) + log_dir = os.path.join('logs', f'{service_name}_{time_str}') + os.makedirs(log_dir) + main_logger_path = None \ + if args.disable_log_file\ + else os.path.join(log_dir, f'{service_name}_log.txt') + flask_logger_path = None \ + if args.disable_log_file\ + else os.path.join(log_dir, 'flask_log.txt') + logger = setup_logger( + logger_name=service_name, + file_level=logging.DEBUG, + console_level=logging.INFO, + logger_path=main_logger_path) + # logger for Flask + flask_logger = setup_logger( + logger_name='werkzeug', + file_level=logging.DEBUG, + console_level=logging.INFO, + logger_path=flask_logger_path) + logger.info('Main logger starts.') + flask_logger.info('Flask logger starts.') + # build service + service_config_str = json.dumps(service_config, indent=4) + logger.debug(f'\nservice_config:\n{service_config_str}') + service_config['logger'] = logger + service = build_service(service_config) + service.run() + + +def setup_parser(): + parser = argparse.ArgumentParser() + # input args + parser.add_argument( + '--config_path', + type=str, + help='Path to service config file.', + default='configs/modules/service/base_service.py') + # log args + parser.add_argument( + '--disable_log_file', + action='store_true', + help='If checked, log will not be written as file.', + default=False) + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = setup_parser() + main(args) diff --git a/xrmocap/client/__init__.py b/xrmocap/client/__init__.py new file mode 100644 index 0000000..ef2f6a2 --- /dev/null +++ b/xrmocap/client/__init__.py @@ -0,0 +1,7 @@ +import logging + +# client does not require xrprimer.utils.log_utils +# logger's level is set to INFO by default +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') diff --git a/xrmocap/client/smpl_stream_client.py b/xrmocap/client/smpl_stream_client.py new file mode 100644 index 0000000..f4beb13 --- /dev/null +++ b/xrmocap/client/smpl_stream_client.py @@ -0,0 +1,180 @@ +# yapf: disable +import gzip +import logging +import numpy as np +import socketio +from enum import Enum +from typing import List, Union + +# yapf: enable + + +class SMPLStreamActionsEnum(str, Enum): + UPLOAD = 'upload' + FORWARD = 'forward' + GET_FACES = 'get_faces' + + +class SMPLStreamClient: + """Client of the XRMocap SMPL Stream server.""" + + def __init__(self, + server_ip: str = '127.0.0.1', + server_port: int = 29091, + enable_bytes: bool = True, + enable_gzip: bool = False, + logger: Union[None, str, logging.Logger] = None) -> None: + """Initialize the client. + + Args: + server_ip (str, optional): + IP address of the server. + Defaults to '127.0.0.1'. + server_port (int, optional): + Port of the server. + Defaults to 8376. + enable_bytes (bool, optional): + If True, the client will receive bytes from server. + Otherwise, the client will receive dict. + Defaults to True. + enable_gzip (bool, optional): + If True, the client will decompress the bytes from server. + Defaults to False. + logger (Union[None, str, logging.Logger], optional): + Logger for logging. If None, root logger will be selected. + Defaults to None. + """ + if logger is None or isinstance(logger, str): + self.logger = logging.getLogger(logger) + else: + self.logger = logger + self.server_ip = server_ip + self.server_port = server_port + self.enable_bytes = enable_bytes + self.enable_gzip = enable_gzip + if not self.enable_bytes and self.enable_gzip: + self.logger.warning('enable_gzip is set to True,' + + ' but enable_bytes is set to False. ' + 'enable_gzip will be ignored.') + # setup websocket client + self.socketio_client = socketio.Client() + self.socketio_client.connect(f'http://{server_ip}:{server_port}') + + def _parse_upload_response(self, data): + if data['status'] == 'success': + n_frames = int(data['n_frames']) + else: + msg = data['msg'] + self.logger.error( + 'Failed to upload body motion, msg from server:\n' + msg) + self.socketio_client.disconnect() + raise RuntimeError + + return n_frames + + def upload_smpl_data(self, smpl_data: Union[bytes, str]) -> int: + """Upload a body motion to the SMPL server. + + Args: + smpl_data (Union[bytes, str]): + A SMPL(X)Data file in bytes, + or a path to SMPL(X)Data file. + + Raises: + ValueError: + body_motion is None + + Returns: + int: number of frames in the body motion + """ + if isinstance(smpl_data, str): + with open(smpl_data, 'rb') as f: + smpl_data_bytes = f.read() + elif smpl_data is None: + self.logger.error('SMPL data is None.') + raise ValueError + else: + smpl_data_bytes = smpl_data + + data = {'file_name': 'body_motion', 'file_data': smpl_data_bytes} + resp_data = self.socketio_client.call(SMPLStreamActionsEnum.UPLOAD, + data) + n_frames = self._parse_upload_response(resp_data) + return n_frames + + def _parse_get_faces_response(self, data: Union[dict, + bytes]) -> List[float]: + # find out if the request is successful first + if isinstance(data, dict): + success = (data['status'] == 'success') + else: + success = True + # extract faces according to response type and self settings + if success: + if self.enable_bytes: + bin_data = data + if self.enable_gzip: + bin_data = gzip.decompress(bin_data) + faces_list = np.frombuffer( + bin_data, dtype=np.float16).reshape((-1, 3)).tolist() + else: + faces_list = data['faces'] + return faces_list + else: + msg = data['msg'] + self.logger.error(msg) + self.close() + raise RuntimeError(msg) + + def get_faces(self) -> List[int]: + """Send a request to get body face indices from the server. + + Returns: + List[int]: the requested face indices, organized as a [|F|, 3] list + """ + resp_data = self.socketio_client.call(SMPLStreamActionsEnum.GET_FACES) + faces = self._parse_get_faces_response(resp_data) + return faces + + def _parse_forward_response(self, data) -> List[List[float]]: + # find out if the request is successful first + if isinstance(data, dict): + success = (data['status'] == 'success') + else: + success = True + # extract verts according to response type and self settings + if success: + if self.enable_bytes: + bin_data = data + if self.enable_gzip: + bin_data = gzip.decompress(bin_data) + verts_list = np.frombuffer( + bin_data, dtype=np.float16).reshape((-1, 3)).tolist() + else: + verts_list = np.asarray(data['verts']).reshape(-1, 3).tolist() + return verts_list + else: + msg = data['msg'] + self.logger.error(msg) + self.close() + raise RuntimeError(msg) + + def forward(self, frame_idx: int) -> List[List[float]]: + """Send a request to get body vertices from the server. + + Args: + frame_idx (int): frame index in infer + + Returns: + List[List[float]]: + A nested list for inferred body vertices, + shape: [n_verts, 3]. + """ + resp_data = self.socketio_client.call(SMPLStreamActionsEnum.FORWARD, + {'frame_idx': frame_idx}) + verts = self._parse_forward_response(resp_data) + return verts + + def close(self): + """Close the client.""" + self.socketio_client.disconnect() diff --git a/xrmocap/service/__init__.py b/xrmocap/service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/xrmocap/service/base_flask_service.py b/xrmocap/service/base_flask_service.py new file mode 100644 index 0000000..db20e36 --- /dev/null +++ b/xrmocap/service/base_flask_service.py @@ -0,0 +1,118 @@ +import os +import shutil +from typing import Tuple, Union +from xrprimer.utils.log_utils import get_logger, logging +from xrprimer.utils.path_utils import Existence, check_path_existence + +from xrmocap.utils.service_utils import payload_to_dict + +try: + from flask import Flask, request + from flask_api import status + from flask_cors import CORS + has_flask = True + import_exception = '' +except (ImportError, ModuleNotFoundError): + has_flask = False + import traceback + stack_str = '' + for line in traceback.format_stack(): + if 'frozen' not in line: + stack_str += line + '\n' + import_exception = traceback.format_exc() + '\n' + import_exception = stack_str + import_exception + + +class BaseFlaskService: + """Base http Flask service.""" + + def __init__(self, + name: str, + work_dir: str, + debug: bool = False, + enable_cors: bool = False, + host: str = '0.0.0.0', + port: int = 80, + logger: Union[None, str, logging.Logger] = None) -> None: + """ + Args: + name (str): Name of this service. + work_dir (str): + Path to a directory, for temp files of this server. + If empty, no temp files will be created. + debug (bool, optional): + If `debug` flag is set the server will automatically reload + for code changes and show a debugger in case + an exception happened. + Defaults to False. + enable_cors (bool, optional): + Whether to enable Cross Origin Resource Sharing (CORS). + Defaults to False. + host (str, optional): + Host IP address. 127.0.0.1 for localhost, + 0.0.0.0 for all local network interfaces. + Defaults to '0.0.0.0'. + port (int, optional): + Port for this http service. + Defaults to 80. + logger (Union[None, str, logging.Logger], optional): + Logger for logging. If None, root logger will be selected. + Defaults to None. + """ + self.logger = get_logger(logger) + # flask also has a global logger: werkzeug + if not has_flask: + self.logger.error(import_exception) + raise ImportError + self.name = name + self.debug = debug + self.host = host + self.port = port + self._set_work_dir(work_dir=work_dir) + self.app = Flask(self.name) + self.enable_cors = enable_cors + if self.enable_cors: + CORS(self.app) + self.app.add_url_rule( + '/base_method/', 'base_method', self.base_method, methods=['POST']) + + def _set_work_dir(self, work_dir: str) -> None: + if len(work_dir) <= 0: + return + existence = check_path_existence(work_dir, 'dir') + if existence == Existence.MissingParent: + self.logger.error(f'Parent of {work_dir} does not exist.') + raise FileNotFoundError + elif existence == Existence.DirectoryExistNotEmpty: + self.logger.warning('\n' + f'Work dir {work_dir} is not empty!' + + ' Please check its content carefully.' + ' Clean it and continue? Y/N') + reply = input().strip().lower() + if reply == 'y': + shutil.rmtree(work_dir) + os.mkdir(work_dir) + else: + self.logger.error('Exiting for keeping work_dir safe.') + raise FileExistsError + elif existence == Existence.DirectoryNotExist: + os.mkdir(work_dir) + self.work_dir = work_dir + + def base_method(self) -> Tuple[dict, int]: + """A base method for interface testing. + + Returns: + Tuple[dict, int]: + dict: Returned payload, or internal error message. + int: Http response, 200 for OK, 500 for internal error. + """ + req_json = request.get_json() + res_dict = payload_to_dict(req_json) + return res_dict, status.HTTP_200_OK + + def run(self): + """Run this flask service according to configuration. + + This process will be blocked. + """ + self.app.run(debug=self.debug, host=self.host, port=self.port) diff --git a/xrmocap/service/builder.py b/xrmocap/service/builder.py new file mode 100644 index 0000000..5b061fa --- /dev/null +++ b/xrmocap/service/builder.py @@ -0,0 +1,14 @@ +from mmcv.utils import Registry + +from .base_flask_service import BaseFlaskService +from .smpl_stream_service import SMPLStreamService + +SERVICES = Registry('services') + +SERVICES.register_module(name='BaseFlaskService', module=BaseFlaskService) +SERVICES.register_module(name='SMPLStreamService', module=SMPLStreamService) + + +def build_service(cfg) -> BaseFlaskService: + """Build a flask service.""" + return SERVICES.build(cfg) diff --git a/xrmocap/service/smpl_stream_service.py b/xrmocap/service/smpl_stream_service.py new file mode 100644 index 0000000..dda072a --- /dev/null +++ b/xrmocap/service/smpl_stream_service.py @@ -0,0 +1,374 @@ +# yapf: disable +import gzip +import numpy as np +import os +import time +import torch +import uuid +from flask import session +from flask_socketio import SocketIO, emit +from threading import RLock +from typing import Union +from xrprimer.utils.log_utils import logging + +from xrmocap.data_structure.body_model import auto_load_smpl_data +from xrmocap.model.body_model.builder import build_body_model +from xrmocap.utils.time_utils import Timer +from .base_flask_service import BaseFlaskService + +# yapf: enable + +_SMPL_CONFIG_TEMPLATE = dict( + type='SMPL', + gender='neutral', + num_betas=10, + keypoint_convention='smpl_45', + model_path='xrmocap_data/body_models/smpl', + batch_size=1) +_SMPLX_CONFIG_TEMPLATE = dict( + type='SMPLX', + gender='neutral', + num_betas=10, + keypoint_convention='smplx', + model_path='xrmocap_data/body_models/smplx', + batch_size=1, + use_face_contour=True, + use_pca=False, + flat_hand_mean=False) + + +class SMPLStreamService(BaseFlaskService): + """A websocket service that provides SMPL/SMPLX vertices in stream.""" + + def __init__(self, + name: str, + body_model_dir: str, + work_dir: str, + secret_key: Union[None, str] = None, + flat_hand_mean: bool = False, + enable_bytes: bool = True, + enable_gzip: bool = False, + debug: bool = False, + enable_cors: bool = False, + device: Union[torch.device, str] = 'cuda', + host: str = '0.0.0.0', + port: int = 29091, + max_http_buffer_size: int = 128 * 1024 * 1024, + logger: Union[None, str, logging.Logger] = None) -> None: + """ + Args: + name (str): Name of this service. + body_model_dir (str): + Path to the directory for SMPL(X) body models, folder `smpl` + and `smplx` are below body_model_dir. + work_dir (str): + Path to a directory, for temp files of this server. + secret_key (Union[None, str], optional): + Secret key for this service. If None, a random key will be + generated. Defaults to None. + flat_hand_mean (bool, optional): + If False, then the pose of the hand is initialized to False. + Defaults to False. + enable_bytes (bool, optional): + Whether to enable bytes response. Defaults to True. + enable_gzip (bool, optional): + Whether to enable gzip compression for the verts response. + Defaults to False. + debug (bool, optional): + If `debug` flag is set the server will automatically reload + for code changes and show a debugger in case + an exception happened. + Defaults to False. + enable_cors (bool, optional): + Whether to enable Cross Origin Resource Sharing (CORS). + Defaults to False. + host (str, optional): + Host IP address. 127.0.0.1 for localhost, + 0.0.0.0 for all local network interfaces. + Defaults to '0.0.0.0'. + port (int, optional): + Port for this http service. + Defaults to 80. + max_http_buffer_size (int): + Server's payload. + Defaults to 128MB. + logger (Union[None, str, logging.Logger], optional): + Logger for logging. If None, root logger will be selected. + Defaults to None. + """ + BaseFlaskService.__init__( + self, + name=name, + work_dir=work_dir, + debug=debug, + enable_cors=enable_cors, + host=host, + port=port, + logger=logger, + ) + self.app.config['SECRET_KEY'] = os.urandom(24) \ + if secret_key is None \ + else secret_key + # max_http_buffer_size: the maximum allowed payload + self.socketio = SocketIO( + self.app, max_http_buffer_size=max_http_buffer_size) + self.device = device + self.worker_lock = RLock() + # set body model configs for all types and genders + # stored in self.body_model_configs + self._set_body_model_config(body_model_dir, flat_hand_mean) + # set enable_bytes and enable_gzip + self.enable_bytes = enable_bytes + self.enable_gzip = enable_gzip + if not self.enable_bytes and self.enable_gzip: + self.logger.warning('enable_gzip is set to True,' + + ' but enable_bytes is set to False. ' + 'enable_gzip will be ignored.') + + self.socketio.on_event('upload', self.upload_smpl_data) + + self.socketio.on_event('forward', self.forward_body_model) + + self.socketio.on_event('get_faces', self.get_faces) + + self.socketio.on_event( + message='disconnect', + handler=self.on_disconnect, + ) + self.socketio.on_event( + message='connect', + handler=self.on_connect, + ) + self.forward_timer = Timer( + name='forward_timer', + logger=self.logger, + ) + + def run(self): + """Run this flask service according to configuration. + + This process will be blocked. + """ + self.socketio.run( + app=self.app, debug=self.debug, host=self.host, port=self.port) + + def on_connect(self) -> None: + """Connect event handler. + + Register client uuid. + """ + uuid_str = str(uuid.uuid4()) + session['uuid'] = uuid_str + self.logger.info(f'Client {uuid_str} connected.') + + def on_disconnect(self) -> None: + """Disconnect event handler. + + Args: + data (dict): Request data. uuid is required. + """ + uuid_str = session['uuid'] + self.logger.info( + f'Client {uuid_str} disconnected. Cleaning files and session.') + self._clean_files_by_uuid(uuid_str) + session.clear() + + def upload_smpl_data(self, data: dict) -> dict: + """Upload smpl data file, check whether the corresponding body model + config exists, and save it to work_dir if success. + + Args: + data (dict): smpl data file info, including + file_name and file_data. + + Returns: + dict: response info, including status, and + msg when fails. + """ + + resp_dict = dict() + uuid_str = session['uuid'] + smpl_data_in_session = session.get('smpl_data', None) + if smpl_data_in_session is not None: + warn_msg = f'Client {uuid_str} has already uploaded a file.' +\ + ' Overwriting.' + self.logger.warning(warn_msg) + resp_dict['msg'] = f'Warning: {warn_msg}' + file_name = data['file_name'] + file_data = data['file_data'] + file_path = os.path.join(self.work_dir, f'{uuid_str}_{file_name}.npz') + with open(file_path, 'wb') as file: + file.write(file_data) + # load smpl data + smpl_data, class_name = auto_load_smpl_data( + npz_path=file_path, logger=self.logger) + smpl_type = class_name.replace('Data', '').lower() + smpl_gender = smpl_data.get_gender() + # check if the body model files exist + if smpl_type not in self.body_model_configs or\ + smpl_gender not in self.body_model_configs[smpl_type]: + error_msg = f'Client {uuid_str} has smpl type {smpl_type} ' +\ + f'and smpl gender {smpl_gender}, ' +\ + 'but no corresponding body model config found.' + resp_dict['msg'] = f'Error: {error_msg}' + self.logger.error(error_msg) + emit('upload_response', resp_dict) + # build body model + body_model_cfg = self.body_model_configs[smpl_type][smpl_gender] + body_model = build_body_model(body_model_cfg).to(self.device) + # save body model to cache + session['smpl_type'] = smpl_type.replace('Data', '').lower() + session['smpl_gender'] = smpl_data.get_gender() + session['smpl_data'] = smpl_data + session['body_model'] = body_model + session['last_connect_time'] = time.time() + self.logger.info( + f'Client {uuid_str} smpl data file loaded confirmed.\n' + + f'Body model type: {smpl_type}\n' + f'Gender: {smpl_gender}') + resp_dict['n_frames'] = smpl_data.get_batch_size() + resp_dict['status'] = 'success' + + return resp_dict + + def forward_body_model(self, data: dict) -> dict: + """Call body_model.forward() to get SMPL vertices. + + Args: + data (dict): Request data, frame_idx is required. + + Returns: + dict: Response data. + If success, status is 'success' and + vertices bytes for an ndarray. + """ + resp_dict = dict() + req_dict = data + uuid_str = session['uuid'] + frame_idx = req_dict['frame_idx'] + smpl_data = session['smpl_data'] + # check if data and args are valid + failed = False + if smpl_data is None: + error_msg = f'Client {uuid_str}\'s smpl data not uploaded.' + failed = True + elif frame_idx >= smpl_data.get_batch_size(): + error_msg = f'Client {uuid_str}\'s smpl data only has ' +\ + f'{smpl_data.get_batch_size()} frames, ' +\ + f'but got frame_idx={frame_idx} in request.' + failed = True + if failed: + self.logger.error(error_msg) + resp_dict['msg'] = f'Error: {error_msg}' + resp_dict['status'] = 'fail' + return resp_dict + # no error, forward body model + else: + tensor_dict = smpl_data.to_tensor_dict( + repeat_betas=True, device=self.device) + for k, v in tensor_dict.items(): + tensor_dict[k] = v[frame_idx:frame_idx + 1] + body_model = session['body_model'] + with self.worker_lock: + self.forward_timer.start() + with torch.no_grad(): + body_model_output = body_model(**tensor_dict) + self.forward_timer.stop() + if self.forward_timer.count >= 50: + self.logger.info( + 'Average forward time per-frame:' + + f' {self.forward_timer.get_average(reset=True):.4f} s') + verts = body_model_output['vertices'] # n_batch=1, n_verts, 3 + verts_np = verts.cpu().numpy().squeeze(0).astype(np.float16) + session['last_connect_time'] = time.time() + if self.enable_bytes: + verts_bytes = verts_np.tobytes() + if self.enable_gzip: + verts_bytes = gzip.compress(verts_bytes) + return verts_bytes + else: + resp_dict['verts'] = verts_np.tolist() + resp_dict['status'] = 'success' + return resp_dict + + def get_faces(self) -> dict: + """Get body face indices. + + Returns: + dict: Response data. + If success, status is 'success' and + face indices for an ndarray. + """ + resp_dict = dict() + body_model = session['body_model'] + # check if data and args are valid + if body_model is None: + error_msg = 'Failed to get body model.' + self.logger.error(error_msg) + resp_dict['msg'] = f'Error: {error_msg}' + resp_dict['status'] = 'fail' + + return resp_dict + + session['last_connect_time'] = time.time() + + if self.enable_bytes: + faces = np.array(body_model.faces, dtype=np.int32) + faces_bytes = faces.tobytes() + if self.enable_gzip: + faces_bytes = gzip.compress(faces_bytes) + return faces_bytes + else: + resp_dict['faces'] = faces + resp_dict['status'] = 'success' + return resp_dict + + def _clean_files_by_uuid(self, uuid: str) -> None: + file_names = os.listdir(self.work_dir) + for file_name in file_names: + if file_name.startswith(uuid): + file_path = os.path.join(self.work_dir, file_name) + os.remove(file_path) + + def _set_body_model_config(self, body_model_dir: str, + flat_hand_mean: bool) -> None: + self.body_model_dir = body_model_dir + self.flat_hand_mean = flat_hand_mean + genders = ('neutral', 'female', 'male') + smpl_configs = dict() + smplx_configs = dict() + absent_models = [] + for gender in genders: + file_name = f'SMPL_{gender.upper()}.pkl' + file_path = os.path.join(body_model_dir, 'smpl', file_name) + if os.path.exists(file_path): + gender_config = _SMPL_CONFIG_TEMPLATE.copy() + gender_config['gender'] = gender + gender_config['logger'] = self.logger + gender_config['model_path'] = os.path.join( + body_model_dir, 'smpl') + smpl_configs[gender] = gender_config + else: + absent_models.append(file_name) + file_name = f'SMPLX_{gender.upper()}.npz' + file_path = os.path.join(body_model_dir, 'smplx', file_name) + if os.path.exists(file_path): + gender_config = _SMPLX_CONFIG_TEMPLATE.copy() + gender_config['gender'] = gender + gender_config['logger'] = self.logger + gender_config['flat_hand_mean'] = flat_hand_mean + gender_config['model_path'] = os.path.join( + body_model_dir, 'smplx') + smplx_configs[gender] = gender_config + else: + absent_models.append(file_name) + self.body_model_configs = dict( + smpl=smpl_configs, + smplx=smplx_configs, + ) + if len(smpl_configs) + len(smplx_configs) <= 0: + self.logger.error(f'No body_model found below {body_model_dir}.') + raise FileNotFoundError + if len(absent_models) > 0: + self.logger.warning(f'Missing {len(absent_models)} model files.' + + ' The following models cannot be used:\n' + + f'{absent_models}') diff --git a/xrmocap/utils/service_utils.py b/xrmocap/utils/service_utils.py new file mode 100644 index 0000000..1ec4218 --- /dev/null +++ b/xrmocap/utils/service_utils.py @@ -0,0 +1,19 @@ +import json +from typing import Union + + +def payload_to_dict(input_instance: Union[str, dict]) -> dict: + """Convert flask payload to python dict. + + Args: + input_instance (Union[str, dict]): + Payload get from request.get_json(). + + Returns: + dict: Payload in type dict. + """ + if isinstance(input_instance, dict): + input_dict = input_instance + else: + input_dict = json.loads(s=input_instance) + return input_dict