Skip to content

Commit

Permalink
Fix indices sometimes formatted as float
Browse files Browse the repository at this point in the history
  • Loading branch information
alugowski committed Aug 29, 2023
1 parent f8c6f56 commit cd41b96
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 11 deletions.
12 changes: 6 additions & 6 deletions matrepr/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,28 +242,28 @@ def get_row_labels(self) -> Optional[Iterable[Optional[int]]]:
return None

if self.dot_row is None:
return self.orig_row_labels if self.orig_row_labels else [str(i) for i in range(self.orig_shape[0])]
return self.orig_row_labels if self.orig_row_labels else list(range(self.orig_shape[0]))
else:
pre_dot_end, post_dot_start = self.get_dot_indices_row()
if self.orig_row_labels is None:
# generate indices
# noinspection PyTypeChecker
return [str(i) for i in range(pre_dot_end)] + [None]\
+ [str(i) for i in range(post_dot_start, self.orig_shape[0])]
return list(range(pre_dot_end)) + [None]\
+ list(range(post_dot_start, self.orig_shape[0]))
else:
return self.orig_row_labels[:pre_dot_end] + [None] \
+ self.orig_row_labels[post_dot_start:self.orig_shape[0]]

def get_col_labels(self) -> Iterable[Optional[int]]:
if self.dot_col is None:
return self.orig_col_labels if self.orig_col_labels else [str(i) for i in range(self.orig_shape[1])]
return self.orig_col_labels if self.orig_col_labels else list(range(self.orig_shape[1]))
else:
pre_dot_end, post_dot_start = self.get_dot_indices_col()
if self.orig_col_labels is None:
# generate indices
# noinspection PyTypeChecker
return [str(i) for i in range(pre_dot_end)] + [None]\
+ [str(i) for i in range(post_dot_start, self.orig_shape[1])]
return list(range(pre_dot_end)) + [None]\
+ list(range(post_dot_start, self.orig_shape[1]))
else:
return self.orig_col_labels[:pre_dot_end] + [None] \
+ self.orig_col_labels[post_dot_start:self.orig_shape[1]]
Expand Down
9 changes: 6 additions & 3 deletions matrepr/html_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ def _attributes_to_string(self, attributes: dict) -> str:
ret = " " + ret
return ret

def pprint(self, obj, current_indent=0):
def pprint(self, obj, current_indent=0, is_index=False):
if obj is None:
return ""

if is_index and isinstance(obj, int):
return int(obj)

if isinstance(obj, (int, float)):
return self.floatfmt(obj)

Expand Down Expand Up @@ -107,7 +110,7 @@ def _write_matrix(self, mat: MatrixAdapterRow,
self.write(f"<th></th>", indent=cell_indent)
for col_label in mat.get_col_labels():
attr = ' style="text-align: center;"' if self.center_header else ""
self.write(f"<th{attr}>{self.pprint(col_label)}</th>", indent=cell_indent)
self.write(f"<th{attr}>{self.pprint(col_label, is_index=True)}</th>", indent=cell_indent)
self.write("</tr>", indent=body_indent)
self.write("</thead>", indent=body_indent)

Expand All @@ -116,7 +119,7 @@ def _write_matrix(self, mat: MatrixAdapterRow,
for row_idx in range(nrows):
self.write("<tr>", body_indent)
if row_labels:
self.write(f"<th>{self.pprint(next(row_labels))}</th>", cell_indent)
self.write(f"<th>{self.pprint(next(row_labels), is_index=True)}</th>", cell_indent)

col_range = (0, ncols)
for col_idx, cell in enumerate(mat.get_dense_row(row_idx, col_range=col_range)):
Expand Down
5 changes: 4 additions & 1 deletion matrepr/latex_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,13 @@ def __init__(self, max_rows, max_cols, num_after_dots, title_latex, latex_matrix
self.floatfmt = lambda f: python_scientific_to_latex_times10(format(f))
self.indent_width = 4

def pprint(self, obj):
def pprint(self, obj, is_index=False):
if obj is None:
return ""

if is_index and isinstance(obj, int):
return int(obj)

if isinstance(obj, (int, float)):
return self.floatfmt(obj)

Expand Down
12 changes: 11 additions & 1 deletion matrepr/list_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file.
# SPDX-License-Identifier: BSD-2-Clause

from typing import List, Optional

from .adapters import MatrixAdapter, MatrixAdapterRow, Truncated2DMatrix, to_trunc, DupeList
from .base_formatter import unicode_dots

Expand Down Expand Up @@ -30,6 +32,12 @@ def _single_line(s: str) -> str:
return " ".join(lines)


def _to_str(iterable) -> Optional[List]:
if iterable is None:
return None
return [None if x is None else str(x) for x in iterable]


class ListConverter:
def __init__(self, max_rows, max_cols, num_after_dots, floatfmt, **_):
super().__init__()
Expand Down Expand Up @@ -119,4 +127,6 @@ def to_lists_and_labels(self, mat: MatrixAdapter, is_1d_ok=True):
mat.get_shape()[1] > self.max_cols:
mat = to_trunc(mat, self.max_rows, self.max_cols, self.num_after_dots)

return self._write_matrix(mat, is_vector=is_vector), mat.get_row_labels(), mat.get_col_labels()
return self._write_matrix(mat, is_vector=is_vector),\
_to_str(mat.get_row_labels()),\
_to_str(mat.get_col_labels())

0 comments on commit cd41b96

Please sign in to comment.