Skip to content

Commit

Permalink
Fix ImportError (#385)
Browse files Browse the repository at this point in the history
* Change BrokenBarHCollection to PolyCollection

* Remove unused import

* np.Inf -> np.inf

* Use `ann.set_clip_path`

* Set expected legend location

* Run black
  • Loading branch information
carlthome authored Aug 15, 2024
1 parent 485a425 commit 71bcb18
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 11 deletions.
8 changes: 2 additions & 6 deletions mir_eval/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from matplotlib.ticker import FuncFormatter, MultipleLocator
from matplotlib.ticker import Formatter
from matplotlib.colors import LinearSegmentedColormap, LogNorm, ColorConverter
from matplotlib.collections import BrokenBarHCollection
from matplotlib.transforms import Bbox, TransformedBbox

from .melody import freq_to_voicing
Expand Down Expand Up @@ -184,18 +183,15 @@ def segments(
seg_map[lab].pop("label", None)

if text:
bbox = Bbox.from_extents(ival[0], base, ival[1], height)
tbbox = TransformedBbox(bbox, transform)
ann = ax.annotate(
lab,
xy=(ival[0], height),
xycoords=transform,
xytext=(8, -10),
textcoords="offset points",
clip_path=rect,
clip_box=tbbox,
**text_kw
)
ann.set_clip_path(rect)

return ax

Expand Down Expand Up @@ -264,7 +260,7 @@ def labeled_intervals(
**kwargs
Additional keyword arguments to pass to
`matplotlib.collection.BrokenBarHCollection`.
`matplotlib.collection.PolyCollection`.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion mir_eval/separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ def _safe_db(num, den):
be 0.
"""
if den == 0:
return np.Inf
return np.inf
return 10 * np.log10(num / den)


Expand Down
8 changes: 4 additions & 4 deletions tests/test_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_display_labeled_intervals_compare_noextend():
est_int, est_labels, extend_labels=False, alpha=0.5, label="Estimate"
)

plt.legend()
plt.legend(loc="upper right")
return plt.gcf()


Expand All @@ -178,7 +178,7 @@ def test_display_labeled_intervals_compare_common():
est_int, est_labels, label_set=label_set, alpha=0.5, label="Estimate"
)

plt.legend()
plt.legend(loc="upper right")
return plt.gcf()


Expand Down Expand Up @@ -344,7 +344,7 @@ def test_display_piano_roll():
est_t, est_p, label="Estimate", alpha=0.5, facecolor="r"
)

plt.legend()
plt.legend(loc="upper right")
return plt.gcf()


Expand All @@ -367,7 +367,7 @@ def test_display_piano_roll_midi():
est_t, midi=est_midi, label="Estimate", alpha=0.5, facecolor="r"
)

plt.legend()
plt.legend(loc="upper right")
return plt.gcf()


Expand Down

0 comments on commit 71bcb18

Please sign in to comment.