diff --git a/easy_rec/python/compat/feature_column/feature_column_v2.py b/easy_rec/python/compat/feature_column/feature_column_v2.py
index 23757669c..844028d2b 100644
--- a/easy_rec/python/compat/feature_column/feature_column_v2.py
+++ b/easy_rec/python/compat/feature_column/feature_column_v2.py
@@ -3595,6 +3595,10 @@ def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
     layer_utils.append_tensor_to_collection(
         compat_ops.GraphKeys.RANK_SERVICE_EMBEDDING, embedding_attrs['name'],
         'input', sparse_tensors.id_tensor)
+    if sparse_tensors.weight_tensor is not None:
+      layer_utils.append_tensor_to_collection(
+        compat_ops.GraphKeys.RANK_SERVICE_EMBEDDING, embedding_attrs['name'],
+        'weighted_input', sparse_tensors.weight_tensor)
 
     return predictions
 
diff --git a/easy_rec/python/core/sampler.py b/easy_rec/python/core/sampler.py
index 9ff25aa5d..480799c35 100644
--- a/easy_rec/python/core/sampler.py
+++ b/easy_rec/python/core/sampler.py
@@ -282,7 +282,7 @@ def get(self, ids):
     sampled_values = tf.py_func(self._get_impl, [ids], self._attr_tf_types)
     result_dict = {}
     for k, t, v in zip(self._attr_names, self._attr_tf_types, sampled_values):
-      v.set_shape([self._num_sample])
+      v.set_shape([None])
       result_dict[k] = v
     return result_dict
 
@@ -508,7 +508,7 @@ def get(self, src_ids, dst_ids):
                                 self._attr_tf_types)
     result_dict = {}
     for k, t, v in zip(self._attr_names, self._attr_tf_types, sampled_values):
-      v.set_shape([self._num_sample])
+      v.set_shape([None])
       result_dict[k] = v
     return result_dict
 
diff --git a/easy_rec/python/input/input.py b/easy_rec/python/input/input.py
index 966ec6cf5..536422a5a 100644
--- a/easy_rec/python/input/input.py
+++ b/easy_rec/python/input/input.py
@@ -262,100 +262,100 @@ def _get_labels(self, fields):
     ])
 
   def _preprocess(self, field_dict):
-    """Preprocess the feature columns.
-
-    preprocess some feature columns, such as TagFeature or LookupFeature,
-    it is expected to handle batch inputs and single input,
-    it could be customized in subclasses
+    """Preprocess the feature columns with negative sampling."""
+    parsed_dict = {}
+    neg_samples = self._maybe_negative_sample(field_dict)
+    if neg_samples:
+      for k, v in neg_samples.items():
+        if k in field_dict:
+          field_dict[k] = tf.concat([field_dict[k], v], axis=0)
+        else:
+          print('appended fields: %s' % k)
+          parsed_dict[k] = v
+          self._appended_fields.append(k)
+    for k, v in self._preprocess_without_negative_sample(field_dict).items():
+      parsed_dict[k] = v
+    return parsed_dict
 
-    Args:
-      field_dict: string to tensor, tensors are dense,
-          could be of shape [batch_size], [batch_size, None], or of shape []
+  def _maybe_negative_sample(self, field_dict):
+    """Negative sampling.
 
     Returns:
-      output_dict: some of the tensors are transformed into sparse tensors,
-          such as input tensors of tag features and lookup features
+      output_dict: if negative sampling is enabled, sampled fields dict is
+          returned. otherwise None is returned.
     """
-    parsed_dict = {}
-
     if self._sampler is not None and self._mode != tf.estimator.ModeKeys.PREDICT:
       if self._mode != tf.estimator.ModeKeys.TRAIN:
         self._sampler.set_eval_num_sample()
       sampler_type = self._data_config.WhichOneof('sampler')
       sampler_config = getattr(self._data_config, sampler_type)
-      item_ids = field_dict[sampler_config.item_id_field]
+      item_ids = self._maybe_squeeze_input(
+          field_dict[sampler_config.item_id_field], name='item_id')
       if sampler_type in ['negative_sampler', 'negative_sampler_in_memory']:
         sampled = self._sampler.get(item_ids)
       elif sampler_type == 'negative_sampler_v2':
-        user_ids = field_dict[sampler_config.user_id_field]
+        user_ids = self._maybe_squeeze_input(
+            field_dict[sampler_config.user_id_field], name='user_id')
         sampled = self._sampler.get(user_ids, item_ids)
       elif sampler_type.startswith('hard_negative_sampler'):
-        user_ids = field_dict[sampler_config.user_id_field]
+        user_ids = self._maybe_squeeze_input(
+            field_dict[sampler_config.user_id_field], name='user_id')
         sampled = self._sampler.get(user_ids, item_ids)
       else:
         raise ValueError('Unknown sampler %s' % sampler_type)
-      for k, v in sampled.items():
-        if k in field_dict:
-          field_dict[k] = tf.concat([field_dict[k], v], axis=0)
-        else:
-          print('appended fields: %s' % k)
-          parsed_dict[k] = v
-          self._appended_fields.append(k)
+      return sampled
+    else:
+      return None
+
+  def _preprocess_without_negative_sample(self,
+                                          field_dict,
+                                          ignore_absent_fields=False):
+    """Preprocess the feature columns.
+
+    preprocess some feature columns, such as TagFeature or LookupFeature,
+    it is expected to handle batch inputs and single input,
+    it could be customized in subclasses
+
+    Args:
+      field_dict: string to tensor, tensors are dense,
+          could be of shape [batch_size], [batch_size, None], or of shape []
+
+    Returns:
+      output_dict: some of the tensors are transformed into sparse tensors,
+          such as input tensors of tag features and lookup features
+    """
+    parsed_dict = {}
 
     for fc in self._feature_configs:
       feature_name = fc.feature_name
       feature_type = fc.feature_type
+      absent_input_names = []
+      for input_name in fc.input_names:
+        if input_name not in field_dict:
+          absent_input_names.append(input_name)
+      if absent_input_names:
+        if ignore_absent_fields:
+          continue
+        else:
+          raise KeyError('feature [{}] lacks input [{}]'.format(
+              feature_name, ', '.join(absent_input_names)))
       input_0 = fc.input_names[0]
       if feature_type == fc.TagFeature:
         input_0 = fc.input_names[0]
         field = field_dict[input_0]
-        # Construct the output of TagFeature according to the dimension of field_dict.
-        # When the input field exceeds 2 dimensions, convert TagFeature to 2D output.
+
+        if fc.HasField('kv_separator') and len(fc.input_names) > 1:
+          assert False, 'Tag Feature Error, ' \
+                        'Cannot set kv_separator and multi input_names in one feature config. Feature: %s.' % input_0
+
         if len(field.get_shape()) < 2 or field.get_shape()[-1] == 1:
+          # Construct the output of TagFeature according to the dimension of field_dict.
+          # When the input field exceeds 2 dimensions, convert TagFeature to 2D output.
           if len(field.get_shape()) == 0:
             field = tf.expand_dims(field, axis=0)
           elif len(field.get_shape()) == 2:
             field = tf.squeeze(field, axis=-1)
-          if fc.HasField('kv_separator') and len(fc.input_names) > 1:
-            assert False, 'Tag Feature Error, ' \
-                          'Cannot set kv_separator and multi input_names in one feature config. Feature: %s.' % input_0
           parsed_dict[input_0] = tf.string_split(field, fc.separator)
-          if fc.HasField('kv_separator'):
-            indices = parsed_dict[input_0].indices
-            tmp_kvs = parsed_dict[input_0].values
-            tmp_kvs = tf.string_split(
-                tmp_kvs, fc.kv_separator, skip_empty=False)
-            tmp_kvs = tf.reshape(tmp_kvs.values, [-1, 2])
-            tmp_ks, tmp_vs = tmp_kvs[:, 0], tmp_kvs[:, 1]
-
-            check_list = [
-                tf.py_func(
-                    check_string_to_number, [tmp_vs, input_0], Tout=tf.bool)
-            ] if self._check_mode else []
-            with tf.control_dependencies(check_list):
-              tmp_vs = tf.string_to_number(
-                  tmp_vs, tf.float32, name='kv_tag_wgt_str_2_flt_%s' % input_0)
-            parsed_dict[input_0] = tf.sparse.SparseTensor(
-                indices, tmp_ks, parsed_dict[input_0].dense_shape)
-            input_wgt = input_0 + '_WEIGHT'
-            parsed_dict[input_wgt] = tf.sparse.SparseTensor(
-                indices, tmp_vs, parsed_dict[input_0].dense_shape)
-            self._appended_fields.append(input_wgt)
-          if not fc.HasField('hash_bucket_size'):
-            check_list = [
-                tf.py_func(
-                    check_string_to_number,
-                    [parsed_dict[input_0].values, input_0],
-                    Tout=tf.bool)
-            ] if self._check_mode else []
-            with tf.control_dependencies(check_list):
-              vals = tf.string_to_number(
-                  parsed_dict[input_0].values,
-                  tf.int32,
-                  name='tag_fea_%s' % input_0)
-            parsed_dict[input_0] = tf.sparse.SparseTensor(
-                parsed_dict[input_0].indices, vals,
-                parsed_dict[input_0].dense_shape)
           if len(fc.input_names) > 1:
             input_1 = fc.input_names[1]
             field = field_dict[input_1]
@@ -382,11 +382,62 @@ def _preprocess(self, field_dict):
                                              tf.identity(field_vals),
                                              field.dense_shape)
             parsed_dict[input_1] = field
+        elif isinstance(field, tf.SparseTensor):
+          # filter out empty values
+          nonempty_selection = tf.where(tf.not_equal(field.values, ''))[:, 0]
+          parsed_dict[input_0] = tf.sparse.SparseTensor(
+              indices=tf.gather(field.indices, nonempty_selection),
+              values=tf.gather(field.values, nonempty_selection),
+              dense_shape=field.dense_shape)
+          if len(fc.input_names) > 1:
+            input_1 = fc.input_names[1]
+            parsed_dict[input_1] = tf.sparse.SparseTensor(
+                indices=tf.gather(field_dict[input_1].indices,
+                                  nonempty_selection),
+                values=tf.gather(field_dict[input_1].values,
+                                 nonempty_selection),
+                dense_shape=field_dict[input_1].dense_shape)
         else:
-          parsed_dict[input_0] = field_dict[input_0]
+          parsed_dict[input_0] = field
           if len(fc.input_names) > 1:
             input_1 = fc.input_names[1]
             parsed_dict[input_1] = field_dict[input_1]
+
+        if fc.HasField('kv_separator'):
+          indices = parsed_dict[input_0].indices
+          tmp_kvs = parsed_dict[input_0].values
+          # split into keys and values
+          tmp_kvs = tf.string_split(tmp_kvs, fc.kv_separator, skip_empty=False)
+          tmp_kvs = tf.reshape(tmp_kvs.values, [-1, 2])
+          tmp_ks, tmp_vs = tmp_kvs[:, 0], tmp_kvs[:, 1]
+          check_list = [
+              tf.py_func(
+                  check_string_to_number, [tmp_vs, input_0], Tout=tf.bool)
+          ] if self._check_mode else []
+          with tf.control_dependencies(check_list):
+            tmp_vs = tf.string_to_number(
+                tmp_vs, tf.float32, name='kv_tag_wgt_str_2_flt_%s' % input_0)
+          parsed_dict[input_0] = tf.sparse.SparseTensor(
+              indices, tmp_ks, parsed_dict[input_0].dense_shape)
+          input_wgt = input_0 + '_WEIGHT'
+          parsed_dict[input_wgt] = tf.sparse.SparseTensor(
+              indices, tmp_vs, parsed_dict[input_0].dense_shape)
+          self._appended_fields.append(input_wgt)
+        if not fc.HasField('hash_bucket_size'):
+          check_list = [
+              tf.py_func(
+                  check_string_to_number,
+                  [parsed_dict[input_0].values, input_0],
+                  Tout=tf.bool)
+          ] if self._check_mode else []
+          with tf.control_dependencies(check_list):
+            vals = tf.string_to_number(
+                parsed_dict[input_0].values,
+                tf.int32,
+                name='tag_fea_%s' % input_0)
+          parsed_dict[input_0] = tf.sparse.SparseTensor(
+              parsed_dict[input_0].indices, vals,
+              parsed_dict[input_0].dense_shape)
       elif feature_type == fc.LookupFeature:
         assert feature_name is not None and feature_name != ''
         assert len(fc.input_names) == 2
@@ -708,8 +759,12 @@ def _preprocess(self, field_dict):
 
     if self._data_config.HasField('sample_weight'):
       if self._mode != tf.estimator.ModeKeys.PREDICT:
-        parsed_dict[constant.SAMPLE_WEIGHT] = field_dict[
-            self._data_config.sample_weight]
+        if self._data_config.sample_weight in field_dict:
+          parsed_dict[constant.SAMPLE_WEIGHT] = field_dict[
+              self._data_config.sample_weight]
+        elif not ignore_absent_fields:
+          raise KeyError('sample weight field [{}] is absent'.format(
+              self._data_config.sample_weight))
     return parsed_dict
 
   def _lookup_preprocess(self, fc, field_dict):
@@ -829,3 +884,35 @@ def _input_fn(mode=None, params=None, config=None):
 
     _input_fn.input_creator = self
     return _input_fn
+
+  def _maybe_squeeze_input(self, tensor, name=None):
+    default_value = None
+    if isinstance(tensor, tf.SparseTensor):
+      if tensor.dtype == tf.string:
+        default_value = ''
+      elif tensor.dtype.is_integer:
+        default_value = -1
+      else:
+        default_value = tensor.dtype.as_numpy_dtype()
+    with tf.name_scope('squeeze_input/{}'.format(name)):
+      rank = len(tensor.get_shape())
+      if rank != 1:
+        tensor_shape = tf.shape(tensor, out_type=tf.int64)
+        check_list = [
+            tf.assert_equal(
+                tf.reduce_prod(tensor_shape[1:]),
+                tf.constant(1, dtype=tensor_shape.dtype),
+                message='{} must not have multi values'.format(name))
+        ]
+        with tf.control_dependencies(check_list):
+          if isinstance(tensor, tf.SparseTensor):
+            return tf.sparse_to_dense(
+                tensor.indices[:, :1], [tensor_shape[0]],
+                tensor.values,
+                default_value=default_value)
+          else:
+            return tf.reshape(tensor, [tensor_shape[0]])
+      elif isinstance(tensor, tf.SparseTensor):
+        return tf.sparse.to_dense(tensor, default_value=default_value)
+      else:
+        return tensor
diff --git a/easy_rec/python/input/odps_rtp_input_v2.py b/easy_rec/python/input/odps_rtp_input_v2.py
index c5a0e8079..096e9de10 100644
--- a/easy_rec/python/input/odps_rtp_input_v2.py
+++ b/easy_rec/python/input/odps_rtp_input_v2.py
@@ -2,10 +2,18 @@
 # Copyright (c) Alibaba, Inc. and its affiliates.
 import json
 import logging
+from enum import Enum
 
 import tensorflow as tf
 
 from easy_rec.python.input.odps_rtp_input import OdpsRTPInput
+from easy_rec.python.utils.input_utils import concat_parsed_features
+
+if tf.__version__ >= '2.0':
+  from tensorflow import argsort as tf_argsort
+  tf = tf.compat.v1
+else:
+  from tensorflow.contrib.framework import argsort as tf_argsort
 
 try:
   import pai
@@ -15,6 +23,47 @@
   rtp_fg = None
 
 
+class RtpFeatureType(Enum):
+  RAW_FEATURE = 'raw_feature'
+  ID_FEATURE = 'id_feature'
+  COMBO_FEATURE = 'combo_feature'
+  LOOKUP_FEATURE = 'lookup_feature'
+  MATCH_FEATURE = 'match_feature'
+
+
+class RtpFeatureConfig:
+
+  def __init__(self, fc_dict):
+    self.feature_name = str(fc_dict.get('feature_name'))
+    self.feature_type = RtpFeatureType(fc_dict.get('feature_type'))
+    self.value_dimension = int(fc_dict.get('value_dimension', 0))
+
+
+class RtpSequenceConfig:
+
+  def __init__(self, fc_dict):
+    self.sequence_name = str(fc_dict.get('sequence_name'))
+    self.sequence_length = int(fc_dict.get('sequence_length'))
+    if self.sequence_length <= 0:
+      raise ValueError(
+          'sequence feature [{}] has illegal sequence length [{}]'.format(
+              self.sequence_name, self.sequence_length))
+    self.features = [
+        RtpFeatureConfig(feature_dict)
+        for feature_dict in fc_dict.get('features')
+    ]
+
+
+def parse_rtp_feature_config(fg_config_dict):
+  feature_configs = []
+  for fc_dict in fg_config_dict.get('features'):
+    if fc_dict.get('sequence_name'):
+      feature_configs.append(RtpSequenceConfig(fc_dict))
+    else:
+      feature_configs.append(RtpFeatureConfig(fc_dict))
+  return feature_configs
+
+
 class OdpsRTPInputV2(OdpsRTPInput):
   """RTPInput for parsing rtp fg new input format on odps.
 
@@ -46,6 +95,32 @@ def __init__(self,
       raise ValueError('fg_json_path is not set')
     with tf.gfile.GFile(self._fg_config_path, 'r') as f:
       self._fg_config = json.load(f)
+    self._rtp_features = parse_rtp_feature_config(self._fg_config)
+
+  def _preprocess(self, field_dict):
+    parsed_dict = {}
+    neg_samples = self._maybe_negative_sample(field_dict)
+    neg_parsed_dict = {}
+    if neg_samples:
+      neg_field_dict = {}
+      for k, v in neg_samples.items():
+        if k in field_dict:
+          neg_field_dict[k] = v
+        else:
+          print('appended fields: %s' % k)
+          parsed_dict[k] = v
+          self._appended_fields.append(k)
+      neg_parsed_dict = self._preprocess_without_negative_sample(
+          neg_field_dict, ignore_absent_fields=True)
+    for k, v in self._preprocess_without_negative_sample(field_dict).items():
+      if k in neg_parsed_dict:
+        try:
+          v = concat_parsed_features([v, neg_parsed_dict[k]], name=k)
+        except Exception as e:  # NOQA
+          logging.error('failed to concat parsed features [{}]'.format(k))
+          raise
+      parsed_dict[k] = v
+    return parsed_dict
 
   def _parse_table(self, *fields):
     self.check_rtp()
@@ -56,16 +131,138 @@ def _parse_table(self, *fields):
     # assume that the last field is the generated feature column
     features = rtp_fg.parse_genreated_fg(self._fg_config, fields[-1])
 
-    field_keys = [x for x in self._input_fields if x not in self._label_fields]
-    for feature_key in features:
-      if feature_key not in field_keys or feature_key not in self._effective_fields:
-        del features[feature_key]
-    inputs = {x: features[x] for x in features.keys()}
+    inputs = self._transform_features(features)
 
     for x in range(len(self._label_fields)):
       inputs[self._label_fields[x]] = labels[x]
+
     return inputs
 
+  def _transform_features(self, rtp_features):
+    """Transform features from RTP format into EasyRec format."""
+    features = {}
+    for fc in self._rtp_features:
+      if isinstance(fc, RtpSequenceConfig):
+        for sfc in fc.features:
+          sub_feature_name = '{}__{}'.format(fc.sequence_name, sfc.feature_name)
+          with tf.name_scope(
+              'sequence_feature_transform/{}'.format(sub_feature_name)):
+            shape_0_list = []
+            shape_2_list = []
+            indices_0_list = []
+            indices_1_list = []
+            indices_2_list = []
+            values_list = []
+            if sfc.feature_type == RtpFeatureType.ID_FEATURE:
+              for i in range(fc.sequence_length):
+                sub_feature_name_rtp = '{}_{}_{}'.format(
+                    fc.sequence_name, i, sfc.feature_name)
+                if sub_feature_name_rtp not in rtp_features:
+                  raise ValueError(
+                      'sequence sub feature [{}] is missing'.format(
+                          sub_feature_name_rtp))
+                sub_feature_tensor = rtp_features[sub_feature_name_rtp]
+                assert isinstance(sub_feature_tensor, tf.SparseTensor), \
+                    'sequence sub feature [{}] must be sparse'
+                values_list.append(sub_feature_tensor.values)
+                shape_0_list.append(sub_feature_tensor.dense_shape[0])
+                shape_2_list.append(sub_feature_tensor.dense_shape[1])
+                indices_0_item = sub_feature_tensor.indices[:, 0]
+                indices_1_item = tf.tile(
+                    tf.constant([i], dtype=indices_0_item.dtype),
+                    tf.shape(indices_0_item))
+                indices_2_item = sub_feature_tensor.indices[:, 1]
+                indices_0_list.append(indices_0_item)
+                indices_1_list.append(indices_1_item)
+                indices_2_list.append(indices_2_item)
+            elif sfc.feature_type == RtpFeatureType.RAW_FEATURE:
+              for i in range(fc.sequence_length):
+                sub_feature_name_rtp = '{}_{}_{}'.format(
+                    fc.sequence_name, i, sfc)
+                if sub_feature_name_rtp not in rtp_features:
+                  raise ValueError(
+                      'sequence sub feature [{}] is missing'.format(
+                          sub_feature_name_rtp))
+                sub_feature_tensor = rtp_features[sub_feature_name_rtp]
+                assert isinstance(sub_feature_tensor, tf.Tensor), \
+                    'sequence sub feature [{}] must be dense'.format(sub_feature_name_rtp)
+                values_list.append(sub_feature_tensor)
+                assert len(sub_feature_tensor.get_shape()) == 2, \
+                    'sequence sub feature [{}] must be 2-dimensional'.format(sub_feature_name_rtp)
+                sub_feature_shape = tf.shape(sub_feature_tensor)
+                sub_feature_shape_0 = sub_feature_shape[0]
+                sub_feature_shape_1 = sub_feature_shape[1]
+                shape_0_list.append(sub_feature_shape_0)
+                shape_2_list.append(sub_feature_shape_1)
+                indices_2_item, indices_0_item = tf.meshgrid(
+                    tf.range(0, sub_feature_shape_1),
+                    tf.range(0, sub_feature_shape_0))
+                num_elements = tf.reduce_prod(sub_feature_shape)
+                indices_0_item = tf.reshape(indices_0_item, [num_elements])
+                indices_1_item = tf.tile(
+                    tf.constant([i], dtype=indices_0_item.dtype),
+                    tf.constant([num_elements], dtype=tf.int32))
+                indices_2_item = tf.reshape(indices_2_item, [num_elements])
+                indices_0_list.append(indices_0_item)
+                indices_1_list.append(indices_1_item)
+                indices_2_list.append(indices_2_item)
+            else:
+              raise ValueError(
+                  'sequence sub feature [{}] illegal type [{}]'.format(
+                      sub_feature_name, sfc.feature_type))
+            # note that, as the first dimension is batch size, all values in shape_0_list should be the same
+            indices_0 = tf.concat(indices_0_list, axis=0, name='indices_0')
+            shape_0 = tf.reduce_max(shape_0_list, name='shape_0')
+            # the second dimension is the sequence length
+            indices_1 = tf.concat(indices_1_list, axis=0, name='indices_1')
+            shape_1 = tf.maximum(
+                tf.add(tf.reduce_max(indices_1), 1), 0, name='shape_1')
+            # shape_2 is the max number of multi-values of a single feature value
+            indices_2 = tf.concat(indices_2_list, axis=0, name='indices_2')
+            shape_2 = tf.reduce_max(shape_2_list, name='shape_2')
+            # values
+            values = tf.concat(values_list, axis=0, name='values')
+            # sort the values along the first dimension indices
+            sorting = tf_argsort(indices_0, name='argsort_after_concat')
+            is_single_sample = tf.equal(
+                shape_0,
+                tf.constant(1, dtype=shape_0.dtype),
+                name='is_single_sample')
+            indices_0 = tf.cond(
+                is_single_sample,
+                lambda: indices_0,
+                lambda: tf.gather(indices_0, sorting, name='indices_0_sorted'),
+                name='indices_0_optional')
+            indices_1 = tf.cond(
+                is_single_sample,
+                lambda: indices_1,
+                lambda: tf.gather(indices_1, sorting, name='indices_1_sorted'),
+                name='indices_1_optional')
+            indices_2 = tf.cond(
+                is_single_sample,
+                lambda: indices_2,
+                lambda: tf.gather(indices_2, sorting, name='indices_2_sorted'),
+                name='indices_2_optional')
+            values = tf.cond(
+                is_single_sample,
+                lambda: values,
+                lambda: tf.gather(values, sorting, name='values_sorted'),
+                name='values_optional')
+            # construct the 3-dimensional sparse tensor
+            features[sub_feature_name] = tf.SparseTensor(
+                dense_shape=tf.stack([shape_0, shape_1, shape_2],
+                                     axis=0,
+                                     name='shape'),
+                indices=tf.stack([indices_0, indices_1, indices_2],
+                                 axis=1,
+                                 name='indices'),
+                values=values)
+      elif isinstance(fc, RtpFeatureConfig):
+        features[fc.feature_name] = rtp_features[fc.feature_name]
+      else:
+        raise TypeError('illegal feature config type {}'.format(type(fc)))
+    return features
+
   def create_placeholders(self, *args, **kwargs):
     """Create serving placeholders with rtp_fg."""
     self.check_rtp()
@@ -74,6 +271,7 @@ def create_placeholders(self, *args, **kwargs):
     print('[OdpsRTPInputV2] building placeholders.')
     print('[OdpsRTPInputV2] fg_config: {}'.format(self._fg_config))
     features = rtp_fg.parse_genreated_fg(self._fg_config, inputs_placeholder)
+    features = self._transform_features(features)
     print('[OdpsRTPInputV2] built features: {}'.format(features.keys()))
     features = self._preprocess(features)
     print('[OdpsRTPInputV2] processed features: {}'.format(features.keys()))
diff --git a/easy_rec/python/model/dssm.py b/easy_rec/python/model/dssm.py
index ce0592e89..d40ae9b6b 100644
--- a/easy_rec/python/model/dssm.py
+++ b/easy_rec/python/model/dssm.py
@@ -122,20 +122,17 @@ def build_output_dict(self):
 
   def build_rtp_output_dict(self):
     output_dict = super(DSSM, self).build_rtp_output_dict()
-    if 'user_tower_emb' not in self._prediction_dict:
-      raise ValueError(
-          'User tower embedding does not exist. Please checking predict graph.')
-    output_dict['user_embedding_output'] = tf.identity(
-        self._prediction_dict['user_tower_emb'], name='user_embedding_output')
-    if 'item_tower_emb' not in self._prediction_dict:
+    if self._loss_type in (LossType.CLASSIFICATION,
+                           LossType.SOFTMAX_CROSS_ENTROPY):
+      rank_predict_source = 'probs'
+    elif self._loss_type == LossType.L2_LOSS:
+      rank_predict_source = 'y'
+    else:
+      raise ValueError('invalid loss type: %s' % str(self._loss_type))
+    if rank_predict_source not in self._prediction_dict:
       raise ValueError(
-          'Item tower embedding does not exist. Please checking predict graph.')
-    output_dict['item_embedding_output'] = tf.identity(
-        self._prediction_dict['item_tower_emb'], name='item_embedding_output')
-    if self._loss_type == LossType.CLASSIFICATION:
-      if 'probs' not in self._prediction_dict:
-        raise ValueError(
-            'Probs output does not exist. Please checking predict graph.')
-      output_dict['rank_predict'] = tf.identity(
-          self._prediction_dict['probs'], name='rank_predict')
+          ('Rank prediction source node [{}] does not exist.' +
+          'Please check the predict graph.').format(rank_predict_source))
+    output_dict['rank_predict'] = tf.identity(
+        self._prediction_dict[rank_predict_source], name='rank_predict')
     return output_dict
diff --git a/easy_rec/python/model/easy_rec_estimator.py b/easy_rec/python/model/easy_rec_estimator.py
index ef8932dd0..13a780871 100644
--- a/easy_rec/python/model/easy_rec_estimator.py
+++ b/easy_rec/python/model/easy_rec_estimator.py
@@ -17,7 +17,6 @@
 from tensorflow.python.platform import gfile
 from tensorflow.python.saved_model import signature_constants
 from tensorflow.python.training import basic_session_run_hooks
-from tensorflow.python.training import saver
 
 from easy_rec.python.builders import optimizer_builder
 from easy_rec.python.compat import optimizers
@@ -643,21 +642,63 @@ def export_checkpoint(self,
                         serving_input_receiver_fn=None,
                         checkpoint_path=None,
                         mode=tf.estimator.ModeKeys.PREDICT):
-    with context.graph_mode():
-      if not checkpoint_path:
-        # Locate the latest checkpoint
-        checkpoint_path = estimator_utils.latest_checkpoint(self._model_dir)
-      if not checkpoint_path:
-        raise ValueError("Couldn't find trained model at %s." % self._model_dir)
-      with ops.Graph().as_default():
+    server_target = None
+    if 'TF_CONFIG' in os.environ:
+      tf_config = estimator_utils.chief_to_master()
+      from tensorflow.python.training import server_lib
+      if tf_config['task']['type'] == 'ps':
+        cluster = tf.train.ClusterSpec(tf_config['cluster'])
+        server = server_lib.Server(
+            cluster, job_name='ps', task_index=tf_config['task']['index'])
+        server.join()
+      elif tf_config['task']['type'] == 'master':
+        if 'ps' in tf_config['cluster']:
+          cluster = tf.train.ClusterSpec(tf_config['cluster'])
+          server = server_lib.Server(cluster, job_name='master', task_index=0)
+          server_target = server.target
+          print('server_target = %s' % server_target)
+
+    if not checkpoint_path:
+      # Locate the latest checkpoint
+      checkpoint_path = estimator_utils.latest_checkpoint(self._model_dir)
+    if not checkpoint_path:
+      raise ValueError("Couldn't find trained model at %s." % self._model_dir)
+
+    if server_target:
+      from tensorflow.python.training.device_setter import replica_device_setter
+      from tensorflow.python.framework.ops import device
+      from tensorflow.python.training.monitored_session import MonitoredSession
+      from tensorflow.python.training.monitored_session import ChiefSessionCreator
+      with device(
+          replica_device_setter(
+              worker_device='/job:master/task:0', cluster=cluster)):
         input_receiver = serving_input_receiver_fn()
         estimator_spec = self._call_model_fn(
             features=input_receiver.features,
             labels=getattr(input_receiver, 'labels', None),
             mode=mode,
             config=self.config)
-        with tf_session.Session(config=self._session_config) as session:
-          graph_saver = estimator_spec.scaffold.saver or saver.Saver(
-              sharded=True)
-          graph_saver.restore(session, checkpoint_path)
-          graph_saver.save(session, export_path)
+      graph_saver = tf.train.Saver(sharded=True)
+      chief_sess_creator = ChiefSessionCreator(
+          master=server_target,
+          scaffold=tf.train.Scaffold(saver=graph_saver),
+          checkpoint_filename_with_path=checkpoint_path)
+      with MonitoredSession(
+          session_creator=chief_sess_creator,
+          hooks=None,
+          stop_grace_period_secs=120) as sess:
+        graph_saver.save(sess._tf_sess(), export_path)
+    else:
+      with context.graph_mode():
+        with ops.Graph().as_default():
+          input_receiver = serving_input_receiver_fn()
+          estimator_spec = self._call_model_fn(
+              features=input_receiver.features,
+              labels=getattr(input_receiver, 'labels', None),
+              mode=mode,
+              config=self.config)
+          with tf_session.Session(config=self._session_config) as session:
+            graph_saver = estimator_spec.scaffold.saver or tf.train.Saver(
+                sharded=True)
+            graph_saver.restore(session, checkpoint_path)
+            graph_saver.save(session, export_path)
diff --git a/easy_rec/python/model/match_model.py b/easy_rec/python/model/match_model.py
index 475ae6def..a41930c41 100644
--- a/easy_rec/python/model/match_model.py
+++ b/easy_rec/python/model/match_model.py
@@ -153,12 +153,12 @@ def _build_list_wise_loss_graph(self):
           tf.log(hit_prob + 1e-12) * tf.squeeze(self._sample_weight))
       logging.info('softmax cross entropy loss is used')
 
-      user_features = self._prediction_dict['user_tower_emb']
-      pos_item_emb = self._prediction_dict['item_tower_emb'][:batch_size]
-      pos_simi = tf.reduce_sum(user_features * pos_item_emb, axis=1)
-      # if pos_simi < 0, produce loss
-      reg_pos_loss = tf.nn.relu(-pos_simi)
-      self._loss_dict['reg_pos_loss'] = tf.reduce_mean(reg_pos_loss)
+      # user_features = self._prediction_dict['user_tower_emb']
+      # pos_item_emb = self._prediction_dict['item_tower_emb'][:batch_size]
+      # pos_simi = tf.reduce_sum(user_features * pos_item_emb, axis=1)
+      # # if pos_simi < 0, produce loss
+      # reg_pos_loss = tf.nn.relu(-pos_simi)
+      # self._loss_dict['reg_pos_loss'] = tf.reduce_mean(reg_pos_loss)
     else:
       raise ValueError('invalid loss type: %s' % str(self._loss_type))
     return self._loss_dict
diff --git a/easy_rec/python/utils/input_utils.py b/easy_rec/python/utils/input_utils.py
index d42127bc3..86ff73cd4 100644
--- a/easy_rec/python/utils/input_utils.py
+++ b/easy_rec/python/utils/input_utils.py
@@ -72,3 +72,123 @@ def string_to_number(field, ftype, default_value, name=''):
   else:
     assert False, 'invalid types: %s' % str(ftype)
   return tmp_field
+
+
+def _calculate_concat_shape(shapes):
+  for shape in shapes:
+    assert len(shape.get_shape()) == 1
+  shapes_stack = tf.stack(shapes, axis=0)
+  batch_size = tf.reduce_sum(shapes_stack[:,:1], axis=0)
+  other_sizes = tf.reduce_max(shapes_stack[:,1:], axis=0)
+  return tf.cond(
+    tf.equal(
+      tf.shape(other_sizes, out_type=tf.int32)[0],
+      tf.constant(0, dtype=tf.int32)),
+    lambda: batch_size,
+    lambda: tf.concat([batch_size, other_sizes], axis=0)
+  )
+
+
+def _accumulate_concat_indices(indices_list, shape_list):
+  with tf.name_scope('accumulate_concat_indices'):
+    assert len(indices_list) != 0
+    indices_shape = indices_list[0].get_shape()
+    assert len(indices_shape) == 2
+    rank = indices_shape[1].value
+    assert rank is not None and rank > 0
+    indices_0_list = [indices_list[0][:,:1]]
+    offset = shape_list[0][0]
+    for i in range(1, len(indices_list)):
+      indices_0_list.append(tf.add(indices_list[i][:,:1], offset))
+      if i == len(indices_list) - 1:
+        break
+      offset = tf.add(offset, shape_list[i][0])
+    if rank == 1:
+      return indices_0_list
+    else:
+      return [
+        tf.concat([indices_0, indices[:,1:]], axis=1)
+        for indices_0, indices
+        in zip(indices_0_list, indices_list)
+      ]
+
+
+def _dense_to_sparse(dense_tensor):
+  with tf.name_scope('dense_to_sparse'):
+    shape = tf.shape(dense_tensor, out_type=tf.int64, name='sparse_shape')
+    nelems = tf.size(dense_tensor, out_type=tf.int64, name='num_elements')
+    indices = tf.transpose(
+      tf.unravel_index(tf.range(nelems, dtype=tf.int64), shape),
+      name='sparse_indices')
+    values = tf.reshape(dense_tensor, [nelems], name='sparse_values')
+    return tf.SparseTensor(indices, values, shape)
+
+
+def _concat_parsed_features_impl(features, name):
+  is_sparse = False
+  for feature in features:
+    if isinstance(feature, tf.SparseTensor):
+      is_sparse = True
+      break
+  feature_ranks = [len(feature.get_shape()) for feature in features]
+  max_rank = max(feature_ranks)
+  if is_sparse:
+    concat_indices = []
+    concat_values = []
+    concat_shapes = []
+    # concat_tensors = []
+    for i in range(len(features)):
+      with tf.name_scope('sparse_preprocess_{}'.format(i)):
+        feature = features[i]
+        if isinstance(feature, tf.Tensor):
+          feature = _dense_to_sparse(feature)
+        feature_rank = feature_ranks[i]
+        if feature_rank < max_rank:
+          # expand dimensions
+          feature = tf.SparseTensor(
+            tf.pad(feature.indices,
+                   [[0,0], [0,max_rank-feature_rank]],
+                   constant_values=0,
+                   name='indices_expanded'),
+            feature.values,
+            tf.pad(feature.dense_shape,
+                   [[0,max_rank-feature_rank]],
+                   constant_values=1,
+                   name='shape_expanded')
+          )
+        concat_indices.append(feature.indices)
+        concat_values.append(feature.values)
+        concat_shapes.append(feature.dense_shape)
+    with tf.name_scope('sparse_indices'):
+      concat_indices = _accumulate_concat_indices(concat_indices, concat_shapes)
+      sparse_indices = tf.concat(concat_indices, axis=0)
+    with tf.name_scope('sparse_values'):
+      sparse_values = tf.concat(concat_values, axis=0)
+    with tf.name_scope('sparse_shape'):
+      sparse_shape = _calculate_concat_shape(concat_shapes)
+    return tf.SparseTensor(
+      sparse_indices,
+      sparse_values,
+      sparse_shape
+    )
+  else:
+    # expand dimensions
+    for i in range(len(features)):
+      with tf.name_scope('dense_preprocess_{}'.format(i)):
+        feature_rank = feature_ranks[i]
+        if feature_rank < max_rank:
+          new_shape = tf.pad(tf.shape(feature),
+                             [[0, max_rank - feature_rank]],
+                             constant_values=1,
+                             name='shape_expanded')
+          features[i] = tf.reshape(features[i], new_shape, name='dense_expanded')
+    # assumes that dense tensors are of the same shape
+    return tf.concat(features, axis=0, name='dense_concat')
+
+
+def concat_parsed_features(features, name=None):
+  if name:
+    with tf.name_scope('concat_parsed_features__{}'.format(name)):
+      return _concat_parsed_features_impl(features, name)
+  else:
+    return _concat_parsed_features_impl(features, '<unknown>')
diff --git a/easy_rec/version.py b/easy_rec/version.py
index 6cf8729cb..66aa9f0e3 100644
--- a/easy_rec/version.py
+++ b/easy_rec/version.py
@@ -1,3 +1,3 @@
 # -*- encoding:utf-8 -*-
 # Copyright (c) Alibaba, Inc. and its affiliates.
-__version__ = '0.5.6'
+__version__ = '0.5.7'
diff --git a/setup.cfg b/setup.cfg
index b180b9fb1..b5b966faa 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -10,7 +10,7 @@ multi_line_output = 7
 force_single_line = true
 known_standard_library = setuptools
 known_first_party = easy_rec
-known_third_party = absl,common_io,docutils,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
+known_third_party = absl,common_io,distutils,docutils,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
 no_lines_before = LOCALFOLDER
 default_section = THIRDPARTY
 skip = easy_rec/python/protos