-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(sdk): add radicalbit_platform_sdk plot module (#204)
* feat(sdk): add placeholder chart (WIP) * feat(sdk): add chart_sdk * feat(sdk): fix version * feat(sdk): rename file and format with ruff * feat: add numerical_bar_chart * feat(chart_sdk_: add test to NumericalBarChart * feat(chart_sdk): binary wip * feat(chart_sdk): add binary distribution chart * feat(chart_sdk): add multiclassification data distribution chart * feat(chart_sdk): add regression distribution chart + utils to get bucket data fromatted * feat(chart_sdk): ruff fix * feat(chart_sdk): add legend and title on every chart * feat(chart_sdk): remove console.debug * feat(chart_sdk): add assert into binaryCharts * feat(sdk): move chart_sdk inside sdk * feat(sdk): add linearChart into sdk * feat(sdk): add confusionMatrix chart * feat(sdk): fix test_chart notebook * feat(sdk): add predictionActual chart * feat(sdk): add residualScatterChart * feat(sdk): add residualBucket chart * feat(sdk): ruff format * feat(sdk): ruff fix * feat(sdk): ruff fixies * feat(sdk): fix linear chart in multiclass * feat(sdk) replace print chart titel with display fn * feat(sdk): remove Option form color in regressionChart * feat(sdk): replace model_dump + get with list comprehension * feat(sdk): remove placeholder chart --------- Co-authored-by: Luca Tagliabue <[email protected]>
- Loading branch information
Showing
20 changed files
with
4,075 additions
and
551 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .chart_data import NumericalBarChartData, ConfusionMatrixChartData | ||
from .chart import Chart | ||
|
||
__all__ = [ | ||
'ConfusionMatrixChartData', | ||
'NumericalBarChartData', | ||
'Chart' | ||
] |
8 changes: 8 additions & 0 deletions
8
sdk/radicalbit_platform_sdk/charts/binary_classification/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .binary_chart import BinaryChart | ||
from .binary_chart_data import BinaryDistributionChartData, BinaryLinearChartData | ||
|
||
__all__ = [ | ||
'BinaryChart', | ||
'BinaryDistributionChartData', | ||
'BinaryLinearChartData' | ||
] |
149 changes: 149 additions & 0 deletions
149
sdk/radicalbit_platform_sdk/charts/binary_classification/binary_chart.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
from ipecharts import EChartsRawWidget | ||
|
||
from ..utils import get_chart_header | ||
from .binary_chart_data import BinaryDistributionChartData, BinaryLinearChartData | ||
|
||
|
||
class BinaryChart: | ||
def __init__(self) -> None: | ||
pass | ||
|
||
def distribution_chart(self, data: BinaryDistributionChartData) -> EChartsRawWidget: | ||
assert len(data.reference_data) <= 2 | ||
assert len(data.y_axis_label) <= 2 | ||
|
||
if data.current_data: | ||
assert len(data.current_data) <= 2 | ||
|
||
reference_json_data = [binary_data.model_dump() for binary_data in data.reference_data] | ||
current_data_json = [binary_data.model_dump() for binary_data in data.current_data] if data.current_data else [] | ||
|
||
reference_series_data = { | ||
'title': data.title, | ||
'type': 'bar', | ||
'itemStyle': {'color': '#9B99A1'}, | ||
'data': reference_json_data, | ||
'color': '#9B99A1', | ||
'name': 'Reference', | ||
'label': { | ||
'show': True, | ||
'position': 'insideRight', | ||
'fontWeight': 'bold', | ||
'color': '#FFFFFF', | ||
}, | ||
} | ||
|
||
current_series_data = { | ||
'title': data.title + '_current', | ||
'type': 'bar', | ||
'itemStyle': {}, | ||
'data': current_data_json, | ||
'color': '#3695d9', | ||
'name': 'Current', | ||
'label': { | ||
'show': True, | ||
'position': 'insideRight', | ||
'fontWeight': 'bold', | ||
'color': '#FFFFFF', | ||
}, | ||
} | ||
|
||
series = ( | ||
[reference_series_data] | ||
if not data.current_data | ||
else [reference_series_data, current_series_data] | ||
) | ||
|
||
option = { | ||
'grid': { | ||
'left': 0, | ||
'right': 20, | ||
'bottom': 0, | ||
'top': 40, | ||
'containLabel': True, | ||
}, | ||
'xAxis': { | ||
'type': 'value', | ||
'axisLabel': {'fontSize': 9, 'color': '#9b99a1'}, | ||
'splitLine': {'lineStyle': {'color': '#9f9f9f54'}}, | ||
}, | ||
'yAxis': { | ||
'type': 'category', | ||
'axisTick': {'show': False}, | ||
'axisLine': {'show': False}, | ||
'splitLine': {'show': False}, | ||
'axisLabel': {'fontSize': 12, 'color': '#9B99A1'}, | ||
'data': data.y_axis_label, | ||
}, | ||
'emphasis': {'disabled': True}, | ||
'barCategoryGap': '21%', | ||
'barGap': '0', | ||
'itemStyle': {'borderWidth': 1, 'borderColor': 'rgba(201, 25, 25, 1)'}, | ||
'series': series, | ||
} | ||
|
||
option.update(get_chart_header(title=data.title)) | ||
|
||
return EChartsRawWidget(option=option) | ||
|
||
def linear_chart(self, data: BinaryLinearChartData) -> EChartsRawWidget: | ||
|
||
reference_series_data = { | ||
'name': 'Reference', | ||
'type': 'line', | ||
'lineStyle': {'width': 2.2, 'color': '#9B99A1', 'type': 'dotted'}, | ||
'symbol': 'none', | ||
'data': data.reference_data, | ||
'itemStyle': {'color': '#9B99A1'}, | ||
'endLabel': {'show': True, 'color': '#9B99A1'}, | ||
'color': '#9B99A1', | ||
} | ||
|
||
current_series_data = { | ||
'name': data.title, | ||
'type': 'line', | ||
'lineStyle': {'width': 2.2, 'color': '#73B2E0'}, | ||
'symbol': 'none', | ||
'data': data.current_data, | ||
'itemStyle': {'color': '#73B2E0'}, | ||
} | ||
|
||
series = [reference_series_data, current_series_data] | ||
|
||
options = { | ||
'tooltip': { | ||
'trigger': 'axis', | ||
'crosshairs': True, | ||
'axisPointer': {'type': 'cross', 'label': {'show': True}}, | ||
}, | ||
'yAxis': { | ||
'type': 'value', | ||
'axisLabel': {'fontSize': 9, 'color': '#9b99a1'}, | ||
'splitLine': {'lineStyle': {'color': '#9f9f9f54'}}, | ||
'scale': True, | ||
}, | ||
'xAxis': { | ||
'type': 'time', | ||
'axisTick': {'show': False}, | ||
'axisLine': {'show': False}, | ||
'splitLine': {'show': False}, | ||
'axisLabel': {'fontSize': 12, 'color': '#9b99a1'}, | ||
'scale': True, | ||
}, | ||
'grid': { | ||
'bottom': 0, | ||
'top': 32, | ||
'left': 0, | ||
'right': 64, | ||
'containLabel': True, | ||
}, | ||
'series': series, | ||
'legend': { | ||
'show': True, | ||
'textStyle': {'color': '#9B99A1'}, | ||
}, | ||
} | ||
|
||
options.update(get_chart_header(title=data.title)) | ||
|
||
return EChartsRawWidget(option=options) |
22 changes: 22 additions & 0 deletions
22
sdk/radicalbit_platform_sdk/charts/binary_classification/binary_chart_data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from typing import List, Optional | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class BinaryDistributionData(BaseModel): | ||
percentage: float | ||
count: float | ||
value: float | ||
|
||
|
||
class BinaryDistributionChartData(BaseModel): | ||
title: str | ||
y_axis_label: List[str] | ||
reference_data: List[BinaryDistributionData] | ||
current_data: Optional[List[BinaryDistributionData]] = None | ||
|
||
|
||
class BinaryLinearChartData(BaseModel): | ||
title: str | ||
reference_data: List[List[str]] | ||
current_data: List[List[str]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from ipecharts import EChartsRawWidget | ||
import numpy as np | ||
|
||
from .chart_data import ConfusionMatrixChartData, NumericalBarChartData | ||
from .utils import get_chart_header, get_formatted_bucket_data | ||
|
||
|
||
class Chart: | ||
def __init__(self) -> None: | ||
pass | ||
|
||
def numerical_bar_chart(self, data: NumericalBarChartData) -> EChartsRawWidget: | ||
bucket_data_formatted = get_formatted_bucket_data(bucket_data=data.bucket_data) | ||
|
||
reference_data_json = { | ||
'title': 'reference', | ||
'type': 'bar', | ||
'name': 'Reference', | ||
'itemStyle': {'color': '#9B99A1'}, | ||
'data': data.reference_data, | ||
} | ||
|
||
current_data_json = { | ||
'title': 'current', | ||
'type': 'bar', | ||
'name': 'Current', | ||
'itemStyle': {'color': '#3695D9'}, | ||
'data': data.current_data, | ||
} | ||
|
||
series = ( | ||
[reference_data_json] | ||
if not data.current_data | ||
else [reference_data_json, current_data_json] | ||
) | ||
|
||
option = { | ||
'grid': { | ||
'left': 0, | ||
'right': 20, | ||
'bottom': 0, | ||
'top': 40, | ||
'containLabel': True, | ||
}, | ||
'xAxis': { | ||
'type': 'category', | ||
'axisTick': {'show': False}, | ||
'axisLine': {'show': False}, | ||
'splitLine': {'show': False}, | ||
'axisLabel': { | ||
'fontSize': 12, | ||
'interval': 0, | ||
'color': '#9B99A1', | ||
'rotate': 20, | ||
}, | ||
'data': bucket_data_formatted, | ||
}, | ||
'yAxis': { | ||
'type': 'value', | ||
'axisLabel': {'fontSize': 9, 'color': '#9B99A1'}, | ||
'splitLine': {'lineStyle': {'color': '#9f9f9f54'}}, | ||
}, | ||
'emphasis': {'disabled': True}, | ||
'barCategoryGap': '0', | ||
'barGap': '0', | ||
'itemStyle': {'borderWidth': 1, 'borderColor': 'rgba(201, 25, 25, 1)'}, | ||
'series': series, | ||
} | ||
|
||
option.update(get_chart_header(title=data.title)) | ||
|
||
return EChartsRawWidget(option=option) | ||
|
||
def confusion_matrix_chart( | ||
self, data: ConfusionMatrixChartData | ||
) -> EChartsRawWidget: | ||
assert len(data.matrix) == len(data.axis_label) * len( | ||
data.axis_label | ||
), 'axis_label count and matrix item count are not compatibile' | ||
|
||
np_matrix = np.matrix(data.matrix) | ||
|
||
options = { | ||
'yAxis': { | ||
'type': 'category', | ||
'axisTick': {'show': False}, | ||
'axisLine': {'show': False}, | ||
'splitLine': {'show': False}, | ||
'axisLabel': {'fontSize': 12, 'color': '#9B99A1'}, | ||
'data': data.axis_label, | ||
'name': 'Actual', | ||
'nameGap': 25, | ||
'nameLocation': 'middle', | ||
}, | ||
'xAxis': { | ||
'type': 'category', | ||
'axisTick': {'show': False}, | ||
'axisLine': {'show': False}, | ||
'splitLine': {'show': False}, | ||
'axisLabel': { | ||
'fontSize': 12, | ||
'interval': 0, | ||
'color': '#9b99a1', | ||
'rotate': 45, | ||
}, | ||
'data': data.axis_label.reverse(), | ||
'name': 'Predicted', | ||
'nameGap': 25, | ||
'nameLocation': 'middle', | ||
}, | ||
'grid': {'bottom': 60, 'top': 0, 'left': 44, 'right': 80}, | ||
'emphasis': {'disabled': True}, | ||
'axis': {'axisLabel': {'fontSize': 9, 'color': '#9b99a1'}}, | ||
'visualMap': { | ||
'calculable': True, | ||
'orient': 'vertical', | ||
'right': 'right', | ||
'top': 'center', | ||
'itemHeight': '250rem', | ||
'max': np_matrix.max(), | ||
'inRange': {'color': data.color}, | ||
}, | ||
'series': { | ||
'name': '', | ||
'type': 'heatmap', | ||
'label': {'show': True}, | ||
'data': data.matrix, | ||
}, | ||
} | ||
|
||
return EChartsRawWidget(option=options) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from typing import List, Optional | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class NumericalBarChartData(BaseModel): | ||
title: str | ||
bucket_data: List[str] | ||
reference_data: List[float] | ||
current_data: Optional[List[float]] = None | ||
|
||
|
||
class ConfusionMatrixChartData(BaseModel): | ||
axis_label: List[str] | ||
matrix: List[List[float]] | ||
color: Optional[List[str]] = ['#FFFFFF', '#9B99A1'] |
9 changes: 9 additions & 0 deletions
9
sdk/radicalbit_platform_sdk/charts/multi_classification/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from .multi_class_chart import MultiClassificationChart | ||
from .multi_class_chart_data import MultiClassificationDistributionChartData, MultiClassificationLinearChartData, MultiClassificationLinearData | ||
|
||
__all__ = [ | ||
'MultiClassificationChart', | ||
'MultiClassificationDistributionChartData', | ||
'MultiClassificationLinearChartData', | ||
'MultiClassificationLinearData' | ||
] |
Oops, something went wrong.