Skip to content

Commit

Permalink
add test case
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Dec 27, 2023
1 parent 39d3c05 commit 0d83b0b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
4 changes: 3 additions & 1 deletion docs/source/intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,6 @@ EasyRec implements state of the art machine learning models used in common recom

- Run [`knn algorithm`](vector_retrieve.md) of vectors in distribute environment

欢迎加入【EasyRec推荐算法交流群】,钉钉群号 : 32260796
### Contact

- DingDing Group: 37930014162, click [this url](https://qr.dingtalk.com/action/joingroup?code=v1,k1,oHNqtNObbu+xUClHh77gCuKdGGH8AYoQ8AjKU23zTg4=&_dt_no_comment=1&origin=11) or scan QrCode to join![new_group.jpg](../images/qrcode/new_group.jpg)
23 changes: 13 additions & 10 deletions easy_rec/python/layers/keras/fibinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(self, params, name='SENet', reuse=None, **kwargs):
if tf.__version__ >= '2.0':
self.layer_norm = tf.keras.layers.LayerNormalization()
else:
self.layer_norm = lambda x: layer_norm(
with tf.name_scope(self.name):
self.layer_norm = lambda x: layer_norm(
x, name='ln_output', reuse=self.reuse)

def build(self, input_shape):
Expand All @@ -52,13 +53,14 @@ def build(self, input_shape):
r = self.config.reduction_ratio
field_size = len(input_shape)
reduction_size = max(1, field_size * g * 2 // r)
name_scope = '' if tf.__version__ >= '2.0' else self.name + "/"
self.reduce_layer = Dense(
units=reduction_size,
activation='relu',
kernel_initializer='he_normal',
name='W1')
name=name_scope + 'W1')
self.excite_layer = Dense(
units=emb_size, kernel_initializer='glorot_normal', name='W2')
units=emb_size, kernel_initializer='glorot_normal', name=name_scope + 'W2')

def call(self, inputs, **kwargs):
g = self.config.num_squeeze_group
Expand Down Expand Up @@ -137,7 +139,7 @@ def __init__(self, params, name='bilinear', reuse=None, **kwargs):
self.func = _full_interaction
else:
self.func = tf.multiply
self.output_layer = Dense(self.output_size, name='output')
self.output_layer = Dense(self.output_size, name=self.name + '/output')

def build(self, input_shape):
if type(input_shape) not in (tuple, list):
Expand All @@ -158,16 +160,17 @@ def build(self, input_shape):
)
dim = int(_dim)

name_scope = '' if tf.__version__ >= '2.0' else self.name + "/"
if self.bilinear_type == 'all':
self.dot_layer = Dense(dim, name='all')
self.dot_layer = Dense(dim, name=name_scope + 'all')
elif self.bilinear_type == 'each':
self.dot_layers = [
Dense(dim, name='each_%d' % i) for i in range(field_num - 1)
Dense(dim, name=name_scope + 'each_%d' % i) for i in range(field_num - 1)
]
else: # interaction
self.dot_layers = [
Dense(
units=int(input_shape[j][-1]), name='interaction_%d_%d' % (i, j))
units=int(input_shape[j][-1]), name=name_scope+'interaction_%d_%d' % (i, j))
for i, j in itertools.combinations(range(field_num), 2)
]

Expand Down Expand Up @@ -216,17 +219,17 @@ def __init__(self, params, name='fibinet', reuse=None, **kwargs):
self._config = params.get_pb_config()

se_params = Parameter.make_from_pb(self._config.senet)
self.senet_layer = SENet(se_params, name='senet', reuse=self.reuse)
self.senet_layer = SENet(se_params, name=self.name+'/senet', reuse=self.reuse)

if self._config.HasField('bilinear'):
bi_params = Parameter.make_from_pb(self._config.bilinear)
self.bilinear_layer = BiLinear(
bi_params, name='bilinear', reuse=self.reuse)
bi_params, name=self.name+'/bilinear', reuse=self.reuse)

if self._config.HasField('mlp'):
p = Parameter.make_from_pb(self._config.mlp)
p.l2_regularizer = params.l2_regularizer
self.final_mlp = MLP(p, name=name, reuse=reuse)
self.final_mlp = MLP(p, name=self.name+'/mlp', reuse=reuse)
else:
self.final_mlp = None

Expand Down
13 changes: 8 additions & 5 deletions easy_rec/python/layers/keras/mask_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,37 +42,40 @@ def build(self, input_shape):
raise ValueError(
'Need one of reduction factor or aggregation size for MaskBlock.')

name_scope = '' if tf.__version__ >= '2.0' else self.name + "/"
self.weight_layer = Dense(
aggregation_size,
activation='relu',
kernel_initializer='he_uniform',
kernel_regularizer=self.l2_reg,
name='weight')
self.mask_layer = Dense(input_dim, name='mask')
name=name_scope + 'weight')
self.mask_layer = Dense(input_dim, name=name_scope+'mask')
if self._projection_dim is not None:
self.project_layer = Dense(
self._projection_dim,
kernel_regularizer=self.l2_reg,
use_bias=False,
name='project')
name=name_scope + 'project')
if self.config.input_layer_norm:
# 推荐在调用MaskBlock之前做好 layer norm,否则为了reuse layer参数会产生scope之外的name
if tf.__version__ >= '2.0':
self.input_layer_norm = tf.keras.layers.LayerNormalization()
self.output_layer_norm = tf.keras.layers.LayerNormalization()
else:
# to share input layer norm parameters
idx = self.name.rfind('_')
if idx > 0 and self.name[idx + 1:].isdigit():
input_name = self.name[:idx]
else:
input_name = self.name
self.input_layer_norm = lambda x: layer_norm(
x, name=input_name + '/input_ln', reuse=tf.AUTO_REUSE)
self.output_layer_norm = lambda x: layer_norm(
with tf.name_scope(self.name):
self.output_layer_norm = lambda x: layer_norm(
x, name='ln_output', reuse=self.reuse)
if self.config.HasField('output_size'):
self.output_layer = Dense(
self.config.output_size, use_bias=False, name='output')
self.config.output_size, use_bias=False, name=name_scope + 'output')

def call(self, inputs, **kwargs):
net, mask_input = inputs
Expand Down

0 comments on commit 0d83b0b

Please sign in to comment.