Skip to content

Commit

Permalink
Merge pull request beetbox#5063 from Maxr1998/fix-advancedrewrite-sim…
Browse files Browse the repository at this point in the history
…ple-rules

advancedrewrite: Fix simple rules being overwritten by advanced rules
  • Loading branch information
Serene-Arc authored Mar 1, 2024
2 parents fa8b120 + b1d9169 commit 8720d64
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 34 deletions.
51 changes: 17 additions & 34 deletions beetsplug/advancedrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,22 @@
from beets.ui import UserError


def simple_rewriter(field, rules):
def rewriter(field, simple_rules, advanced_rules):
"""Template field function factory.
Create a template field function that rewrites the given field
with the given rewriting rules.
``rules`` must be a list of (pattern, replacement) pairs.
``simple_rules`` must be a list of (pattern, replacement) pairs.
``advanced_rules`` must be a list of (query, replacement) pairs.
"""

def fieldfunc(item):
value = item._values_fixed[field]
for pattern, replacement in rules:
for pattern, replacement in simple_rules:
if pattern.match(value.lower()):
# Rewrite activated.
return replacement
# Not activated; return original value.
return value

return fieldfunc


def advanced_rewriter(field, rules):
"""Template field function factory.
Create a template field function that rewrites the given field
with the given rewriting rules.
``rules`` must be a list of (query, replacement) pairs.
"""

def fieldfunc(item):
value = item._values_fixed[field]
for query, replacement in rules:
for query, replacement in advanced_rules:
if query.match(item):
# Rewrite activated.
return replacement
Expand Down Expand Up @@ -97,8 +82,12 @@ def __init__(self):
}

# Gather all the rewrite rules for each field.
simple_rules = defaultdict(list)
advanced_rules = defaultdict(list)
class RulesContainer:
def __init__(self):
self.simple = []
self.advanced = []

rules = defaultdict(RulesContainer)
for rule in self.config.get(template):
if "match" not in rule:
# Simple syntax
Expand All @@ -124,12 +113,12 @@ def __init__(self):
f"for field {fieldname}"
)
pattern = re.compile(pattern.lower())
simple_rules[fieldname].append((pattern, value))
rules[fieldname].simple.append((pattern, value))

# Apply the same rewrite to the corresponding album field.
if fieldname in corresponding_album_fields:
album_fieldname = corresponding_album_fields[fieldname]
simple_rules[album_fieldname].append((pattern, value))
rules[album_fieldname].simple.append((pattern, value))
else:
# Advanced syntax
match = rule["match"]
Expand Down Expand Up @@ -168,24 +157,18 @@ def __init__(self):
f"for field {fieldname}"
)

advanced_rules[fieldname].append((query, replacement))
rules[fieldname].advanced.append((query, replacement))

# Apply the same rewrite to the corresponding album field.
if fieldname in corresponding_album_fields:
album_fieldname = corresponding_album_fields[fieldname]
advanced_rules[album_fieldname].append(
rules[album_fieldname].advanced.append(
(query, replacement)
)

# Replace each template field with the new rewriter function.
for fieldname, fieldrules in simple_rules.items():
getter = simple_rewriter(fieldname, fieldrules)
self.template_fields[fieldname] = getter
if fieldname in Album._fields:
self.album_template_fields[fieldname] = getter

for fieldname, fieldrules in advanced_rules.items():
getter = advanced_rewriter(fieldname, fieldrules)
for fieldname, fieldrules in rules.items():
getter = rewriter(fieldname, fieldrules.simple, fieldrules.advanced)
self.template_fields[fieldname] = getter
if fieldname in Album._fields:
self.album_template_fields[fieldname] = getter
25 changes: 25 additions & 0 deletions test/plugins/test_advancedrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,31 @@ def test_fail_when_rewriting_single_valued_field_with_list(self):
):
self.load_plugins(PLUGIN_NAME)

def test_combined_rewrite_example(self):
self.config[PLUGIN_NAME] = [
{"artist A": "B"},
{
"match": "album:'C'",
"replacements": {
"artist": "D",
},
},
]
self.load_plugins(PLUGIN_NAME)

item = self.add_item(
artist="A",
albumartist="A",
)
self.assertEqual(item.artist, "B")

item = self.add_item(
artist="C",
albumartist="C",
album="C",
)
self.assertEqual(item.artist, "D")


def suite():
return unittest.TestLoader().loadTestsFromName(__name__)
Expand Down

0 comments on commit 8720d64

Please sign in to comment.