Skip to content

Commit

Permalink
add custom op
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Jul 26, 2024
1 parent ffda1da commit f2f5b74
Show file tree
Hide file tree
Showing 16 changed files with 217 additions and 42 deletions.
2 changes: 2 additions & 0 deletions docs/source/feature/pai_rec_callback_conf.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# PAI-REC 全埋点配置

## PAI-Rec引擎的callback服务文档

- [文档](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pairec/docs/pairec/html/intro/callback_api.html)

## 模板
Expand Down
1 change: 1 addition & 0 deletions docs/source/intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,6 @@ EasyRec implements state of the art machine learning models used in common recom
- Run [`knn algorithm`](vector_retrieve.md) of vectors in distribute environment

### Contact

- DingDing Group: 32260796. (EasyRec usage general discussion.)
- DingDing Group: 37930014162, click [this url](https://qr.dingtalk.com/action/joingroup?code=v1,k1,oHNqtNObbu+xUClHh77gCuKdGGH8AYoQ8AjKU23zTg4=&_dt_no_comment=1&origin=11) or scan QrCode to join![new_group.jpg](../images/qrcode/new_group.jpg)
7 changes: 6 additions & 1 deletion easy_rec/python/inference/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import signature_constants

import easy_rec
from easy_rec.python.utils import numpy_utils
from easy_rec.python.utils.config_util import get_configs_from_pipeline_file
from easy_rec.python.utils.config_util import get_input_name_from_fg_json
from easy_rec.python.utils.config_util import search_fg_json
from easy_rec.python.utils.input_utils import get_type_defaults
from easy_rec.python.utils.load_class import get_register_class_meta

try:
tf.load_op_library(os.path.join(easy_rec.ops_dir, 'libcustom_ops.so'))
except Exception as ex:
logging.warning('exception: %s' % str(ex))

if tf.__version__ >= '2.0':
tf = tf.compat.v1

Expand Down
13 changes: 9 additions & 4 deletions easy_rec/python/layers/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, config, features, input_layer, l2_reg=None):
self._dag.add_node(block.name)
self._name_to_blocks[block.name] = block
layer = block.WhichOneof('layer')
if layer == 'input_layer':
if layer in {'input_layer', 'raw_input'}:
if len(block.inputs) != 1:
raise ValueError('input layer `%s` takes only one input' % block.name)
one_input = block.inputs[0]
Expand All @@ -71,8 +71,11 @@ def __init__(self, config, features, input_layer, l2_reg=None):
if group in input_feature_groups:
logging.warning('input `%s` already exists in other block' % group)
else:
input_fn = EnhancedInputLayer(self._input_layer, self._features,
group, reuse)
if layer == 'input_layer':
input_fn = EnhancedInputLayer(self._input_layer, self._features,
group, reuse)
else:
input_fn = self._input_layer.get_raw_features(self._features, group)
input_feature_groups[group] = input_fn
self._name_to_layer[block.name] = input_fn
else:
Expand All @@ -91,7 +94,7 @@ def __init__(self, config, features, input_layer, l2_reg=None):
num_pkg_input = 0
for block in config.blocks:
layer = block.WhichOneof('layer')
if layer == 'input_layer':
if layer in {'input_layer', 'raw_input'}:
continue
name = block.name
if name in input_feature_groups:
Expand Down Expand Up @@ -270,6 +273,8 @@ def call(self, is_training, **kwargs):
if layer is None: # identity layer
output = self.block_input(config, block_outputs, is_training, **kwargs)
block_outputs[block] = output
elif layer == 'raw_input':
block_outputs[block] = self._name_to_layer[block]
elif layer == 'input_layer':
input_fn = self._name_to_layer[block]
input_config = config.input_layer
Expand Down
3 changes: 3 additions & 0 deletions easy_rec/python/layers/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from .blocks import TextCNN
from .bst import BST
from .custom_ops import EditDistance
from .custom_ops import MappedDotProduct
from .custom_ops import OverlapFeature
from .custom_ops import SeqAugmentOps
from .custom_ops import TextNormalize
from .data_augment import SeqAugment
from .din import DIN
from .fibinet import BiLinear
Expand Down
199 changes: 166 additions & 33 deletions easy_rec/python/layers/keras/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Convenience blocks for using custom ops."""
import logging
import os

import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.keras.layers import Layer
Expand All @@ -23,8 +24,19 @@
else:
ops_dir = None

if tf.__version__ >= '2.0':
tf = tf.compat.v1
logging.info('ops_dir is %s' % ops_dir)
custom_op_path = os.path.join(ops_dir, 'libcustom_ops.so')
try:
custom_ops = tf.load_op_library(custom_op_path)
logging.info('load custom op from %s succeed' % custom_op_path)
except Exception as ex:
logging.warning('load custom op from %s failed: %s' %
(custom_op_path, str(ex)))
custom_ops = None


# if tf.__version__ >= '2.0':
# tf = tf.compat.v1


class SeqAugmentOps(Layer):
Expand All @@ -34,26 +46,16 @@ def __init__(self, params, name='sequence_aug', reuse=None, **kwargs):
super(SeqAugmentOps, self).__init__(name, **kwargs)
self.reuse = reuse
self.seq_aug_params = params.get_pb_config()
logging.info("ops_dir is %s" % ops_dir)
custom_op_path = os.path.join(ops_dir, 'libcustom_ops.so')
try:
custom_ops = tf.load_op_library(custom_op_path)
logging.info('load edit_distance op from %s succeed' % custom_op_path)
except Exception as ex:
logging.warning('load edit_distance op from %s failed: %s' %
(custom_op_path, str(ex)))
custom_ops = None
self.seq_augment = custom_ops.my_seq_augment

def build(self, input_shape):
assert len(input_shape) >= 2, 'MaskBlock must has at least two inputs'
assert len(input_shape) >= 2, 'SeqAugmentOps must has at least two inputs'
embed_dim = int(input_shape[0][-1])
self.mask_emb = self.add_weight(
shape=(embed_dim,),
initializer="glorot_uniform",
trainable=True,
name="mask"
)
shape=(embed_dim,),
initializer='glorot_uniform',
trainable=True,
name='mask')

def call(self, inputs, training=None, **kwargs):
assert isinstance(inputs, (list, tuple))
Expand All @@ -66,21 +68,152 @@ def call(self, inputs, training=None, **kwargs):
return x


class EditDistance(tf.keras.layers.Layer):
class TextNormalize(Layer):

def __init__(self, params, name='text_normalize', reuse=None, **kwargs):
super(TextNormalize, self).__init__(name, **kwargs)
self.txt_normalizer = custom_ops.text_normalize_op
self.norm_parameter = params.get_or_default('norm_parameter', 0)
self.remove_space = params.get_or_default('remove_space', False)

def call(self, inputs, training=None, **kwargs):
inputs = inputs if type(inputs) in (tuple, list) else [inputs]
result = [
self.txt_normalizer(
txt, parameter=self.norm_parameter, remove_space=self.remove_space)
for txt in inputs
]
if len(result) == 1:
return result[0]
return result


class MappedDotProduct(Layer):

def __init__(self, params, name='mapped_dot_product', reuse=None, **kwargs):
super(MappedDotProduct, self).__init__(name, **kwargs)
self.mapped_dot_product = custom_ops.mapped_dot_product
self.bucketize = custom_ops.my_bucketize
self.default_value = params.get_or_default('default_value', 0)
self.separator = params.get_or_default('separator', '\035')
self.norm_fn = params.get_or_default('normalize_fn', None)
self.boundaries = list(params.get_or_default('boundaries', []))
self.emb_dim = params.get_or_default('embedding_dim', 0)
self.print_first_n = params.get_or_default('print_first_n', 0)
self.summarize = params.get_or_default('summarize', None)
if self.emb_dim > 0:
vocab_size = len(self.boundaries) + 1
with tf.variable_scope(self.name, reuse=reuse):
self.embedding_table = tf.get_variable(
name='dot_product_emb_table',
shape=[vocab_size, self.emb_dim],
dtype=tf.float32)

def call(self, inputs, training=None, **kwargs):
query, doc = inputs[:2]
with ops.device('/CPU:0'):
feature = self.mapped_dot_product(
query=query,
document=doc,
feature_name=self.name,
separator=self.separator,
default_value=self.default_value)
tf.summary.scalar(self.name, tf.reduce_mean(feature))
if self.print_first_n:
encode_q = tf.regex_replace(query, self.separator, ' ')
encode_t = tf.regex_replace(query, self.separator, ' ')
feature = tf.Print(
feature, [encode_q, encode_t, feature],
message=self.name,
first_n=self.print_first_n,
summarize=self.summarize)
if self.norm_fn is not None:
fn = eval(self.norm_fn)
feature = fn(feature)
tf.summary.scalar('normalized_%s' % self.name, tf.reduce_mean(feature))
if self.print_first_n:
feature = tf.Print(
feature, [feature],
message='normalized %s' % self.name,
first_n=self.print_first_n,
summarize=self.summarize)
if self.boundaries:
feature = self.bucketize(feature, boundaries=self.boundaries)
tf.summary.histogram('bucketized_%s' % self.name, feature)
if self.emb_dim > 0 and self.boundaries:
vocab_size = len(self.boundaries) + 1
one_hot_input_ids = tf.one_hot(feature, depth=vocab_size)
return tf.matmul(one_hot_input_ids, self.embedding_table)
return tf.expand_dims(feature, axis=-1)


class OverlapFeature(Layer):

def __init__(self, params, name='overlap_feature', reuse=None, **kwargs):
super(OverlapFeature, self).__init__(name, **kwargs)
self.overlap_feature = custom_ops.overlap_fg_op
self.bucketize = custom_ops.my_bucketize
self.method = params.get_or_default('method', 'is_contain')
self.norm_fn = params.get_or_default('normalize_fn', None)
self.boundaries = list(params.get_or_default('boundaries', []))
self.separator = params.get_or_default('separator', '\035')
self.default_value = params.get_or_default('default_value', '-1')
self.emb_dim = params.get_or_default('embedding_dim', 0)
self.print_first_n = params.get_or_default('print_first_n', 0)
self.summarize = params.get_or_default('summarize', None)
if self.emb_dim > 0:
vocab_size = len(self.boundaries) + 1
with tf.variable_scope(self.name, reuse=reuse):
self.embedding_table = tf.get_variable(
name='overlap_emb_table',
shape=[vocab_size, self.emb_dim],
dtype=tf.float32)

def call(self, inputs, training=None, **kwargs):
query, title = inputs[:2]
fea_name = '%s_%s' % (self.name, self.method)
with ops.device('/CPU:0'):
feature = self.overlap_feature(
query=query,
title=title,
feature_name=fea_name,
separator=self.separator,
default_value=self.default_value,
method=self.method)
tf.summary.scalar(fea_name, tf.reduce_mean(feature))
if self.print_first_n:
encode_q = tf.regex_replace(query, self.separator, ' ')
encode_t = tf.regex_replace(query, self.separator, ' ')
feature = tf.Print(
feature, [encode_q, encode_t, feature],
message='%s %s' % (self.name, self.method),
first_n=self.print_first_n,
summarize=self.summarize)
if self.norm_fn is not None:
fn = eval(self.norm_fn)
feature = fn(feature)
tf.summary.scalar('normalized_' + fea_name, tf.reduce_mean(feature))
if self.print_first_n:
feature = tf.Print(
feature, [feature],
message='normalized_%s' % fea_name,
first_n=self.print_first_n,
summarize=self.summarize)
if self.boundaries:
feature = self.bucketize(feature, boundaries=self.boundaries)
tf.summary.histogram('bucketized_%s' % fea_name, feature)
if self.emb_dim > 0 and self.boundaries:
vocab_size = len(self.boundaries) + 1
one_hot_input_ids = tf.one_hot(feature, depth=vocab_size)
return tf.matmul(one_hot_input_ids, self.embedding_table)
return tf.expand_dims(feature, axis=-1)


class EditDistance(Layer):

def __init__(self, params, name='edit_distance', reuse=None, **kwargs):
super(EditDistance, self).__init__(name, **kwargs)
logging.info("ops_dir is %s" % ops_dir)
custom_op_path = os.path.join(ops_dir, 'libedit_distance.so')
try:
custom_ops = tf.load_op_library(custom_op_path)
logging.info('load edit_distance op from %s succeed' % custom_op_path)
except Exception as ex:
logging.warning('load edit_distance op from %s failed: %s' %
(custom_op_path, str(ex)))
custom_ops = None
self.edit_distance = custom_ops.my_edit_distance

self.txt_encoding = params.get_or_default('text_encoding', 'utf-8')
self.emb_size = params.get_or_default('embedding_size', 512)
emb_dim = params.get_or_default('embedding_dim', 4)
Expand All @@ -93,11 +226,11 @@ def call(self, inputs, training=None, **kwargs):
input1, input2 = inputs[:2]
with ops.device('/CPU:0'):
dist = self.edit_distance(
input1,
input2,
normalize=False,
dtype=tf.int32,
encoding=self.txt_encoding)
input1,
input2,
normalize=False,
dtype=tf.int32,
encoding=self.txt_encoding)
ids = tf.clip_by_value(dist, 0, self.emb_size - 1)
embed = tf.nn.embedding_lookup(self.embedding_table, ids)
return embed
2 changes: 1 addition & 1 deletion easy_rec/python/layers/keras/mask_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, params, name='mask_block', reuse=None, **kwargs):

def build(self, input_shape):
assert type(input_shape) in (tuple, list) and len(input_shape) >= 2,\
'MaskBlock must has at least two inputs'
'MaskBlock must has at least two inputs'
input_dim = int(input_shape[0][-1])
mask_input_dim = int(input_shape[1][-1])
if self.config.HasField('reduction_factor'):
Expand Down
5 changes: 4 additions & 1 deletion easy_rec/python/model/rank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def _output_to_prediction_impl(self,
else:
probs = tf.nn.softmax(output, axis=1)
prediction_dict['logits' + suffix] = output
prediction_dict['logits' + suffix + "_1"] = output[:, 1]
prediction_dict['probs' + suffix] = probs
prediction_dict['probs' + suffix + "_1"] = probs[:, 1]
prediction_dict['logits' + suffix + '_y'] = math_ops.reduce_max(
output, axis=1)
prediction_dict['probs' + suffix + '_y'] = math_ops.reduce_max(
Expand Down Expand Up @@ -416,7 +418,8 @@ def _get_outputs_impl(self, loss_type, num_class=1, suffix=''):
else:
return [
'y' + suffix, 'probs' + suffix, 'logits' + suffix,
'probs' + suffix + '_y', 'logits' + suffix + '_y'
'probs' + suffix + '_y', 'logits' + suffix + '_y',
'probs' + suffix + '_1', 'logits' + suffix + '_1'
]
elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
return ['y' + suffix]
Expand Down
Binary file removed easy_rec/python/ops/1.12/libedit_distance.so
Binary file not shown.
Binary file modified easy_rec/python/ops/1.12_pai/libcustom_ops.so
Binary file not shown.
Binary file removed easy_rec/python/ops/1.12_pai/libedit_distance.so
Binary file not shown.
Binary file removed easy_rec/python/ops/1.15/libedit_distance.so
Binary file not shown.
2 changes: 2 additions & 0 deletions easy_rec/python/protos/keras_layer.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,7 @@ message KerasLayer {
PPNet ppnet = 16;
TextCNN text_cnn = 17;
HighWayTower highway = 18;
OverlapFeature overlap = 19;
MappedDotProduct dot_product = 20;
}
}
21 changes: 21 additions & 0 deletions easy_rec/python/protos/layer.proto
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,24 @@ message TextCNN {
optional string activation = 4 [default = 'relu'];
optional MLP mlp = 5;
}

message OverlapFeature {
optional string separator = 1;
optional string default_value = 2;
required string method = 3 [default = 'is_contain'];
optional string normalize_fn = 4;
repeated float boundaries = 5;
optional int32 embedding_dim = 6;
optional int32 print_first_n = 7 [default = 0];
optional int32 summarize = 8;
}

message MappedDotProduct {
optional string separator = 1;
optional float default_value = 2;
optional string normalize_fn = 3;
repeated float boundaries = 4;
optional int32 embedding_dim = 5;
optional int32 print_first_n = 6 [default = 0];
optional int32 summarize = 7;
}
2 changes: 1 addition & 1 deletion easy_rec/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.

__version__ = '0.8.1'
__version__ = '0.8.2'
Loading

0 comments on commit f2f5b74

Please sign in to comment.