Skip to content

Commit

Permalink
Change the "batched" kwarg to "suffix" which have default behavior sa…
Browse files Browse the repository at this point in the history
…me as the previous version. Add detailed documentation for the "suffix" kwarg.
  • Loading branch information
lukewys committed Nov 16, 2021
1 parent 3562a19 commit 70def03
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions ddsp/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def __init__(self,
base_dir,
instrument_key='tpt',
split='train',
batched='batched'):
suffix=None):
"""URMP dataset for either a specific instrument or all instruments.
Args:
Expand All @@ -424,22 +424,30 @@ def __init__(self,
['all', 'bn', 'cl', 'db', 'fl', 'hn', 'ob', 'sax', 'tba', 'tbn',
'tpt', 'va', 'vc', 'vn'].
split: Choices include ['train', 'test'].
batched: Choices include ['batched', 'unbatched'].
suffix: Choices include [None, 'batched', 'unbatched'], but broadly
applies to any suffix adding to the file pattern.
When suffix is not None, will add "_suffix" to the file pattern.
This option is used in gs://magentadata/datasets/urmp/urmp_20210324.
With the "batched" suffix, the dataloader will load tfrecords
containing segmented audio samples in 4 seconds. With the "unbatched"
suffix, the dataloader will load tfrecords containing unsegmented
samples which could be used for learning note sequence in URMP dataset.
"""
self.instrument_key = instrument_key
self.split = split
self.base_dir = base_dir
self.batched = batched
self.suffix = suffix if suffix is None else '_' + suffix
super().__init__()

@property
def default_file_pattern(self):
if self.instrument_key == 'all':
file_pattern = 'all_instruments_{}_{}.tfrecord*'.format(
self.split, self.batched)
file_pattern = 'all_instruments_{}{}.tfrecord*'.format(
self.split, self.suffix)
else:
file_pattern = 'urmp_{}_solo_ddsp_conditioning_{}_{}.tfrecord*'.format(
self.instrument_key, self.split, self.batched)
file_pattern = 'urmp_{}_solo_ddsp_conditioning_{}{}.tfrecord*'.format(
self.instrument_key, self.split, self.suffix)

return os.path.join(self.base_dir, file_pattern)

Expand Down

0 comments on commit 70def03

Please sign in to comment.