Skip to content

Commit

Permalink
fix bug of SeqAugment layer & and compatibility problem of Windows Op…
Browse files Browse the repository at this point in the history
…eration System (#520)


* fix bug of SeqAugment layer

* add compacity for windows users
  • Loading branch information
yangxudong authored Jan 23, 2025
1 parent bc38227 commit e33da24
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
20 changes: 18 additions & 2 deletions easy_rec/python/layers/keras/data_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,28 @@ def mask_fn():
def reorder_fn():
return item_reorder(seq, length, aug_param.reorder_rate)

method = tf.random.uniform([], minval=0, maxval=3, dtype=tf.int32)
trans_fn = []
if aug_param.crop_rate < 1.0:
trans_fn.append(crop_fn)
if aug_param.mask_rate > 0:
trans_fn.append(mask_fn)
if aug_param.reorder_rate > 0:
trans_fn.append(reorder_fn)

num_trans = len(trans_fn)
if num_trans == 0:
return seq, length

if num_trans == 1:
return trans_fn[0]()

method = tf.random.uniform([], minval=0, maxval=num_trans, dtype=tf.int32)
if num_trans == 2:
return tf.cond(tf.equal(method, 0), trans_fn[0], trans_fn[1])

aug_seq, aug_len = tf.cond(
tf.equal(method, 0), crop_fn,
lambda: tf.cond(tf.equal(method, 1), mask_fn, reorder_fn))

return aug_seq, aug_len


Expand Down
5 changes: 3 additions & 2 deletions git-lfs/git_lfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def get_yes_no(msg):
'usage: python git_lfs.py [pull] [push] [add filename] [resolve_conflict]'
)
sys.exit(1)
home_directory = os.path.expanduser("~")
with open('.git_oss_config_pub', 'r') as fin:
git_oss_data_dir = None
host = None
Expand All @@ -237,7 +238,7 @@ def get_yes_no(msg):
continue
if line_str.startswith('#'):
continue
line_str = line_str.replace('~/', os.environ['HOME'] + '/')
line_str = line_str.replace('~/', home_directory + '/')
line_str = line_str.replace('${TMPDIR}/',
os.environ.get('TMPDIR', '/tmp/'))
line_str = line_str.replace('${PROJECT_NAME}', get_proj_name())
Expand All @@ -251,7 +252,7 @@ def get_yes_no(msg):
elif line_tok[0] == 'git_oss_private_config':
git_oss_private_path = line_tok[1]
if git_oss_private_path.startswith('~/'):
git_oss_private_path = os.path.join(os.environ['HOME'],
git_oss_private_path = os.path.join(home_directory,
git_oss_private_path[2:])
elif line_tok[0] == 'git_oss_cache_dir':
git_oss_cache_dir = line_tok[1]
Expand Down

0 comments on commit e33da24

Please sign in to comment.