Skip to content

Commit

Permalink
Add special color kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew-S-Rosen committed Feb 3, 2022
1 parent 0a2f259 commit 382140b
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions ptable_trends.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ def ptable_plotter(
cbar_height: float = None,
cbar_standoff: int = 12,
cbar_fontsize: int = 14,
blank_color: str = "#140F0E",
blank_color: str = "#c4c4c4",
under_value: float = None,
under_color: str = "#140F0E",
over_value: float = None,
over_color: str = "#c4c4c4",
over_color: str = "#140F0E",
special_elements: List[str] = None,
special_color: str = "#6F3023",
) -> figure:

"""
Expand Down Expand Up @@ -79,7 +81,10 @@ def ptable_plotter(
Values >= over_value will be colored with over_color.
under_color : str
Hexadecial color to be used for the upper bound color.
special_elements: List[str]
List of elements to be colored with special_color.
special_color: str
Hexadecimal color to be used for the special elements.
Returns
-------
figure
Expand Down Expand Up @@ -163,9 +168,7 @@ def ptable_plotter(
color_scale = ScalarMappable(norm=norm, cmap=cmap).to_rgba(data, alpha=None)

# Set blank color
color_list = []
for i in range(len(elements)):
color_list.append(blank_color)
color_list = [blank_color] * len(elements)

# Compare elements in dataset with elements in periodic table
for i, data_element in enumerate(data_elements):
Expand All @@ -178,13 +181,17 @@ def ptable_plotter(
warnings.warn("Invalid chemical symbol: " + data_element)
if color_list[element_index] != blank_color:
warnings.warn("Multiple entries for element " + data_element)
if under_value is not None and data[i] <= under_value:
elif under_value is not None and data[i] <= under_value:
color_list[element_index] = under_color
elif over_value is not None and data[i] >= over_value:
color_list[element_index] = over_color
else:
color_list[element_index] = to_hex(color_scale[i])

for k, v in elements["symbol"].iteritems():
if v in special_elements:
color_list[k] = special_color

# Define figure properties for visualizing data
source = ColumnDataSource(
data=dict(
Expand Down Expand Up @@ -248,6 +255,3 @@ def ptable_plotter(
show_(p)

return p


ptable_plotter("ptable.csv", show=True, over_value=0)

0 comments on commit 382140b

Please sign in to comment.