diff --git a/deploy/squeezesegv3/python/infer.py b/deploy/squeezesegv3/python/infer.py
new file mode 100644
index 00000000..5a6e5d84
--- /dev/null
+++ b/deploy/squeezesegv3/python/infer.py
@@ -0,0 +1,164 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+
+import cv2
+import numpy as np
+import paddle
+from paddle.inference import Config, create_predictor
+
+from paddle3d import transforms as T
+from paddle3d.sample import Sample
+from paddle3d.transforms.normalize import NormalizeRangeImage
+from paddle3d.transforms.reader import LoadSemanticKITTIRange
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--model_file",
+        type=str,
+        help="Model filename, Specify this when your model is a combined model.",
+        required=True)
+    parser.add_argument(
+        "--params_file",
+        type=str,
+        help=
+        "Parameter filename, Specify this when your model is a combined model.",
+        required=True)
+    parser.add_argument(
+        '--lidar_file', type=str, help='The lidar path.', required=True)
+    parser.add_argument(
+        '--img_mean',
+        type=str,
+        help='The mean value of range-view image.',
+        required=True)
+    parser.add_argument(
+        '--img_std',
+        type=str,
+        help='The variance value of range-view image.',
+        required=True)
+    parser.add_argument("--gpu_id", type=int, default=0, help="GPU card id.")
+    parser.add_argument(
+        "--use_trt",
+        type=int,
+        default=0,
+        help="Whether to use tensorrt to accelerate when using gpu.")
+    parser.add_argument(
+        "--trt_precision",
+        type=int,
+        default=0,
+        help="Precision type of tensorrt, 0: kFloat32, 1: kHalf.")
+    parser.add_argument(
+        "--trt_use_static",
+        type=int,
+        default=0,
+        help="Whether to load the tensorrt graph optimization from a disk path."
+    )
+    parser.add_argument(
+        "--trt_static_dir",
+        type=str,
+        help="Path of a tensorrt graph optimization directory.")
+
+    return parser.parse_args()
+
+
+def preprocess(file_path, img_mean, img_std):
+    if isinstance(img_mean, str):
+        img_mean = eval(img_mean)
+    if isinstance(img_std, str):
+        img_std = eval(img_std)
+
+    sample = Sample(path=file_path, modality="lidar")
+
+    transforms = T.Compose([
+        LoadSemanticKITTIRange(project_label=False),
+        NormalizeRangeImage(mean=img_mean, std=img_std)
+    ])
+
+    sample = transforms(sample)
+
+    if "proj_mask" in sample.meta:
+        sample.data *= sample.meta.pop("proj_mask")
+    return np.expand_dims(sample.data,
+                          0), sample.meta.proj_x, sample.meta.proj_y
+
+
+def init_predictor(model_file,
+                   params_file,
+                   gpu_id=0,
+                   use_trt=False,
+                   trt_precision=0,
+                   trt_use_static=False,
+                   trt_static_dir=None):
+    config = Config(model_file, params_file)
+    config.enable_memory_optim()
+    config.enable_use_gpu(1000, gpu_id)
+    if use_trt:
+        precision_mode = paddle.inference.PrecisionType.Float32
+        if trt_precision == 1:
+            precision_mode = paddle.inference.PrecisionType.Half
+        config.enable_tensorrt_engine(
+            workspace_size=1 << 20,
+            max_batch_size=1,
+            min_subgraph_size=3,
+            precision_mode=precision_mode,
+            use_static=trt_use_static,
+            use_calib_mode=False)
+        if trt_use_static:
+            config.set_optim_cache_dir(trt_static_dir)
+
+    predictor = create_predictor(config)
+    return predictor
+
+
+def run(predictor, points):
+    # copy img data to input tensor
+    input_names = predictor.get_input_names()
+    input_tensor = predictor.get_input_handle(input_names[0])
+    input_tensor.reshape(points.shape)
+    input_tensor.copy_from_cpu(points.copy())
+
+    # do the inference
+    predictor.run()
+
+    results = []
+    # get out data from output tensor
+    output_names = predictor.get_output_names()
+    output_tensor = predictor.get_output_handle(output_names[0])
+    pred_label = output_tensor.copy_to_cpu()
+
+    return pred_label[0]
+
+
+def postprocess(pred_img_label, proj_x, proj_y):
+    return pred_img_label[proj_y, proj_x]
+
+
+def main(args):
+    predictor = init_predictor(args.model_file, args.params_file, args.gpu_id,
+                               args.use_trt, args.trt_precision,
+                               args.trt_use_static, args.trt_static_dir)
+    range_img, proj_x, proj_y = preprocess(args.lidar_file, args.img_mean,
+                                           args.img_std)
+    pred_img_label = run(predictor, range_img)
+    pred_point_label = postprocess(pred_img_label, proj_x, proj_y)
+    return pred_point_label
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    main(args)
diff --git a/docs/models/squeezesegv3/README.md b/docs/models/squeezesegv3/README.md
index 7b72fecd..433f28f4 100644
--- a/docs/models/squeezesegv3/README.md
+++ b/docs/models/squeezesegv3/README.md
@@ -10,6 +10,7 @@
   * [训练](#h3-id52h3)
   * [评估](#h3-id53h3)
   * [模型导出](#h3-id54h3)
+  * [模型部署](#h3-id55h3)
 
 ## <h2 id="1">引用</h2>
 
@@ -125,3 +126,42 @@ python tools/export.py \
 | model       | 待导出模型参数`model.pdparams`路径                                                                                    | 是     | -        |
 | input_shape | 指定模型的输入尺寸,支持`N, C, H, W`或`H, W`格式                                                                            | 是     | -        |
 | save_dir    | 保存导出模型的路径,`save_dir`下将会生成三个文件:`squeezesegv3.pdiparams `、`squeezesegv3.pdiparams.info`和`squeezesegv3.pdmodel` | 否     | `deploy` |
+
+
+
+### <h3 id="55">模型部署</h3>
+
+#### C++部署
+
+Coming soon...
+
+#### Python部署
+
+命令参数说明如下:
+
+| 参数 | 说明 |
+| -- | -- |
+| model_file | 导出模型的结构文件`squeezesegv3.pdmodel`所在路径 |
+| params_file | 导出模型的参数文件`squeezesegv3.pdiparams`所在路径 |
+| lidar_file | 待预测的点云文件所在路径 |
+| img_mean | 点云投影到range-view后所成图像的均值,例如为`12.12,10.88,0.23,-1.04,0.21` |
+| img_std | 点云投影到range-view后所成图像的方差,例如为`12.32,11.47,6.91,0.86,0.16` |
+| use_trt | 是否使用TensorRT进行加速,默认0|
+| trt_precision | 当use_trt设置为1时,模型精度可设置0或1,0表示fp32, 1表示fp16。默认0 |
+| trt_use_static | 当trt_use_static设置为1时,**在首次运行程序的时候会将TensorRT的优化信息进行序列化到磁盘上,下次运行时直接加载优化的序列化信息而不需要重新生成**。默认0 |
+| trt_static_dir | 当trt_use_static设置为1时,保存优化信息的路径 |
+
+
+运行以下命令,执行预测:
+
+```
+python infer.py --model_file /path/to/squeezesegv3.pdmodel --params_file /path/to/squeezesegv3.pdiparams --lidar_file /path/to/lidar.pcd.bin --img_mean 12.12,10.88,0.23,-1.04,0.21 --img_std 12.32,11.47,6.91,0.86,0.16
+```
+
+如果要开启TensorRT的话,请卸载掉原有的`paddlepaddel_gpu`,至[Paddle官网](https://paddleinference.paddlepaddle.org.cn/user_guides/download_lib.html#python)下载与TensorRT连编的预编译Paddle Inferece安装包,选择符合本地环境CUDA/cuDNN/TensorRT版本的安装包完成安装即可。
+
+运行以下命令,开启TensorRT加速模型预测:
+
+```
+python infer.py --model_file /path/to/squeezesegv3.pdmodel --params_file /path/to/squeezesegv3.pdiparams --lidar_file /path/to/lidar.pcd.bin --img_mean 12.12,10.88,0.23,-1.04,0.21 --img_std 12.32,11.47,6.91,0.86,0.16 --use_trt 1
+```