diff --git a/ptable_trends.py b/ptable_trends.py index 41eed17..da6c96d 100644 --- a/ptable_trends.py +++ b/ptable_trends.py @@ -31,11 +31,13 @@ def ptable_plotter( cbar_height: float = None, cbar_standoff: int = 12, cbar_fontsize: int = 14, - blank_color: str = "#140F0E", + blank_color: str = "#c4c4c4", under_value: float = None, under_color: str = "#140F0E", over_value: float = None, - over_color: str = "#c4c4c4", + over_color: str = "#140F0E", + special_elements: List[str] = None, + special_color: str = "#6F3023", ) -> figure: """ @@ -79,7 +81,10 @@ def ptable_plotter( Values >= over_value will be colored with over_color. under_color : str Hexadecial color to be used for the upper bound color. - + special_elements: List[str] + List of elements to be colored with special_color. + special_color: str + Hexadecimal color to be used for the special elements. Returns ------- figure @@ -163,9 +168,7 @@ def ptable_plotter( color_scale = ScalarMappable(norm=norm, cmap=cmap).to_rgba(data, alpha=None) # Set blank color - color_list = [] - for i in range(len(elements)): - color_list.append(blank_color) + color_list = [blank_color] * len(elements) # Compare elements in dataset with elements in periodic table for i, data_element in enumerate(data_elements): @@ -178,13 +181,17 @@ def ptable_plotter( warnings.warn("Invalid chemical symbol: " + data_element) if color_list[element_index] != blank_color: warnings.warn("Multiple entries for element " + data_element) - if under_value is not None and data[i] <= under_value: + elif under_value is not None and data[i] <= under_value: color_list[element_index] = under_color elif over_value is not None and data[i] >= over_value: color_list[element_index] = over_color else: color_list[element_index] = to_hex(color_scale[i]) + for k, v in elements["symbol"].iteritems(): + if v in special_elements: + color_list[k] = special_color + # Define figure properties for visualizing data source = ColumnDataSource( data=dict( @@ -248,6 +255,3 @@ def ptable_plotter( show_(p) return p - - -ptable_plotter("ptable.csv", show=True, over_value=0)