Skip to content

Commit

Permalink
Bug fix of SENet when run with MirroredStrategy (#443)
Browse files Browse the repository at this point in the history
* fix bug of SENet when run with tf.distribute.MirroredStrategy
* add LayerNormalization Layer
  • Loading branch information
yangxudong authored Dec 28, 2023
1 parent e3aa70d commit cb806b2
Show file tree
Hide file tree
Showing 15 changed files with 1,479 additions and 175 deletions.
7 changes: 5 additions & 2 deletions docs/source/component/backbone.md
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,7 @@ MovieLens-1M数据集效果:

- Highway Network: [highway network](../models/highway.md)
- Cross Decoupling Network: [CDN](../models/cdn.md)
- DLRM+SENet: [dlrm_senet_on_criteo.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/dlrm_senet_on_criteo.config)

# 组件库介绍

Expand Down Expand Up @@ -1178,8 +1179,10 @@ def call(self, inputs, training=None, **kwargs):
【可选】如需要自定义protobuf message参数,先在`easy_rec/python/protos/layer.proto`添加参数message的定义,
再把参数注册到定义在`easy_rec/python/protos/keras_layer.proto``KerasLayer.params`消息体中。

`__init__`方法的`reuse`参数表示该Layer对象的权重参数是否需要被复用。开发时需要按照可复用的逻辑来实现Layer对象,推荐严格按照keras layer的规范来实现。
尽量在`__init__`方法中声明需要依赖的keras layer对象;仅在必要时才使用`tf.layers.*`函数,且需要传递reuse参数。
`__init__`方法的`reuse`参数表示该Layer对象的权重参数是否需要被复用。
开发时需要按照可复用的逻辑来实现Layer对象,推荐严格按照keras layer的规范来实现。
推荐在`__init__`方法中声明需要依赖的keras layer对象;
**非常不建议使用`tf.layers.*`函数,因为可能会在使用`DistributeStrategy`时出错**,如一定要用需要传递reuse参数。

```{tips}
提示:实现Layer对象时尽量使用原生的 tf.keras.layers.* 对象,且全部在 __init__ 方法中预先声明好。
Expand Down
19 changes: 19 additions & 0 deletions docs/source/develop.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,25 @@ pre-commit install
pre-commit run -a
```

### 增加新的运行命令(cmd)

增加新的运行命令需要修改`xflow`的配置和脚本,文件位置:

- 弹内用户: `pai_jobs/easy_rec_flow`
- 公有云用户:`pai_jobs/easy_rec_flow_ex`

升级xflow对外发布之前,需要严格测试,影响面很大,会影响所有用户。

更建议的方式是不增加新的运行命令,新增功能通过`cmd=custom`命令来运行,通过`entryFile`参数指定新增功能的运行脚本,
需要额外参数时,通过`extra_params`参数传递。示例如下:

```
pai -name easy_rec_ext
-Dcmd='custom'
-DentryFile='easy_rec/python/tools/feature_selection.py'
-Dextra_params='--topk 1000'
```

### 测试

#### 单元测试
Expand Down
4 changes: 3 additions & 1 deletion docs/source/intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,6 @@ EasyRec implements state of the art machine learning models used in common recom

- Run [`knn algorithm`](vector_retrieve.md) of vectors in distribute environment

欢迎加入【EasyRec推荐算法交流群】,钉钉群号 : 32260796
### Contact

- DingDing Group: 37930014162, click [this url](https://qr.dingtalk.com/action/joingroup?code=v1,k1,oHNqtNObbu+xUClHh77gCuKdGGH8AYoQ8AjKU23zTg4=&_dt_no_comment=1&origin=11) or scan QrCode to join![new_group.jpg](../images/qrcode/new_group.jpg)
2 changes: 1 addition & 1 deletion easy_rec/python/compat/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import os
import threading
import time
from distutils.version import LooseVersion

import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
Expand Down
23 changes: 14 additions & 9 deletions easy_rec/python/layers/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, config, features, input_layer, l2_reg=None):
self._block_outputs = {}
self._package_input = None
reuse = None if config.name == 'backbone' else tf.AUTO_REUSE
input_feature_groups = set()
input_feature_groups = {}

for block in config.blocks:
if len(block.inputs) == 0:
Expand All @@ -71,9 +71,9 @@ def __init__(self, config, features, input_layer, l2_reg=None):
if group in input_feature_groups:
logging.warning('input `%s` already exists in other block' % group)
else:
input_feature_groups.add(group)
input_fn = EnhancedInputLayer(self._input_layer, self._features,
group, reuse)
input_feature_groups[group] = input_fn
self._name_to_layer[block.name] = input_fn
else:
self.define_layers(layer, block, block.name, reuse)
Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(self, config, features, input_layer, l2_reg=None):
if iname in self._name_to_blocks:
assert iname != name, 'input name can not equal to block name:' + iname
self._dag.add_edge(iname, name)
elif iname not in input_feature_groups:
else:
is_fea_group = input_type == 'feature_group_name'
if is_fea_group and input_layer.has_group(iname):
logging.info('adding an input_layer block: ' + iname)
Expand All @@ -129,8 +129,11 @@ def __init__(self, config, features, input_layer, l2_reg=None):
self._name_to_blocks[iname] = new_block
self._dag.add_node(iname)
self._dag.add_edge(iname, name)
input_feature_groups.add(iname)
fn = EnhancedInputLayer(self._input_layer, self._features, iname)
if iname in input_feature_groups:
fn = input_feature_groups[iname]
else:
fn = EnhancedInputLayer(self._input_layer, self._features, iname)
input_feature_groups[iname] = fn
self._name_to_layer[iname] = fn
elif Package.has_backbone_block(iname):
backbone = Package.__packages['backbone']
Expand Down Expand Up @@ -225,8 +228,9 @@ def block_input(self, config, block_outputs, training=None, **kwargs):
fn = eval('lambda x: x' + input_node.input_slice.strip())
input_feature = fn(input_feature)
if input_node.HasField('input_fn'):
fn = eval(input_node.input_fn)
input_feature = fn(input_feature)
with tf.name_scope(config.name):
fn = eval(input_node.input_fn)
input_feature = fn(input_feature)
inputs.append(input_feature)

if config.merge_inputs_into_list:
Expand Down Expand Up @@ -372,8 +376,9 @@ def call_layer(self, inputs, config, name, training, **kwargs):
fn = eval('lambda x, i: x' + conf.input_slice.strip())
ly_inputs = fn(ly_inputs, i)
if conf.HasField('input_fn'):
fn = eval(conf.input_fn)
ly_inputs = fn(ly_inputs, i)
with tf.name_scope(config.name):
fn = eval(conf.input_fn)
ly_inputs = fn(ly_inputs, i)
output = self.call_keras_layer(ly_inputs, name_i, training, **kwargs)
outputs.append(output)
if len(outputs) == 1:
Expand Down
Loading

0 comments on commit cb806b2

Please sign in to comment.