Skip to content

Commit

Permalink
Fix median cut frequencies not summing to one
Browse files Browse the repository at this point in the history
  • Loading branch information
qTipTip committed Jul 7, 2024
1 parent 2781630 commit 4d18e46
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 37 deletions.
4 changes: 2 additions & 2 deletions Pylette/src/color_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ def median_cut_extraction(arr: np.ndarray, height: int, width: int, palette_size

arr = arr.reshape((width * height, -1))
c = [ColorBox(arr)]
full_box_size = c[0].size

# Each iteration, find the largest box, split it, remove original box from list of boxes, and add the two new boxes.
while len(c) < palette_size:
largest_c_idx = np.argmax(c)
# add the two new boxes to the list, while removing the split box.
c = c[:largest_c_idx] + c[largest_c_idx].split() + c[largest_c_idx + 1 :]

colors = [Color(tuple(map(int, box.average)), box.size / full_box_size) for box in c]
total_pixels = width * height
colors = [Color(tuple(map(int, box.average)), box.pixel_count / total_pixels) for box in c]

return colors

Expand Down
10 changes: 10 additions & 0 deletions Pylette/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,13 @@ def split(self) -> list["ColorBox"]:
ColorBox(self.colors[:median_index]),
ColorBox(self.colors[median_index:]),
]

@property
def pixel_count(self) -> int:
"""
Returns the number of pixels in the ColorBox.
Returns:
int: The number of pixels in the ColorBox.
"""
return len(self.colors)
40 changes: 5 additions & 35 deletions tests/integration/test_colorspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,7 @@ def test_kmean_extracted_palette(test_image_path_as_str):
@pytest.mark.parametrize("palette_size", [1, 5, 10, 100])
@pytest.mark.parametrize(
"extraction_mode",
[
"KM",
pytest.param(
"MC",
marks=pytest.mark.skip("Currently a bug in the MC algorithm, causing frequencies not summing to one"),
),
],
["KM", "MC"],
)
def test_palette_invariants_with_image_path(test_image_path_as_str, palette_size, extraction_mode):
palette = extract_colors(
Expand Down Expand Up @@ -92,13 +86,7 @@ def test_palette_invariants_with_image_path(test_image_path_as_str, palette_size
@pytest.mark.parametrize("palette_size", [1, 5, 10, 100])
@pytest.mark.parametrize(
"extraction_mode",
[
"KM",
pytest.param(
"MC",
marks=pytest.mark.skip("Currently a bug in the MC algorithm, causing frequencies not summing to one"),
),
],
["KM", "MC"],
)
def test_palette_invariants_with_image_pathlike(test_image_path_as_pathlike, palette_size, extraction_mode):
palette = extract_colors(
Expand Down Expand Up @@ -131,13 +119,7 @@ def test_palette_invariants_with_image_pathlike(test_image_path_as_pathlike, pal
@pytest.mark.parametrize("palette_size", [1, 5, 10, 100])
@pytest.mark.parametrize(
"extraction_mode",
[
"KM",
pytest.param(
"MC",
marks=pytest.mark.skip("Currently a bug in the MC algorithm, causing frequencies not summing to one"),
),
],
["KM", "MC"],
)
def test_palette_invariants_with_image_bytes(test_image_as_bytes, palette_size, extraction_mode):
palette = extract_colors(
Expand Down Expand Up @@ -170,13 +152,7 @@ def test_palette_invariants_with_image_bytes(test_image_as_bytes, palette_size,
@pytest.mark.parametrize("palette_size", [1, 5, 10, 100])
@pytest.mark.parametrize(
"extraction_mode",
[
"KM",
pytest.param(
"MC",
marks=pytest.mark.skip("Currently a bug in the MC algorithm, causing frequencies not summing to one"),
),
],
["KM", "MC"],
)
def test_palette_invariants_with_opencv(test_image_from_opencv, palette_size, extraction_mode):
palette = extract_colors(
Expand Down Expand Up @@ -209,13 +185,7 @@ def test_palette_invariants_with_opencv(test_image_from_opencv, palette_size, ex
@pytest.mark.parametrize("palette_size", [1, 5, 10, 100])
@pytest.mark.parametrize(
"extraction_mode",
[
"KM",
pytest.param(
"MC",
marks=pytest.mark.skip("Currently a bug in the MC algorithm, causing frequencies not summing to one"),
),
],
["KM", "MC"],
)
def test_palette_invariants_with_image_url(test_image_as_url, palette_size, extraction_mode):
palette = extract_colors(
Expand Down

0 comments on commit 4d18e46

Please sign in to comment.