diff --git a/observability-lib/dashboards/core-node/component.go b/observability-lib/dashboards/core-node/component.go index f745534ae..1843b4003 100644 --- a/observability-lib/dashboards/core-node/component.go +++ b/observability-lib/dashboards/core-node/component.go @@ -2,6 +2,7 @@ package corenode import ( "fmt" + "strconv" "github.com/grafana/grafana-foundation-sdk/go/alerting" "github.com/grafana/grafana-foundation-sdk/go/cog" @@ -241,6 +242,50 @@ func vars(p *Props) []cog.Builder[dashboard.VariableModel] { return variables } +func healthAverageAlertRule(p *Props, threshold float64, tags map[string]string) grafana.AlertOptions { + return grafana.AlertOptions{ + Title: `Health Avg by Service is less than ` + strconv.FormatFloat(threshold, 'f', -1, 64) + `%`, + Summary: `Uptime less than ` + strconv.FormatFloat(threshold, 'f', -1, 64) + `% over last 15 minutes on one component in a Node`, + Description: `Component {{ index $labels "service_id" }} uptime in the last 15m is {{ index $values "C" }}%`, + RunbookURL: "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", + For: "15m", + Tags: tags, + Query: []grafana.RuleQuery{ + { + Expr: `health{` + p.AlertsFilters + `}`, + RefID: "A", + Datasource: p.MetricsDataSource.UID, + }, + }, + QueryRefCondition: "D", + Condition: []grafana.ConditionQuery{ + { + RefID: "B", + ReduceExpression: &grafana.ReduceExpression{ + Expression: "A", + Reducer: expr.TypeReduceReducerMean, + }, + }, + { + RefID: "C", + MathExpression: &grafana.MathExpression{ + Expression: "$B * 100", + }, + }, + { + RefID: "D", + ThresholdExpression: &grafana.ThresholdExpression{ + Expression: "C", + ThresholdConditionsOptions: grafana.ThresholdConditionsOption{ + Params: []float64{threshold}, + Type: grafana.TypeThresholdTypeLt, + }, + }, + }, + }, + } +} + func headlines(p *Props) []*grafana.Panel { var panels []*grafana.Panel @@ -373,47 +418,10 @@ func headlines(p *Props) []*grafana.Panel { DisplayMode: common.LegendDisplayModeList, Placement: common.LegendPlacementRight, }, - AlertOptions: &grafana.AlertOptions{ - Summary: `Uptime less than 90% over last 15 minutes on one component in a Node`, - Description: `Component {{ index $labels "service_id" }} uptime in the last 15m is {{ index $values "C" }}%`, - RunbookURL: "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", - For: "15m", - Tags: map[string]string{ - "severity": "warning", - }, - Query: []grafana.RuleQuery{ - { - Expr: `health{` + p.AlertsFilters + `}`, - RefID: "A", - Datasource: p.MetricsDataSource.UID, - }, - }, - QueryRefCondition: "D", - Condition: []grafana.ConditionQuery{ - { - RefID: "B", - ReduceExpression: &grafana.ReduceExpression{ - Expression: "A", - Reducer: expr.TypeReduceReducerMean, - }, - }, - { - RefID: "C", - MathExpression: &grafana.MathExpression{ - Expression: "$B * 100", - }, - }, - { - RefID: "D", - ThresholdExpression: &grafana.ThresholdExpression{ - Expression: "C", - ThresholdConditionsOptions: grafana.ThresholdConditionsOption{ - Params: []float64{90}, - Type: grafana.TypeThresholdTypeLt, - }, - }, - }, - }, + AlertsOptions: []grafana.AlertOptions{ + healthAverageAlertRule(p, 90, map[string]string{"severity": "info"}), + healthAverageAlertRule(p, 70, map[string]string{"severity": "warning"}), + healthAverageAlertRule(p, 50, map[string]string{"severity": "critical"}), }, })) @@ -478,32 +486,34 @@ func headlines(p *Props) []*grafana.Panel { }, }, }, - AlertOptions: &grafana.AlertOptions{ - Summary: `ETH Balance is lower than threshold`, - Description: `ETH Balance critically low at {{ index $values "A" }} on {{ index $labels "` + p.platformOpts.LabelFilter + `" }}`, - RunbookURL: "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", - For: "15m", - NoDataState: alerting.RuleNoDataStateOK, - Tags: map[string]string{ - "severity": "critical", - }, - Query: []grafana.RuleQuery{ - { - Expr: `eth_balance{` + p.AlertsFilters + `}`, - Instant: true, - RefID: "A", - Datasource: p.MetricsDataSource.UID, + AlertsOptions: []grafana.AlertOptions{ + { + Summary: `ETH Balance is lower than threshold`, + Description: `ETH Balance critically low at {{ index $values "A" }} on {{ index $labels "` + p.platformOpts.LabelFilter + `" }}`, + RunbookURL: "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", + For: "15m", + NoDataState: alerting.RuleNoDataStateOK, + Tags: map[string]string{ + "severity": "critical", }, - }, - QueryRefCondition: "B", - Condition: []grafana.ConditionQuery{ - { - RefID: "B", - ThresholdExpression: &grafana.ThresholdExpression{ - Expression: "A", - ThresholdConditionsOptions: grafana.ThresholdConditionsOption{ - Params: []float64{1}, - Type: grafana.TypeThresholdTypeLt, + Query: []grafana.RuleQuery{ + { + Expr: `eth_balance{` + p.AlertsFilters + `}`, + Instant: true, + RefID: "A", + Datasource: p.MetricsDataSource.UID, + }, + }, + QueryRefCondition: "B", + Condition: []grafana.ConditionQuery{ + { + RefID: "B", + ThresholdExpression: &grafana.ThresholdExpression{ + Expression: "A", + ThresholdConditionsOptions: grafana.ThresholdConditionsOption{ + Params: []float64{1}, + Type: grafana.TypeThresholdTypeLt, + }, }, }, }, @@ -525,32 +535,34 @@ func headlines(p *Props) []*grafana.Panel { }, }, }, - AlertOptions: &grafana.AlertOptions{ - Summary: `Solana Balance is lower than threshold`, - Description: `Solana Balance critically low at {{ index $values "A" }} on {{ index $labels "` + p.platformOpts.LabelFilter + `" }}`, - RunbookURL: "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", - For: "15m", - NoDataState: alerting.RuleNoDataStateOK, - Tags: map[string]string{ - "severity": "critical", - }, - Query: []grafana.RuleQuery{ - { - Expr: `solana_balance{` + p.AlertsFilters + `}`, - Instant: true, - RefID: "A", - Datasource: p.MetricsDataSource.UID, + AlertsOptions: []grafana.AlertOptions{ + { + Summary: `Solana Balance is lower than threshold`, + Description: `Solana Balance critically low at {{ index $values "A" }} on {{ index $labels "` + p.platformOpts.LabelFilter + `" }}`, + RunbookURL: "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", + For: "15m", + NoDataState: alerting.RuleNoDataStateOK, + Tags: map[string]string{ + "severity": "critical", }, - }, - QueryRefCondition: "B", - Condition: []grafana.ConditionQuery{ - { - RefID: "B", - ThresholdExpression: &grafana.ThresholdExpression{ - Expression: "A", - ThresholdConditionsOptions: grafana.ThresholdConditionsOption{ - Params: []float64{1}, - Type: grafana.TypeThresholdTypeLt, + Query: []grafana.RuleQuery{ + { + Expr: `solana_balance{` + p.AlertsFilters + `}`, + Instant: true, + RefID: "A", + Datasource: p.MetricsDataSource.UID, + }, + }, + QueryRefCondition: "B", + Condition: []grafana.ConditionQuery{ + { + RefID: "B", + ThresholdExpression: &grafana.ThresholdExpression{ + Expression: "A", + ThresholdConditionsOptions: grafana.ThresholdConditionsOption{ + Params: []float64{1}, + Type: grafana.TypeThresholdTypeLt, + }, }, }, }, @@ -876,32 +888,34 @@ func headTracker(p *Props) []*grafana.Panel { }, }, }, - AlertOptions: &grafana.AlertOptions{ - Summary: `No Headers Received`, - Description: `{{ index $labels "` + p.platformOpts.LabelFilter + `" }} on ChainID {{ index $labels "ChainID" }} has received {{ index $values "A" }} heads over 10 minutes.`, - RunbookURL: "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", - For: "10m", - NoDataState: alerting.RuleNoDataStateOK, - Tags: map[string]string{ - "severity": "critical", - }, - Query: []grafana.RuleQuery{ - { - Expr: `increase(head_tracker_heads_received{` + p.AlertsFilters + `}[10m])`, - Instant: true, - RefID: "A", - Datasource: p.MetricsDataSource.UID, + AlertsOptions: []grafana.AlertOptions{ + { + Summary: `No Headers Received`, + Description: `{{ index $labels "` + p.platformOpts.LabelFilter + `" }} on ChainID {{ index $labels "ChainID" }} has received {{ index $values "A" }} heads over 10 minutes.`, + RunbookURL: "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", + For: "10m", + NoDataState: alerting.RuleNoDataStateOK, + Tags: map[string]string{ + "severity": "critical", }, - }, - QueryRefCondition: "B", - Condition: []grafana.ConditionQuery{ - { - RefID: "B", - ThresholdExpression: &grafana.ThresholdExpression{ - Expression: "A", - ThresholdConditionsOptions: grafana.ThresholdConditionsOption{ - Params: []float64{1}, - Type: grafana.TypeThresholdTypeLt, + Query: []grafana.RuleQuery{ + { + Expr: `increase(head_tracker_heads_received{` + p.AlertsFilters + `}[10m])`, + Instant: true, + RefID: "A", + Datasource: p.MetricsDataSource.UID, + }, + }, + QueryRefCondition: "B", + Condition: []grafana.ConditionQuery{ + { + RefID: "B", + ThresholdExpression: &grafana.ThresholdExpression{ + Expression: "A", + ThresholdConditionsOptions: grafana.ThresholdConditionsOption{ + Params: []float64{1}, + Type: grafana.TypeThresholdTypeLt, + }, }, }, }, diff --git a/observability-lib/dashboards/core-node/test-output.json b/observability-lib/dashboards/core-node/test-output.json index 0c729dbe5..7f43d7735 100644 --- a/observability-lib/dashboards/core-node/test-output.json +++ b/observability-lib/dashboards/core-node/test-output.json @@ -5891,6 +5891,7 @@ { "annotations": { "description": "Component {{ index $labels \"service_id\" }} uptime in the last 15m is {{ index $values \"C\" }}%", + "panel_title": "Health Avg by Service over 15m", "runbook_url": "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", "summary": "Uptime less than 90% over last 15 minutes on one component in a Node" }, @@ -5970,17 +5971,200 @@ "execErrState": "Alerting", "folderUID": "", "for": "15m", + "labels": { + "severity": "info" + }, + "noDataState": "NoData", + "orgID": 0, + "ruleGroup": "", + "title": "Health Avg by Service is less than 90%" + }, + { + "annotations": { + "description": "Component {{ index $labels \"service_id\" }} uptime in the last 15m is {{ index $values \"C\" }}%", + "panel_title": "Health Avg by Service over 15m", + "runbook_url": "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", + "summary": "Uptime less than 70% over last 15 minutes on one component in a Node" + }, + "condition": "D", + "data": [ + { + "datasourceUid": "1", + "model": { + "expr": "health{}", + "legendFormat": "__auto", + "refId": "A" + }, + "refId": "A", + "relativeTimeRange": { + "from": 600, + "to": 0 + } + }, + { + "datasourceUid": "__expr__", + "model": { + "expression": "A", + "intervalMs": 1000, + "maxDataPoints": 43200, + "reducer": "mean", + "refId": "B", + "type": "reduce" + }, + "refId": "B", + "relativeTimeRange": { + "from": 600, + "to": 0 + } + }, + { + "datasourceUid": "__expr__", + "model": { + "expression": "$B * 100", + "intervalMs": 1000, + "maxDataPoints": 43200, + "refId": "C", + "type": "math" + }, + "refId": "C", + "relativeTimeRange": { + "from": 600, + "to": 0 + } + }, + { + "datasourceUid": "__expr__", + "model": { + "conditions": [ + { + "evaluator": { + "params": [ + 70, + 0 + ], + "type": "lt" + } + } + ], + "expression": "C", + "intervalMs": 1000, + "maxDataPoints": 43200, + "refId": "D", + "type": "threshold" + }, + "refId": "D", + "relativeTimeRange": { + "from": 600, + "to": 0 + } + } + ], + "execErrState": "Alerting", + "folderUID": "", + "for": "15m", "labels": { "severity": "warning" }, "noDataState": "NoData", "orgID": 0, "ruleGroup": "", - "title": "Health Avg by Service over 15m" + "title": "Health Avg by Service is less than 70%" + }, + { + "annotations": { + "description": "Component {{ index $labels \"service_id\" }} uptime in the last 15m is {{ index $values \"C\" }}%", + "panel_title": "Health Avg by Service over 15m", + "runbook_url": "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", + "summary": "Uptime less than 50% over last 15 minutes on one component in a Node" + }, + "condition": "D", + "data": [ + { + "datasourceUid": "1", + "model": { + "expr": "health{}", + "legendFormat": "__auto", + "refId": "A" + }, + "refId": "A", + "relativeTimeRange": { + "from": 600, + "to": 0 + } + }, + { + "datasourceUid": "__expr__", + "model": { + "expression": "A", + "intervalMs": 1000, + "maxDataPoints": 43200, + "reducer": "mean", + "refId": "B", + "type": "reduce" + }, + "refId": "B", + "relativeTimeRange": { + "from": 600, + "to": 0 + } + }, + { + "datasourceUid": "__expr__", + "model": { + "expression": "$B * 100", + "intervalMs": 1000, + "maxDataPoints": 43200, + "refId": "C", + "type": "math" + }, + "refId": "C", + "relativeTimeRange": { + "from": 600, + "to": 0 + } + }, + { + "datasourceUid": "__expr__", + "model": { + "conditions": [ + { + "evaluator": { + "params": [ + 50, + 0 + ], + "type": "lt" + } + } + ], + "expression": "C", + "intervalMs": 1000, + "maxDataPoints": 43200, + "refId": "D", + "type": "threshold" + }, + "refId": "D", + "relativeTimeRange": { + "from": 600, + "to": 0 + } + } + ], + "execErrState": "Alerting", + "folderUID": "", + "for": "15m", + "labels": { + "severity": "critical" + }, + "noDataState": "NoData", + "orgID": 0, + "ruleGroup": "", + "title": "Health Avg by Service is less than 50%" }, { "annotations": { "description": "ETH Balance critically low at {{ index $values \"A\" }} on {{ index $labels \"instance\" }}", + "panel_title": "ETH Balance", "runbook_url": "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", "summary": "ETH Balance is lower than threshold" }, @@ -6042,6 +6226,7 @@ { "annotations": { "description": "Solana Balance critically low at {{ index $values \"A\" }} on {{ index $labels \"instance\" }}", + "panel_title": "SOL Balance", "runbook_url": "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", "summary": "Solana Balance is lower than threshold" }, @@ -6103,6 +6288,7 @@ { "annotations": { "description": "{{ index $labels \"instance\" }} on ChainID {{ index $labels \"ChainID\" }} has received {{ index $values \"A\" }} heads over 10 minutes.", + "panel_title": "Head Tracker Heads Received Rate", "runbook_url": "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", "summary": "No Headers Received" }, diff --git a/observability-lib/grafana/alerts.go b/observability-lib/grafana/alerts.go index f050b89b0..980d3ac74 100644 --- a/observability-lib/grafana/alerts.go +++ b/observability-lib/grafana/alerts.go @@ -158,7 +158,7 @@ func newConditionQuery(options ConditionQuery) *alerting.QueryBuilder { } type AlertOptions struct { - Name string + Title string Summary string Description string RunbookURL string @@ -169,6 +169,7 @@ type AlertOptions struct { Query []RuleQuery QueryRefCondition string Condition []ConditionQuery + PanelTitle string } func NewAlertRule(options *AlertOptions) *alerting.RuleBuilder { @@ -188,16 +189,22 @@ func NewAlertRule(options *AlertOptions) *alerting.RuleBuilder { options.QueryRefCondition = "A" } - rule := alerting.NewRuleBuilder(options.Name). + annotations := map[string]string{ + "summary": options.Summary, + "description": options.Description, + "runbook_url": options.RunbookURL, + } + + if options.PanelTitle != "" { + annotations["panel_title"] = options.PanelTitle + } + + rule := alerting.NewRuleBuilder(options.Title). For(options.For). NoDataState(options.NoDataState). ExecErrState(options.RuleExecErrState). Condition(options.QueryRefCondition). - Annotations(map[string]string{ - "summary": options.Summary, - "description": options.Description, - "runbook_url": options.RunbookURL, - }). + Annotations(annotations). Labels(options.Tags) for _, query := range options.Query { diff --git a/observability-lib/grafana/builder.go b/observability-lib/grafana/builder.go index 90319dd91..7b4aba40a 100644 --- a/observability-lib/grafana/builder.go +++ b/observability-lib/grafana/builder.go @@ -86,8 +86,8 @@ func (b *Builder) AddPanel(panel ...*Panel) { item.heatmapBuilder.Id(panelID) b.dashboardBuilder.WithPanel(item.heatmapBuilder) } - if item.alertBuilder != nil { - b.alertsBuilder = append(b.alertsBuilder, item.alertBuilder) + if item.alertBuilders != nil && len(item.alertBuilders) > 0 { + b.AddAlert(item.alertBuilders...) } } } diff --git a/observability-lib/grafana/dashboard.go b/observability-lib/grafana/dashboard.go index 8c5b11773..8b2119a7c 100644 --- a/observability-lib/grafana/dashboard.go +++ b/observability-lib/grafana/dashboard.go @@ -3,6 +3,7 @@ package grafana import ( "encoding/json" "fmt" + "reflect" "github.com/grafana/grafana-foundation-sdk/go/alerting" "github.com/grafana/grafana-foundation-sdk/go/dashboard" @@ -42,7 +43,7 @@ type DeployOptions struct { func alertRuleExist(alerts []alerting.Rule, alert alerting.Rule) bool { for _, a := range alerts { - if a.Title == alert.Title { + if reflect.DeepEqual(a, alert) { return true } } @@ -103,9 +104,11 @@ func (db *Dashboard) DeployToGrafana(options *DeployOptions) error { alert.FolderUID = folder.UID alert.Annotations["__dashboardUid__"] = *newDashboard.UID - panelId := panelIDByTitle(db.Dashboard, alert.Title) + panelId := panelIDByTitle(db.Dashboard, alert.Annotations["panel_title"]) + // we can clean it up as it was only used to get the panelId + delete(alert.Annotations, "panel_title") if panelId != "" { - alert.Annotations["__panelId__"] = panelIDByTitle(db.Dashboard, alert.Title) + alert.Annotations["__panelId__"] = panelId } if alertRuleExist(alertsRule, alert) { // update alert rule if it already exists diff --git a/observability-lib/grafana/dashboard_test.go b/observability-lib/grafana/dashboard_test.go index 11df71395..3f0ed8bf6 100644 --- a/observability-lib/grafana/dashboard_test.go +++ b/observability-lib/grafana/dashboard_test.go @@ -32,31 +32,33 @@ func TestGenerateJSON(t *testing.T) { }, }, }, - AlertOptions: &grafana.AlertOptions{ - Summary: `ETH Balance is lower than threshold`, - Description: `ETH Balance critically low at {{ index $values "A" }}`, - RunbookURL: "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", - For: "1m", - Tags: map[string]string{ - "severity": "warning", - }, - Query: []grafana.RuleQuery{ - { - Expr: `eth_balance`, - Instant: true, - RefID: "A", - Datasource: "datasource-uid", + AlertsOptions: []grafana.AlertOptions{ + { + Summary: `ETH Balance is lower than threshold`, + Description: `ETH Balance critically low at {{ index $values "A" }}`, + RunbookURL: "https://github.com/smartcontractkit/chainlink-common/tree/main/observability-lib", + For: "1m", + Tags: map[string]string{ + "severity": "warning", }, - }, - QueryRefCondition: "B", - Condition: []grafana.ConditionQuery{ - { - RefID: "B", - ThresholdExpression: &grafana.ThresholdExpression{ - Expression: "A", - ThresholdConditionsOptions: grafana.ThresholdConditionsOption{ - Params: []float64{2}, - Type: grafana.TypeThresholdTypeLt, + Query: []grafana.RuleQuery{ + { + Expr: `eth_balance`, + Instant: true, + RefID: "A", + Datasource: "datasource-uid", + }, + }, + QueryRefCondition: "B", + Condition: []grafana.ConditionQuery{ + { + RefID: "B", + ThresholdExpression: &grafana.ThresholdExpression{ + Expression: "A", + ThresholdConditionsOptions: grafana.ThresholdConditionsOption{ + Params: []float64{2}, + Type: grafana.TypeThresholdTypeLt, + }, }, }, }, diff --git a/observability-lib/grafana/panels.go b/observability-lib/grafana/panels.go index a3e67b633..9869a2317 100644 --- a/observability-lib/grafana/panels.go +++ b/observability-lib/grafana/panels.go @@ -112,7 +112,7 @@ type Panel struct { tablePanelBuilder *table.PanelBuilder logPanelBuilder *logs.PanelBuilder heatmapBuilder *heatmap.PanelBuilder - alertBuilder *alerting.RuleBuilder + alertBuilders []*alerting.RuleBuilder } // panel defaults @@ -226,7 +226,7 @@ func NewStatPanel(options *StatPanelOptions) *Panel { type TimeSeriesPanelOptions struct { *PanelOptions - AlertOptions *AlertOptions + AlertsOptions []AlertOptions FillOpacity float64 ScaleDistribution common.ScaleDistribution LegendOptions *LegendOptions @@ -287,17 +287,22 @@ func NewTimeSeriesPanel(options *TimeSeriesPanelOptions) *Panel { newPanel.ColorScheme(dashboard.NewFieldColorBuilder().Mode(options.ColorScheme)) } - if options.AlertOptions != nil { - options.AlertOptions.Name = options.Title - - return &Panel{ - timeSeriesPanelBuilder: newPanel, - alertBuilder: NewAlertRule(options.AlertOptions), + var alertBuilders []*alerting.RuleBuilder + if options.AlertsOptions != nil && len(options.AlertsOptions) > 0 { + for _, alert := range options.AlertsOptions { + // this is used as an internal mechanism to set the panel title in the alert to associate panelId with alert + alert.PanelTitle = options.Title + // if name is provided use it, otherwise use panel title + if alert.Title == "" { + alert.Title = options.Title + } + alertBuilders = append(alertBuilders, NewAlertRule(&alert)) } } return &Panel{ timeSeriesPanelBuilder: newPanel, + alertBuilders: alertBuilders, } } diff --git a/pkg/capabilities/consensus/ocr3/aggregators/identical.go b/pkg/capabilities/consensus/ocr3/aggregators/identical.go index ed38fc0e2..aa05e7cf4 100644 --- a/pkg/capabilities/consensus/ocr3/aggregators/identical.go +++ b/pkg/capabilities/consensus/ocr3/aggregators/identical.go @@ -13,11 +13,12 @@ import ( ocrcommon "github.com/smartcontractkit/libocr/commontypes" ) +// Aggregates by the most frequent observation for each index of a data set type identicalAggregator struct { - config aggregatorConfig + config identicalAggConfig } -type aggregatorConfig struct { +type identicalAggConfig struct { // Length of the list of observations that each node is expected to provide. // Aggregator's output (i.e. EncodableOutcome) will be a values.Map with the same // number of elements and keyed by indices 0,1,2,... (unless KeyOverrides are provided). @@ -102,7 +103,7 @@ func (a *identicalAggregator) collectHighestCounts(counters []map[[32]byte]*coun } func NewIdenticalAggregator(config values.Map) (*identicalAggregator, error) { - parsedConfig, err := ParseConfig(config) + parsedConfig, err := ParseConfigIdenticalAggregator(config) if err != nil { return nil, fmt.Errorf("failed to parse config (%+v): %w", config, err) } @@ -111,10 +112,10 @@ func NewIdenticalAggregator(config values.Map) (*identicalAggregator, error) { }, nil } -func ParseConfig(config values.Map) (aggregatorConfig, error) { - parsedConfig := aggregatorConfig{} +func ParseConfigIdenticalAggregator(config values.Map) (identicalAggConfig, error) { + parsedConfig := identicalAggConfig{} if err := config.UnwrapTo(&parsedConfig); err != nil { - return aggregatorConfig{}, err + return identicalAggConfig{}, err } if parsedConfig.ExpectedObservationsLen == 0 { parsedConfig.ExpectedObservationsLen = 1 diff --git a/pkg/capabilities/consensus/ocr3/aggregators/identical_test.go b/pkg/capabilities/consensus/ocr3/aggregators/identical_test.go index 711b1ab25..95688e894 100644 --- a/pkg/capabilities/consensus/ocr3/aggregators/identical_test.go +++ b/pkg/capabilities/consensus/ocr3/aggregators/identical_test.go @@ -13,7 +13,7 @@ import ( ) func TestDataFeedsAggregator_Aggregate(t *testing.T) { - config := getConfig(t, nil) + config := getConfigIdenticalAggregator(t, nil) agg, err := aggregators.NewIdenticalAggregator(*config) require.NoError(t, err) @@ -37,7 +37,7 @@ func TestDataFeedsAggregator_Aggregate(t *testing.T) { } func TestDataFeedsAggregator_Aggregate_OverrideWithKeys(t *testing.T) { - config := getConfig(t, []string{"outcome"}) + config := getConfigIdenticalAggregator(t, []string{"outcome"}) agg, err := aggregators.NewIdenticalAggregator(*config) require.NoError(t, err) @@ -61,7 +61,7 @@ func TestDataFeedsAggregator_Aggregate_OverrideWithKeys(t *testing.T) { } func TestDataFeedsAggregator_Aggregate_NoConsensus(t *testing.T) { - config := getConfig(t, []string{"outcome"}) + config := getConfigIdenticalAggregator(t, []string{"outcome"}) agg, err := aggregators.NewIdenticalAggregator(*config) require.NoError(t, err) @@ -81,7 +81,7 @@ func TestDataFeedsAggregator_Aggregate_NoConsensus(t *testing.T) { require.ErrorContains(t, err, "can't reach consensus on observations with index 0") } -func getConfig(t *testing.T, overrideKeys []string) *values.Map { +func getConfigIdenticalAggregator(t *testing.T, overrideKeys []string) *values.Map { unwrappedConfig := map[string]any{ "expectedObservationsLen": len(overrideKeys), "keyOverrides": overrideKeys, diff --git a/pkg/capabilities/consensus/ocr3/aggregators/reduce_aggregator.go b/pkg/capabilities/consensus/ocr3/aggregators/reduce_aggregator.go new file mode 100644 index 000000000..f8c32af4a --- /dev/null +++ b/pkg/capabilities/consensus/ocr3/aggregators/reduce_aggregator.go @@ -0,0 +1,523 @@ +package aggregators + +import ( + "crypto/sha256" + "errors" + "fmt" + "math" + "math/big" + "sort" + "strconv" + "time" + + "github.com/shopspring/decimal" + "google.golang.org/protobuf/proto" + + ocrcommon "github.com/smartcontractkit/libocr/commontypes" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/values" + "github.com/smartcontractkit/chainlink-common/pkg/values/pb" +) + +const ( + AGGREGATION_METHOD_MEDIAN = "median" + AGGREGATION_METHOD_MODE = "mode" + DEVIATION_TYPE_NONE = "none" + DEVIATION_TYPE_PERCENT = "percent" + DEVIATION_TYPE_ABSOLUTE = "absolute" + REPORT_FORMAT_MAP = "map" + REPORT_FORMAT_ARRAY = "array" + REPORT_FORMAT_VALUE = "value" + + DEFAULT_REPORT_FORMAT = REPORT_FORMAT_MAP + DEFAULT_OUTPUT_FIELD_NAME = "Reports" +) + +type ReduceAggConfig struct { + // Configuration on how to aggregate one or more data points + Fields []AggregationField `mapstructure:"fields" required:"true"` + // The top level field name that report data is put into + OutputFieldName string `mapstructure:"outputFieldName" json:"outputFieldName" default:"Reports"` + // The structure surrounding the report data that is put on to "OutputFieldName" + ReportFormat string `mapstructure:"reportFormat" json:"reportFormat" default:"map" jsonschema:"enum=map,enum=array,enum=value"` + // Optional key name, that when given will contain a nested map with designated Fields moved into it + // If given, one or more fields must be given SubMapField: true + SubMapKey string `mapstructure:"subMapKey" json:"subMapKey" default:""` +} + +type AggregationField struct { + // An optional check to only report when the difference from the previous report exceeds a certain threshold. + // Can only be used when the field is of a numeric type: string, decimal, int64, big.Int, time.Time, float64 + // If no deviation is provided on any field, there will always be a report once minimum observations are reached. + Deviation decimal.Decimal `mapstructure:"-" json:"-"` + DeviationString string `mapstructure:"deviation" json:"deviation,omitempty"` + // The format of the deviation being provided + // * percent - a percentage deviation + // * absolute - an unsigned numeric difference + DeviationType string `mapstructure:"deviationType" json:"deviationType,omitempty" jsonschema:"enum=percent,enum=absolute,enum=none"` + // The key to find a data point within the input data + // If omitted, the entire input will be used + InputKey string `mapstructure:"inputKey" json:"inputKey"` + // How the data set should be aggregated to a single value + // * median - take the centermost value of the sorted data set of observations. can only be used on numeric types. not a true median, because no average if two middle values. + // * mode - take the most frequent value. if tied, use the "first". + Method string `mapstructure:"method" json:"method" jsonschema:"enum=median,enum=mode" required:"true"` + // The key that the aggregated data is put under + // If omitted, the InputKey will be used + OutputKey string `mapstructure:"outputKey" json:"outputKey"` + // If enabled, this field will be moved from the top level map + // into a nested map on the key defined by "SubMapKey" + SubMapField bool `mapstructure:"subMapField" json:"subMapField,omitempty"` +} + +type reduceAggregator struct { + config ReduceAggConfig +} + +var _ types.Aggregator = (*reduceAggregator)(nil) + +// Condenses multiple observations into a single encodable outcome +func (a *reduceAggregator) Aggregate(lggr logger.Logger, previousOutcome *types.AggregationOutcome, observations map[ocrcommon.OracleID][]values.Value, f int) (*types.AggregationOutcome, error) { + if len(observations) < 2*f+1 { + return nil, fmt.Errorf("not enough observations, have %d want %d", len(observations), 2*f+1) + } + + currentState, err := a.initializeCurrentState(lggr, previousOutcome) + if err != nil { + return nil, err + } + + report := map[string]any{} + shouldReport := false + + for _, field := range a.config.Fields { + vals := a.extractValues(lggr, observations, field.InputKey) + + // only proceed if every field has reached the minimum number of observations + if len(vals) < 2*f+1 { + return nil, fmt.Errorf("not enough observations provided %s, have %d want %d", field.InputKey, len(vals), 2*f+1) + } + + singleValue, err := reduce(field.Method, vals) + if err != nil { + return nil, fmt.Errorf("unable to reduce on method %s, err: %s", field.Method, err.Error()) + } + + if field.DeviationType != DEVIATION_TYPE_NONE { + oldValue := (*currentState)[field.OutputKey] + currDeviation, err := deviation(field.DeviationType, oldValue, singleValue) + if oldValue != nil && err != nil { + return nil, fmt.Errorf("unable to determine deviation %s", err.Error()) + } + if oldValue == nil || currDeviation.GreaterThan(field.Deviation) { + shouldReport = true + } + lggr.Debugw("checked deviation", "key", field.OutputKey, "deviationType", field.DeviationType, "currentDeviation", currDeviation.String(), "targetDeviation", field.Deviation.String(), "shouldReport", shouldReport) + } + + (*currentState)[field.OutputKey] = singleValue + if len(field.OutputKey) > 0 { + report[field.OutputKey] = singleValue + } else { + report[field.InputKey] = singleValue + } + } + + // if SubMapKey is provided, move fields in a nested map + if len(a.config.SubMapKey) > 0 { + subMap := map[string]any{} + for _, field := range a.config.Fields { + if field.SubMapField { + if len(field.OutputKey) > 0 { + subMap[field.OutputKey] = report[field.OutputKey] + delete(report, field.OutputKey) + } else { + subMap[field.InputKey] = report[field.InputKey] + delete(report, field.InputKey) + } + } + } + report[a.config.SubMapKey] = subMap + } + + // if none of the AggregationFields define deviation, always report + hasNoDeviation := true + for _, field := range a.config.Fields { + if field.DeviationType != DEVIATION_TYPE_NONE { + hasNoDeviation = false + break + } + } + if hasNoDeviation { + lggr.Debugw("no deviation defined, reporting") + shouldReport = true + } + + stateValuesMap, err := values.WrapMap(currentState) + if err != nil { + return nil, fmt.Errorf("aggregate state wrapmap error: %s", err.Error()) + } + stateBytes, err := proto.Marshal(values.ProtoMap(stateValuesMap)) + if err != nil { + return nil, fmt.Errorf("aggregate state proto marshal error: %s", err.Error()) + } + + toWrap, err := formatReport(report, a.config.ReportFormat) + if err != nil { + return nil, fmt.Errorf("aggregate formatReport error: %s", err.Error()) + } + reportValuesMap, err := values.NewMap(map[string]any{ + a.config.OutputFieldName: toWrap, + }) + if err != nil { + return nil, fmt.Errorf("aggregate new map error: %s", err.Error()) + } + reportProtoMap := values.Proto(reportValuesMap).GetMapValue() + + lggr.Debugw("Aggregation complete", "shouldReport", shouldReport) + + return &types.AggregationOutcome{ + EncodableOutcome: reportProtoMap, + Metadata: stateBytes, + ShouldReport: shouldReport, + }, nil +} + +func (a *reduceAggregator) initializeCurrentState(lggr logger.Logger, previousOutcome *types.AggregationOutcome) (*map[string]values.Value, error) { + currentState := map[string]values.Value{} + + if previousOutcome != nil { + pb := &pb.Map{} + proto.Unmarshal(previousOutcome.Metadata, pb) + mv, err := values.FromMapValueProto(pb) + if err != nil { + return nil, fmt.Errorf("initializeCurrentState FromMapValueProto error: %s", err.Error()) + } + err = mv.UnwrapTo(currentState) + if err != nil { + return nil, fmt.Errorf("initializeCurrentState FromMapValueProto error: %s", err.Error()) + } + } + + zeroValue := values.NewDecimal(decimal.Zero) + for _, field := range a.config.Fields { + if _, ok := currentState[field.OutputKey]; !ok { + currentState[field.OutputKey] = zeroValue + lggr.Debugw("initializing empty onchain state for feed", "fieldOutputKey", field.OutputKey) + } + } + + lggr.Debugw("current state initialized", "state", currentState, "previousOutcome", previousOutcome) + return ¤tState, nil +} + +func (a *reduceAggregator) extractValues(lggr logger.Logger, observations map[ocrcommon.OracleID][]values.Value, aggregationKey string) (vals []values.Value) { + for nodeID, nodeObservations := range observations { + // we only expect a single observation per node + if len(nodeObservations) == 0 || nodeObservations[0] == nil { + lggr.Warnf("node %d contributed with empty observations", nodeID) + continue + } + if len(nodeObservations) > 1 { + lggr.Warnf("node %d contributed with more than one observation", nodeID) + continue + } + + val, err := nodeObservations[0].Unwrap() + if err != nil { + lggr.Warnf("node %d contributed a Value that could not be unwrapped", nodeID) + continue + } + + // if the observation data is a complex type, extract the value using the inputKey + // values are then re-wrapped here to handle aggregating against Value types + // which is used for mode aggregation + switch val := val.(type) { + case map[string]interface{}: + _, ok := val[aggregationKey] + if !ok { + continue + } + + rewrapped, err := values.Wrap(val[aggregationKey]) + if err != nil { + lggr.Warnf("unable to wrap value %s", val[aggregationKey]) + continue + } + vals = append(vals, rewrapped) + case []interface{}: + i, err := strconv.Atoi(aggregationKey) + if err != nil { + lggr.Warnf("aggregation key %s could not be used to index a list type", aggregationKey) + continue + } + rewrapped, err := values.Wrap(val[i]) + if err != nil { + lggr.Warnf("unable to wrap value %s", val[i]) + continue + } + vals = append(vals, rewrapped) + default: + // not a complex type, use raw value + if len(aggregationKey) == 0 { + vals = append(vals, nodeObservations[0]) + } else { + lggr.Warnf("aggregation key %s provided, but value is not an indexable type", aggregationKey) + } + } + } + + return vals +} + +func reduce(method string, items []values.Value) (values.Value, error) { + switch method { + case AGGREGATION_METHOD_MEDIAN: + return median(items) + case AGGREGATION_METHOD_MODE: + return mode(items) + default: + // invariant, config should be validated + return nil, fmt.Errorf("unsupported aggregation method %s", method) + } +} + +func median(items []values.Value) (values.Value, error) { + if len(items) == 0 { + // invariant, as long as f > 0 there should be items + return nil, errors.New("items cannot be empty") + } + err := sortAsDecimal(items) + if err != nil { + return nil, err + } + return items[(len(items)-1)/2], nil +} + +func sortAsDecimal(items []values.Value) error { + var err error + sort.Slice(items, func(i, j int) bool { + decimalI, errI := toDecimal(items[i]) + if errI != nil { + err = errI + } + decimalJ, errJ := toDecimal(items[j]) + if errJ != nil { + err = errJ + } + return decimalI.GreaterThan(decimalJ) + }) + if err != nil { + return err + } + return nil +} + +func toDecimal(item values.Value) (decimal.Decimal, error) { + unwrapped, err := item.Unwrap() + if err != nil { + return decimal.NewFromInt(0), err + } + + switch v := unwrapped.(type) { + case string: + deci, err := decimal.NewFromString(unwrapped.(string)) + if err != nil { + return decimal.NewFromInt(0), err + } + return deci, nil + case decimal.Decimal: + return unwrapped.(decimal.Decimal), nil + case int64: + return decimal.NewFromInt(unwrapped.(int64)), nil + case *big.Int: + big := unwrapped.(*big.Int) + return decimal.NewFromBigInt(big, 10), nil + case time.Time: + return decimal.NewFromInt(unwrapped.(time.Time).Unix()), nil + case float64: + return decimal.NewFromFloat(unwrapped.(float64)), nil + default: + // unsupported type + return decimal.NewFromInt(0), fmt.Errorf("unable to convert type %T to decimal", v) + } +} + +func mode(items []values.Value) (values.Value, error) { + if len(items) == 0 { + // invariant, as long as f > 0 there should be items + return nil, errors.New("items cannot be empty") + } + + counts := make(map[[32]byte]*counter) + for _, item := range items { + marshalled, err := proto.MarshalOptions{Deterministic: true}.Marshal(values.Proto(item)) + if err != nil { + // invariant: values should always be able to be proto marshalled + return nil, err + } + sha := sha256.Sum256(marshalled) + elem, ok := counts[sha] + if !ok { + counts[sha] = &counter{ + fullObservation: item, + count: 1, + } + } else { + elem.count++ + } + } + + var maxCount int + for _, ctr := range counts { + if ctr.count > maxCount { + maxCount = ctr.count + } + } + + var modes []values.Value + for _, ctr := range counts { + if ctr.count == maxCount { + modes = append(modes, ctr.fullObservation) + } + } + + // If more than one mode found, choose first + + return modes[0], nil +} + +func deviation(method string, previousValue values.Value, nextValue values.Value) (decimal.Decimal, error) { + prevDeci, err := toDecimal(previousValue) + if err != nil { + return decimal.NewFromInt(0), err + } + nextDeci, err := toDecimal(nextValue) + if err != nil { + return decimal.NewFromInt(0), err + } + + diff := prevDeci.Sub(nextDeci).Abs() + + switch method { + case DEVIATION_TYPE_ABSOLUTE: + return diff, nil + case DEVIATION_TYPE_PERCENT: + if prevDeci.Cmp(decimal.NewFromInt(0)) == 0 { + if diff.Cmp(decimal.NewFromInt(0)) == 0 { + return decimal.NewFromInt(0), nil + } + return decimal.NewFromInt(math.MaxInt), nil + } + return diff.Div(prevDeci), nil + default: + return decimal.NewFromInt(0), fmt.Errorf("unsupported deviation method %s", method) + } +} + +func formatReport(report map[string]any, format string) (any, error) { + switch format { + case REPORT_FORMAT_ARRAY: + return []map[string]any{report}, nil + case REPORT_FORMAT_MAP: + return report, nil + case REPORT_FORMAT_VALUE: + for _, value := range report { + return value, nil + } + // invariant: validation enforces only one output value + return nil, errors.New("value format must contain at least one output") + default: + return nil, errors.New("unsupported report format") + } +} + +func isOneOf(toCheck string, options []string) bool { + for _, option := range options { + if toCheck == option { + return true + } + } + return false +} + +func NewReduceAggregator(config values.Map) (types.Aggregator, error) { + parsedConfig, err := ParseConfigReduceAggregator(config) + if err != nil { + return nil, fmt.Errorf("failed to parse config (%+v): %w", config, err) + } + return &reduceAggregator{ + config: parsedConfig, + }, nil +} + +func ParseConfigReduceAggregator(config values.Map) (ReduceAggConfig, error) { + parsedConfig := ReduceAggConfig{} + if err := config.UnwrapTo(&parsedConfig); err != nil { + return ReduceAggConfig{}, err + } + + // validations & fill defaults + if len(parsedConfig.Fields) == 0 { + return ReduceAggConfig{}, errors.New("reduce aggregator must contain config for Fields to aggregate") + } + if len(parsedConfig.OutputFieldName) == 0 { + parsedConfig.OutputFieldName = DEFAULT_OUTPUT_FIELD_NAME + } + if len(parsedConfig.ReportFormat) == 0 { + parsedConfig.ReportFormat = DEFAULT_REPORT_FORMAT + } + if len(parsedConfig.Fields) > 1 && parsedConfig.ReportFormat == REPORT_FORMAT_VALUE { + return ReduceAggConfig{}, errors.New("report type of value can only have one field") + } + hasSubMapField := false + outputKeyCount := make(map[any]bool) + for i, field := range parsedConfig.Fields { + if (parsedConfig.ReportFormat == REPORT_FORMAT_ARRAY || parsedConfig.ReportFormat == REPORT_FORMAT_MAP) && len(field.OutputKey) == 0 { + return ReduceAggConfig{}, fmt.Errorf("report type %s or %s must have an OutputKey to put the result under", REPORT_FORMAT_ARRAY, REPORT_FORMAT_MAP) + } + if len(field.DeviationType) == 0 { + field.DeviationType = DEVIATION_TYPE_NONE + parsedConfig.Fields[i].DeviationType = DEVIATION_TYPE_NONE + } + if !isOneOf(field.DeviationType, []string{DEVIATION_TYPE_ABSOLUTE, DEVIATION_TYPE_PERCENT, DEVIATION_TYPE_NONE}) { + return ReduceAggConfig{}, fmt.Errorf("invalid config DeviationType. received: %s. options: [%s, %s, %s]", field.DeviationType, DEVIATION_TYPE_ABSOLUTE, DEVIATION_TYPE_PERCENT, DEVIATION_TYPE_NONE) + } + if field.DeviationType != DEVIATION_TYPE_NONE && len(field.DeviationString) == 0 { + return ReduceAggConfig{}, errors.New("aggregation field deviation must contain DeviationString amount") + } + if field.DeviationType != DEVIATION_TYPE_NONE && len(field.DeviationString) > 0 { + deci, err := decimal.NewFromString(field.DeviationString) + if err != nil { + return ReduceAggConfig{}, fmt.Errorf("reduce aggregator could not parse deviation decimal from string %s", field.DeviationString) + } + parsedConfig.Fields[i].Deviation = deci + } + if len(field.Method) == 0 || !isOneOf(field.Method, []string{AGGREGATION_METHOD_MEDIAN, AGGREGATION_METHOD_MODE}) { + return ReduceAggConfig{}, fmt.Errorf("aggregation field must contain a method. options: [%s, %s]", AGGREGATION_METHOD_MEDIAN, AGGREGATION_METHOD_MODE) + } + if len(field.DeviationString) > 0 && field.DeviationType == DEVIATION_TYPE_NONE { + return ReduceAggConfig{}, fmt.Errorf("aggregation field cannot have deviation with a deviation type of %s", DEVIATION_TYPE_NONE) + } + if field.SubMapField { + hasSubMapField = true + } + if outputKeyCount[field.OutputKey] { + return ReduceAggConfig{}, errors.New("multiple fields have the same outputkey, which will overwrite each other") + } + outputKeyCount[field.OutputKey] = true + } + if len(parsedConfig.SubMapKey) > 0 && !hasSubMapField { + return ReduceAggConfig{}, fmt.Errorf("sub Map key %s given, but no fields are marked as sub map fields", parsedConfig.SubMapKey) + } + if hasSubMapField && len(parsedConfig.SubMapKey) == 0 { + return ReduceAggConfig{}, errors.New("fields are marked as sub Map fields, but no sub map key given") + } + if !isOneOf(parsedConfig.ReportFormat, []string{REPORT_FORMAT_ARRAY, REPORT_FORMAT_MAP, REPORT_FORMAT_VALUE}) { + return ReduceAggConfig{}, fmt.Errorf("invalid config ReportFormat. received: %s. options: %s, %s, %s", parsedConfig.ReportFormat, REPORT_FORMAT_ARRAY, REPORT_FORMAT_MAP, REPORT_FORMAT_VALUE) + } + + return parsedConfig, nil +} diff --git a/pkg/capabilities/consensus/ocr3/aggregators/reduce_test.go b/pkg/capabilities/consensus/ocr3/aggregators/reduce_test.go new file mode 100644 index 000000000..66467dd62 --- /dev/null +++ b/pkg/capabilities/consensus/ocr3/aggregators/reduce_test.go @@ -0,0 +1,1131 @@ +package aggregators_test + +import ( + "math/big" + "testing" + "time" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/smartcontractkit/libocr/commontypes" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/aggregators" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/datastreams" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/values" + "github.com/smartcontractkit/chainlink-common/pkg/values/pb" +) + +var ( + feedIDA = datastreams.FeedID("0x0001013ebd4ed3f5889fb5a8a52b42675c60c1a8c42bc79eaa72dcd922ac4292") + idABytes = feedIDA.Bytes() + feedIDB = datastreams.FeedID("0x0003c317fec7fad514c67aacc6366bf2f007ce37100e3cddcacd0ccaa1f3746d") + idBBytes = feedIDB.Bytes() + now = time.Now() +) + +func TestReduceAggregator_Aggregate(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + cases := []struct { + name string + fields []aggregators.AggregationField + extraConfig map[string]any + observationsFactory func() map[commontypes.OracleID][]values.Value + shouldReport bool + expectedState any + expectedOutcome map[string]any + }{ + { + name: "aggregate on int64 median", + fields: []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "mode", + }, + { + InputKey: "BenchmarkPrice", + OutputKey: "Price", + Method: "median", + DeviationString: "10", + DeviationType: "percent", + }, + { + InputKey: "Timestamp", + OutputKey: "Timestamp", + Method: "median", + DeviationString: "100", + DeviationType: "absolute", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue, err := values.WrapMap(map[string]any{ + "FeedID": idABytes[:], + "BenchmarkPrice": int64(100), + "Timestamp": 12341414929, + }) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}} + }, + shouldReport: true, + expectedOutcome: map[string]any{ + "Reports": []any{ + map[string]any{ + "FeedID": idABytes[:], + "Timestamp": int64(12341414929), + "Price": int64(100), + }, + }, + }, + expectedState: map[string]any{ + "FeedID": idABytes[:], + "Timestamp": int64(12341414929), + "Price": int64(100), + }, + }, + { + name: "aggregate on decimal median", + fields: []aggregators.AggregationField{ + { + InputKey: "BenchmarkPrice", + OutputKey: "Price", + Method: "median", + DeviationString: "10", + DeviationType: "percent", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue, err := values.WrapMap(map[string]any{ + "BenchmarkPrice": decimal.NewFromInt(32), + }) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}} + }, + shouldReport: true, + expectedOutcome: map[string]any{ + "Reports": []any{ + map[string]any{ + "Price": decimal.NewFromInt(32), + }, + }, + }, + expectedState: map[string]any{ + "Price": decimal.NewFromInt(32), + }, + }, + { + name: "aggregate on float64 median", + fields: []aggregators.AggregationField{ + { + InputKey: "BenchmarkPrice", + OutputKey: "Price", + Method: "median", + DeviationString: "10", + DeviationType: "percent", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue, err := values.WrapMap(map[string]any{ + "BenchmarkPrice": float64(1.2), + }) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}} + }, + shouldReport: true, + expectedOutcome: map[string]any{ + "Reports": []any{ + map[string]any{ + "Price": float64(1.2), + }, + }, + }, + expectedState: map[string]any{ + "Price": float64(1.2), + }, + }, + { + name: "aggregate on time median", + fields: []aggregators.AggregationField{ + { + InputKey: "BenchmarkPrice", + OutputKey: "Price", + Method: "median", + DeviationString: "10", + DeviationType: "percent", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue, err := values.WrapMap(map[string]any{ + "BenchmarkPrice": now, + }) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}} + }, + shouldReport: true, + expectedOutcome: map[string]any{ + "Reports": []any{ + map[string]any{ + "Price": time.Time(now).UTC(), + }, + }, + }, + expectedState: map[string]any{ + "Price": now.UTC(), + }, + }, + { + name: "aggregate on big int median", + fields: []aggregators.AggregationField{ + { + InputKey: "BenchmarkPrice", + OutputKey: "Price", + Method: "median", + DeviationString: "10", + DeviationType: "percent", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue, err := values.WrapMap(map[string]any{ + "BenchmarkPrice": big.NewInt(100), + }) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}} + }, + shouldReport: true, + expectedOutcome: map[string]any{ + "Reports": []any{ + map[string]any{ + "Price": big.NewInt(100), + }, + }, + }, + expectedState: map[string]any{ + "Price": big.NewInt(100), + }, + }, + { + name: "aggregate on bytes mode", + fields: []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "mode", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue1, err := values.WrapMap(map[string]any{ + "FeedID": idABytes[:], + }) + require.NoError(t, err) + mockValue2, err := values.WrapMap(map[string]any{ + "FeedID": idBBytes[:], + }) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue1}, 2: {mockValue1}, 3: {mockValue2}, 4: {mockValue1}} + }, + shouldReport: true, + expectedOutcome: map[string]any{ + "Reports": []any{ + map[string]any{ + "FeedID": idABytes[:], + }, + }, + }, + expectedState: map[string]any{ + "FeedID": idABytes[:], + }, + }, + { + name: "aggregate on string mode", + fields: []aggregators.AggregationField{ + { + InputKey: "BenchmarkPrice", + OutputKey: "Price", + Method: "mode", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue1, err := values.WrapMap(map[string]any{ + "BenchmarkPrice": "1", + }) + require.NoError(t, err) + mockValue2, err := values.WrapMap(map[string]any{ + "BenchmarkPrice": "2", + }) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue1}, 2: {mockValue1}, 3: {mockValue2}} + }, + shouldReport: true, + expectedOutcome: map[string]any{ + "Reports": []any{ + map[string]any{ + "Price": "1", + }, + }, + }, + expectedState: map[string]any{ + "Price": "1", + }, + }, + { + name: "aggregate on bool mode", + fields: []aggregators.AggregationField{ + { + InputKey: "BenchmarkPrice", + OutputKey: "Price", + Method: "mode", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue1, err := values.WrapMap(map[string]any{ + "BenchmarkPrice": true, + }) + require.NoError(t, err) + mockValue2, err := values.WrapMap(map[string]any{ + "BenchmarkPrice": false, + }) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue1}, 2: {mockValue1}, 3: {mockValue2}} + }, + shouldReport: true, + expectedOutcome: map[string]any{ + "Reports": []any{ + map[string]any{ + "Price": true, + }, + }, + }, + expectedState: map[string]any{ + "Price": true, + }, + }, + { + name: "aggregate on non-indexable type", + fields: []aggregators.AggregationField{ + { + // Omitting "InputKey" + OutputKey: "Price", + Method: "median", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue, err := values.Wrap(1) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}} + }, + shouldReport: true, + expectedOutcome: map[string]any{ + "Reports": []any{ + map[string]any{ + "Price": int64(1), + }, + }, + }, + expectedState: map[string]any{"Price": int64(1)}, + }, + { + name: "aggregate on list type", + fields: []aggregators.AggregationField{ + { + InputKey: "1", + OutputKey: "Price", + Method: "median", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue, err := values.NewList([]any{"1", "2", "3"}) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}} + }, + shouldReport: true, + expectedOutcome: map[string]any{ + "Reports": []any{ + map[string]any{ + "Price": "2", + }, + }, + }, + expectedState: map[string]any{ + "Price": "2", + }, + }, + { + name: "submap", + fields: []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "mode", + }, + { + InputKey: "BenchmarkPrice", + OutputKey: "Price", + Method: "median", + DeviationString: "10", + DeviationType: "percent", + SubMapField: true, + }, + { + InputKey: "Timestamp", + OutputKey: "Timestamp", + Method: "median", + DeviationString: "100", + DeviationType: "absolute", + }, + }, + extraConfig: map[string]any{ + "SubMapKey": "Report", + }, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue, err := values.WrapMap(map[string]any{ + "FeedID": idABytes[:], + "BenchmarkPrice": int64(100), + "Timestamp": 12341414929, + }) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}} + }, + shouldReport: true, + expectedOutcome: map[string]any{ + "Reports": []any{ + map[string]any{ + "FeedID": idABytes[:], + "Timestamp": int64(12341414929), + "Report": map[string]any{ + "Price": int64(100), + }, + }, + }, + }, + expectedState: map[string]any{ + "FeedID": idABytes[:], + "Price": int64(100), + "Timestamp": int64(12341414929), + }, + }, + { + name: "report format value", + fields: []aggregators.AggregationField{ + { + OutputKey: "Price", + Method: "median", + }, + }, + extraConfig: map[string]any{ + "reportFormat": "value", + }, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue, err := values.Wrap(1) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}} + }, + shouldReport: true, + expectedOutcome: map[string]any{ + "Reports": int64(1), + }, + expectedState: map[string]any{"Price": int64(1)}, + }, + { + name: "report format array", + fields: []aggregators.AggregationField{ + { + OutputKey: "Price", + Method: "median", + }, + }, + extraConfig: map[string]any{ + "reportFormat": "array", + }, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue, err := values.Wrap(1) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}} + }, + shouldReport: true, + expectedOutcome: map[string]any{ + "Reports": []any{map[string]any{"Price": int64(1)}}, + }, + expectedState: map[string]any{"Price": int64(1)}, + }, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + config := getConfigReduceAggregator(t, tt.fields, tt.extraConfig) + agg, err := aggregators.NewReduceAggregator(*config) + require.NoError(t, err) + + pb := &pb.Map{} + outcome, err := agg.Aggregate(logger.Nop(), nil, tt.observationsFactory(), 1) + require.NoError(t, err) + require.Equal(t, tt.shouldReport, outcome.ShouldReport) + + // validate metadata + proto.Unmarshal(outcome.Metadata, pb) + vmap, err := values.FromMapValueProto(pb) + require.NoError(t, err) + state, err := vmap.Unwrap() + require.NoError(t, err) + require.Equal(t, tt.expectedState, state) + + // validate encodable outcome + val, err := values.FromMapValueProto(outcome.EncodableOutcome) + require.NoError(t, err) + topLevelMap, err := val.Unwrap() + require.NoError(t, err) + mm, ok := topLevelMap.(map[string]any) + require.True(t, ok) + + require.NoError(t, err) + + require.Equal(t, tt.expectedOutcome, mm) + }) + } + }) + + t.Run("error path", func(t *testing.T) { + cases := []struct { + name string + previousOutcome *types.AggregationOutcome + fields []aggregators.AggregationField + extraConfig map[string]any + observationsFactory func() map[commontypes.OracleID][]values.Value + }{ + { + name: "not enough observations", + previousOutcome: nil, + fields: []aggregators.AggregationField{ + { + Method: "median", + OutputKey: "Price", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + return map[commontypes.OracleID][]values.Value{} + }, + }, + { + name: "empty previous outcome", + previousOutcome: &types.AggregationOutcome{}, + fields: []aggregators.AggregationField{ + { + Method: "median", + OutputKey: "Price", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue, err := values.Wrap(int64(100)) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}} + }, + }, + { + name: "invalid previous outcome not pb", + previousOutcome: &types.AggregationOutcome{ + Metadata: []byte{1, 2, 3}, + }, + fields: []aggregators.AggregationField{ + { + Method: "median", + OutputKey: "Price", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue, err := values.Wrap(int64(100)) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}} + }, + }, + { + name: "not enough extracted values", + previousOutcome: nil, + fields: []aggregators.AggregationField{ + { + InputKey: "Price", + OutputKey: "Price", + Method: "median", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue, err := values.WrapMap(map[string]any{"Price": int64(100)}) + require.NoError(t, err) + mockValueEmpty := values.EmptyMap() + return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValueEmpty}} + }, + }, + { + name: "reduce error median", + previousOutcome: nil, + fields: []aggregators.AggregationField{ + { + Method: "median", + OutputKey: "Price", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue, err := values.Wrap(true) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}} + }, + }, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + config := getConfigReduceAggregator(t, tt.fields, tt.extraConfig) + agg, err := aggregators.NewReduceAggregator(*config) + require.NoError(t, err) + + _, err = agg.Aggregate(logger.Nop(), tt.previousOutcome, tt.observationsFactory(), 1) + require.Error(t, err) + }) + } + }) +} + +func TestInputChanges(t *testing.T) { + fields := []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "mode", + }, + { + InputKey: "BenchmarkPrice", + OutputKey: "Price", + Method: "median", + DeviationString: "10", + DeviationType: "percent", + }, + { + InputKey: "Timestamp", + OutputKey: "Timestamp", + Method: "median", + DeviationString: "100", + DeviationType: "absolute", + }, + } + config := getConfigReduceAggregator(t, fields, map[string]any{}) + agg, err := aggregators.NewReduceAggregator(*config) + require.NoError(t, err) + + // First Round + mockValue1, err := values.WrapMap(map[string]any{ + "FeedID": idABytes[:], + "BenchmarkPrice": int64(100), + "Timestamp": 12341414929, + }) + require.NoError(t, err) + pb := &pb.Map{} + outcome, err := agg.Aggregate(logger.Nop(), nil, map[commontypes.OracleID][]values.Value{1: {mockValue1}, 2: {mockValue1}, 3: {mockValue1}}, 1) + require.NoError(t, err) + shouldReport := true + require.Equal(t, shouldReport, outcome.ShouldReport) + + // validate metadata + proto.Unmarshal(outcome.Metadata, pb) + vmap, err := values.FromMapValueProto(pb) + require.NoError(t, err) + state, err := vmap.Unwrap() + require.NoError(t, err) + expectedState1 := map[string]any{ + "FeedID": idABytes[:], + "Price": int64(100), + "Timestamp": int64(12341414929), + } + require.Equal(t, expectedState1, state) + + // validate encodable outcome + val, err := values.FromMapValueProto(outcome.EncodableOutcome) + require.NoError(t, err) + topLevelMap, err := val.Unwrap() + require.NoError(t, err) + mm, ok := topLevelMap.(map[string]any) + require.True(t, ok) + + require.NoError(t, err) + expectedOutcome1 := map[string]any{ + "Reports": []any{ + map[string]any{ + "FeedID": idABytes[:], + "Timestamp": int64(12341414929), + "Price": int64(100), + }, + }, + } + require.Equal(t, expectedOutcome1, mm) + + // Second Round + mockValue2, err := values.WrapMap(map[string]any{ + "FeedID": true, + "Timestamp": int64(12341414929), + "BenchmarkPrice": int64(100), + }) + require.NoError(t, err) + outcome, err = agg.Aggregate(logger.Nop(), nil, map[commontypes.OracleID][]values.Value{1: {mockValue2}, 2: {mockValue2}, 3: {mockValue2}}, 1) + require.NoError(t, err) + require.Equal(t, shouldReport, outcome.ShouldReport) + + // validate metadata + proto.Unmarshal(outcome.Metadata, pb) + vmap, err = values.FromMapValueProto(pb) + require.NoError(t, err) + state, err = vmap.Unwrap() + require.NoError(t, err) + expectedState2 := map[string]any{ + "FeedID": true, + "Price": int64(100), + "Timestamp": int64(12341414929), + } + require.Equal(t, expectedState2, state) + + // validate encodable outcome + val, err = values.FromMapValueProto(outcome.EncodableOutcome) + require.NoError(t, err) + topLevelMap, err = val.Unwrap() + require.NoError(t, err) + mm, ok = topLevelMap.(map[string]any) + require.True(t, ok) + + require.NoError(t, err) + expectedOutcome2 := map[string]any{ + "Reports": []any{ + map[string]any{ + "FeedID": true, + "Timestamp": int64(12341414929), + "Price": int64(100), + }, + }, + } + + require.Equal(t, expectedOutcome2, mm) + +} + +func TestMedianAggregator_ParseConfig(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + cases := []struct { + name string + inputFactory func() map[string]any + outputFactory func() aggregators.ReduceAggConfig + }{ + { + name: "no inputkey", + inputFactory: func() map[string]any { + return map[string]any{ + "fields": []aggregators.AggregationField{ + { + Method: "median", + OutputKey: "Price", + }, + }, + } + }, + outputFactory: func() aggregators.ReduceAggConfig { + return aggregators.ReduceAggConfig{ + Fields: []aggregators.AggregationField{ + { + InputKey: "", + OutputKey: "Price", + Method: "median", + DeviationString: "", + Deviation: decimal.Decimal{}, + DeviationType: "none", + }, + }, + OutputFieldName: "Reports", + ReportFormat: "map", + } + }, + }, + { + name: "reportFormat map, aggregation method mode, deviation", + inputFactory: func() map[string]any { + return map[string]any{ + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedId", + Method: "mode", + DeviationString: "1.1", + DeviationType: "absolute", + }, + }, + } + }, + outputFactory: func() aggregators.ReduceAggConfig { + return aggregators.ReduceAggConfig{ + Fields: []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedId", + Method: "mode", + DeviationString: "1.1", + Deviation: decimal.NewFromFloat(1.1), + DeviationType: "absolute", + }, + }, + OutputFieldName: "Reports", + ReportFormat: "map", + } + }, + }, + { + name: "reportFormat array, aggregation method median, no deviation", + inputFactory: func() map[string]any { + return map[string]any{ + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedId", + Method: "median", + }, + }, + "outputFieldName": "Reports", + "reportFormat": "array", + } + }, + outputFactory: func() aggregators.ReduceAggConfig { + return aggregators.ReduceAggConfig{ + Fields: []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedId", + Method: "median", + DeviationString: "", + Deviation: decimal.Decimal{}, + DeviationType: "none", + }, + }, + OutputFieldName: "Reports", + ReportFormat: "array", + } + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + vMap, err := values.NewMap(tt.inputFactory()) + require.NoError(t, err) + parsedConfig, err := aggregators.ParseConfigReduceAggregator(*vMap) + require.NoError(t, err) + require.Equal(t, tt.outputFactory(), parsedConfig) + }) + } + }) + + t.Run("unhappy path", func(t *testing.T) { + cases := []struct { + name string + configFactory func() *values.Map + }{ + { + name: "empty", + configFactory: func() *values.Map { + return values.EmptyMap() + }, + }, + { + name: "invalid report format", + configFactory: func() *values.Map { + vMap, err := values.NewMap(map[string]any{ + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "median", + }, + }, + "reportFormat": "invalid", + }) + require.NoError(t, err) + return vMap + }, + }, + { + name: "field with no method", + configFactory: func() *values.Map { + vMap, err := values.NewMap(map[string]any{ + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + }, + }, + }) + require.NoError(t, err) + return vMap + }, + }, + { + name: "field with empty method", + configFactory: func() *values.Map { + vMap, err := values.NewMap(map[string]any{ + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "", + }, + }, + }) + require.NoError(t, err) + return vMap + }, + }, + { + name: "field with invalid method", + configFactory: func() *values.Map { + vMap, err := values.NewMap(map[string]any{ + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "invalid", + }, + }, + }) + require.NoError(t, err) + return vMap + }, + }, + { + name: "field with deviation string but no deviation type", + configFactory: func() *values.Map { + vMap, err := values.NewMap(map[string]any{ + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "median", + DeviationString: "1", + }, + }, + }) + require.NoError(t, err) + return vMap + }, + }, + { + name: "field with deviation string but empty deviation type", + configFactory: func() *values.Map { + vMap, err := values.NewMap(map[string]any{ + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "median", + DeviationString: "1", + DeviationType: "", + }, + }, + }) + require.NoError(t, err) + return vMap + }, + }, + { + name: "field with invalid deviation type", + configFactory: func() *values.Map { + vMap, err := values.NewMap(map[string]any{ + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "median", + DeviationString: "1", + DeviationType: "invalid", + }, + }, + }) + require.NoError(t, err) + return vMap + }, + }, + { + name: "field with deviation type but no deviation string", + configFactory: func() *values.Map { + vMap, err := values.NewMap(map[string]any{ + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "median", + DeviationType: "absolute", + }, + }, + }) + require.NoError(t, err) + return vMap + }, + }, + { + name: "field with deviation type but empty deviation string", + configFactory: func() *values.Map { + vMap, err := values.NewMap(map[string]any{ + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "median", + DeviationType: "absolute", + DeviationString: "", + }, + }, + }) + require.NoError(t, err) + return vMap + }, + }, + { + name: "field with invalid deviation string", + configFactory: func() *values.Map { + vMap, err := values.NewMap(map[string]any{ + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "median", + DeviationType: "absolute", + DeviationString: "1-1", + }, + }, + }) + require.NoError(t, err) + return vMap + }, + }, + { + name: "field with sub report, but no sub report key", + configFactory: func() *values.Map { + vMap, err := values.NewMap(map[string]any{ + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "median", + SubMapField: true, + }, + }, + }) + require.NoError(t, err) + return vMap + }, + }, + { + name: "sub report key, but no sub report fields", + configFactory: func() *values.Map { + vMap, err := values.NewMap(map[string]any{ + "subMapKey": "Report", + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "median", + }, + }, + }) + require.NoError(t, err) + return vMap + }, + }, + { + name: "clashing output keys", + configFactory: func() *values.Map { + vMap, err := values.NewMap(map[string]any{ + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "median", + }, + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "median", + }, + }, + }) + require.NoError(t, err) + return vMap + }, + }, + { + name: "map/array type, no output key", + configFactory: func() *values.Map { + vMap, err := values.NewMap(map[string]any{ + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + Method: "median", + }, + }, + }) + require.NoError(t, err) + return vMap + }, + }, + { + name: "report type value with multiple fields", + configFactory: func() *values.Map { + vMap, err := values.NewMap(map[string]any{ + "reportFormat": "value", + "fields": []aggregators.AggregationField{ + { + InputKey: "FeedID", + Method: "median", + OutputKey: "FeedID", + }, + { + InputKey: "Price", + Method: "median", + OutputKey: "Price", + }, + }, + }) + require.NoError(t, err) + return vMap + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + _, err := aggregators.ParseConfigReduceAggregator(*tt.configFactory()) + require.Error(t, err) + }, + ) + } + }) +} + +func getConfigReduceAggregator(t *testing.T, fields []aggregators.AggregationField, override map[string]any) *values.Map { + unwrappedConfig := map[string]any{ + "fields": fields, + "outputFieldName": "Reports", + "reportFormat": "array", + } + for key, val := range override { + unwrappedConfig[key] = val + } + config, err := values.NewMap(unwrappedConfig) + require.NoError(t, err) + return config +} diff --git a/pkg/capabilities/consensus/ocr3/capability_test.go b/pkg/capabilities/consensus/ocr3/capability_test.go index e383a487c..f1a8480ed 100644 --- a/pkg/capabilities/consensus/ocr3/capability_test.go +++ b/pkg/capabilities/consensus/ocr3/capability_test.go @@ -68,62 +68,79 @@ func TestOCR3Capability_Schema(t *testing.T) { } func TestOCR3Capability(t *testing.T) { - n := time.Now() - fc := clockwork.NewFakeClockAt(n) - lggr := logger.Test(t) - - ctx := tests.Context(t) - - s := requests.NewStore() - - cp := newCapability(s, fc, 1*time.Second, mockAggregatorFactory, mockEncoderFactory, lggr, 10) - require.NoError(t, cp.Start(ctx)) - - config, err := values.NewMap( - map[string]any{ - "aggregation_method": "data_feeds", - "aggregation_config": map[string]any{}, - "encoder_config": map[string]any{}, - "encoder": "evm", - "report_id": "ffff", + cases := []struct { + name string + aggregationMethod string + }{ + { + name: "success - aggregation_method data_feeds", + aggregationMethod: "data_feeds", }, - ) - require.NoError(t, err) - - ethUsdValStr := "1.123456" - ethUsdValue, err := decimal.NewFromString(ethUsdValStr) - require.NoError(t, err) - observationKey := "ETH_USD" - obs := []any{map[string]any{observationKey: ethUsdValue}} - inputs, err := values.NewMap(map[string]any{"observations": obs}) - require.NoError(t, err) - - executeReq := capabilities.CapabilityRequest{ - Metadata: capabilities.RequestMetadata{ - WorkflowID: workflowTestID, - WorkflowExecutionID: workflowExecutionTestID, + { + name: "success - aggregation_method reduce", + aggregationMethod: "reduce", }, - Config: config, - Inputs: inputs, } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + n := time.Now() + fc := clockwork.NewFakeClockAt(n) + lggr := logger.Test(t) + + ctx := tests.Context(t) + + s := requests.NewStore() + + cp := newCapability(s, fc, 1*time.Second, mockAggregatorFactory, mockEncoderFactory, lggr, 10) + require.NoError(t, cp.Start(ctx)) + + config, err := values.NewMap( + map[string]any{ + "aggregation_method": tt.aggregationMethod, + "aggregation_config": map[string]any{}, + "encoder_config": map[string]any{}, + "encoder": "evm", + "report_id": "ffff", + }, + ) + require.NoError(t, err) + + ethUsdValStr := "1.123456" + ethUsdValue, err := decimal.NewFromString(ethUsdValStr) + require.NoError(t, err) + observationKey := "ETH_USD" + obs := []any{map[string]any{observationKey: ethUsdValue}} + inputs, err := values.NewMap(map[string]any{"observations": obs}) + require.NoError(t, err) + + executeReq := capabilities.CapabilityRequest{ + Metadata: capabilities.RequestMetadata{ + WorkflowID: workflowTestID, + WorkflowExecutionID: workflowExecutionTestID, + }, + Config: config, + Inputs: inputs, + } - respCh := executeAsync(ctx, executeReq, cp.Execute) + respCh := executeAsync(ctx, executeReq, cp.Execute) - obsv, err := values.NewList(obs) - require.NoError(t, err) + obsv, err := values.NewList(obs) + require.NoError(t, err) - // Mock the oracle returning a response - mresp, err := values.NewMap(map[string]any{"observations": obsv}) - cp.reqHandler.SendResponse(ctx, requests.Response{ - Value: mresp, - WorkflowExecutionID: workflowExecutionTestID, - }) - require.NoError(t, err) + // Mock the oracle returning a response + mresp, err := values.NewMap(map[string]any{"observations": obsv}) + cp.reqHandler.SendResponse(ctx, requests.Response{ + Value: mresp, + WorkflowExecutionID: workflowExecutionTestID, + }) + require.NoError(t, err) - resp := <-respCh - assert.Nil(t, resp.Err) + resp := <-respCh + assert.Nil(t, resp.Err) - assert.Equal(t, mresp, resp.Value) + assert.Equal(t, mresp, resp.Value) + }) + } } func TestOCR3Capability_Eviction(t *testing.T) { diff --git a/pkg/capabilities/consensus/ocr3/models.go b/pkg/capabilities/consensus/ocr3/models.go index 86662dc31..9e7887685 100644 --- a/pkg/capabilities/consensus/ocr3/models.go +++ b/pkg/capabilities/consensus/ocr3/models.go @@ -5,7 +5,7 @@ import ( ) type config struct { - AggregationMethod string `mapstructure:"aggregation_method" json:"aggregation_method" jsonschema:"enum=data_feeds"` + AggregationMethod string `mapstructure:"aggregation_method" json:"aggregation_method" jsonschema:"enum=data_feeds,enum=reduce"` AggregationConfig *values.Map `mapstructure:"aggregation_config" json:"aggregation_config"` Encoder string `mapstructure:"encoder" json:"encoder"` EncoderConfig *values.Map `mapstructure:"encoder_config" json:"encoder_config"` diff --git a/pkg/capabilities/consensus/ocr3/ocr3cap/reduce_consensus.go b/pkg/capabilities/consensus/ocr3/ocr3cap/reduce_consensus.go new file mode 100644 index 000000000..9bff885b0 --- /dev/null +++ b/pkg/capabilities/consensus/ocr3/ocr3cap/reduce_consensus.go @@ -0,0 +1,51 @@ +package ocr3cap + +import ( + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/aggregators" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk" +) + +// Note this isn't generated because generics isn't supported in json schema + +type ReduceConsensusConfig[T any] struct { + Encoder Encoder + EncoderConfig EncoderConfig + ReportID ReportId + AggregationConfig aggregators.ReduceAggConfig +} + +func (c ReduceConsensusConfig[T]) New(w *sdk.WorkflowSpecFactory, ref string, input ReduceConsensusInput[T]) SignedReportCap { + def := sdk.StepDefinition{ + ID: "offchain_reporting@1.0.0", + Ref: ref, + Inputs: input.ToSteps(), + Config: map[string]any{ + "aggregation_method": "reduce", + "aggregation_config": c.AggregationConfig, + "encoder": c.Encoder, + "encoder_config": c.EncoderConfig, + "report_id": c.ReportID, + }, + CapabilityType: capabilities.CapabilityTypeConsensus, + } + + step := sdk.Step[SignedReport]{Definition: def} + return SignedReportWrapper(step.AddTo(w)) +} + +type ReduceConsensusInput[T any] struct { + Observation sdk.CapDefinition[T] + Encoder Encoder + EncoderConfig EncoderConfig +} + +func (input ReduceConsensusInput[T]) ToSteps() sdk.StepInputs { + return sdk.StepInputs{ + Mapping: map[string]any{ + "observations": sdk.ListOf(input.Observation).Ref(), + "encoder": input.Encoder, + "encoderConfig": input.EncoderConfig, + }, + } +} diff --git a/pkg/capabilities/consensus/ocr3/ocr3cap/reduce_consensus_test.go b/pkg/capabilities/consensus/ocr3/ocr3cap/reduce_consensus_test.go new file mode 100644 index 000000000..a74123d4f --- /dev/null +++ b/pkg/capabilities/consensus/ocr3/ocr3cap/reduce_consensus_test.go @@ -0,0 +1,154 @@ +package ocr3cap_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/cli/cmd/testdata/fixtures/capabilities/basictrigger" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/aggregators" + ocr3 "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/ocr3cap" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/targets/chainwriter" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk/testutils" +) + +func TestReduceConsensus(t *testing.T) { + t.Parallel() + workflow := sdk.NewWorkflowSpecFactory(sdk.NewWorkflowParams{ + Owner: "0x1234", + Name: "Test", + }) + + trigger := basictrigger.TriggerConfig{Name: "1234", Number: 1}.New(workflow) + + consensus := ocr3.ReduceConsensusConfig[basictrigger.TriggerOutputs]{ + Encoder: ocr3.EncoderEVM, + EncoderConfig: ocr3.EncoderConfig{}, + ReportID: "0001", + AggregationConfig: aggregators.ReduceAggConfig{ + Fields: []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "mode", + }, + { + InputKey: "Timestamp", + OutputKey: "Timestamp", + Method: "median", + DeviationString: "3600", // 1 hour in seconds + DeviationType: "absolute", + }, + { + InputKey: "Price", + OutputKey: "Price", + Method: "median", + DeviationString: "0.05", // 5% + DeviationType: "percent", + SubMapField: true, + }, + }, + OutputFieldName: "Reports", + ReportFormat: "array", + SubMapKey: "Report", + }, + }.New(workflow, "consensus", ocr3.ReduceConsensusInput[basictrigger.TriggerOutputs]{ + Observation: trigger, + Encoder: "evm", + EncoderConfig: ocr3.EncoderConfig(map[string]any{ + "abi": "(bytes32 FeedID, bytes Report, uint32 Timestamp)[] Reports", + }), + }) + + chainwriter.TargetConfig{ + Address: "0x1235", + DeltaStage: "45s", + Schedule: "oneAtATime", + }.New(workflow, "chainwriter@1.0.0", chainwriter.TargetInput{SignedReport: consensus}) + + actual, err := workflow.Spec() + require.NoError(t, err) + + expected := sdk.WorkflowSpec{ + Name: "Test", + Owner: "0x1234", + Triggers: []sdk.StepDefinition{ + { + ID: "basic-test-trigger@1.0.0", + Ref: "trigger", + Inputs: sdk.StepInputs{}, + Config: map[string]any{ + "name": "1234", + "number": 1, + }, + CapabilityType: capabilities.CapabilityTypeTrigger, + }, + }, + Actions: []sdk.StepDefinition{}, + Consensus: []sdk.StepDefinition{ + { + ID: "offchain_reporting@1.0.0", + Ref: "consensus", + Inputs: sdk.StepInputs{Mapping: map[string]any{ + "observations": []any{"$(trigger.outputs)"}, + "encoder": "evm", + "encoderConfig": map[string]any{ + "abi": "(bytes32 FeedID, bytes Report, uint32 Timestamp)[] Reports", + }, + }}, + Config: map[string]any{ + "encoder": "EVM", + "encoder_config": map[string]any{}, + "report_id": "0001", + "aggregation_method": "reduce", + "aggregation_config": map[string]any{ + "outputFieldName": "Reports", + "reportFormat": "array", + "subMapKey": "Report", + "Fields": []map[string]any{ + { + "inputKey": "FeedID", + "outputKey": "FeedID", + "method": "mode", + }, + { + "inputKey": "Timestamp", + "outputKey": "Timestamp", + "method": "median", + "deviation": "3600", + "deviationType": "absolute", + }, + { + "inputKey": "Price", + "outputKey": "Price", + "method": "median", + "deviation": "0.05", + "deviationType": "percent", + "subMapField": true, + }, + }, + }, + }, + CapabilityType: capabilities.CapabilityTypeConsensus, + }, + }, + Targets: []sdk.StepDefinition{ + { + ID: "chainwriter@1.0.0", + Inputs: sdk.StepInputs{ + Mapping: map[string]any{"signed_report": "$(consensus.outputs)"}, + }, + Config: map[string]any{ + "address": "0x1235", + "deltaStage": "45s", + "schedule": "oneAtATime", + }, + CapabilityType: capabilities.CapabilityTypeTarget, + }, + }, + } + + testutils.AssertWorkflowSpec(t, expected, actual) +} diff --git a/pkg/capabilities/consensus/ocr3/testdata/fixtures/capability/schema.json b/pkg/capabilities/consensus/ocr3/testdata/fixtures/capability/schema.json index ebdabb38d..a50ff7d88 100644 --- a/pkg/capabilities/consensus/ocr3/testdata/fixtures/capability/schema.json +++ b/pkg/capabilities/consensus/ocr3/testdata/fixtures/capability/schema.json @@ -7,7 +7,8 @@ "aggregation_method": { "type": "string", "enum": [ - "data_feeds" + "data_feeds", + "reduce" ] }, "aggregation_config": { diff --git a/pkg/capabilities/triggers/mercury_trigger.go b/pkg/capabilities/triggers/mercury_trigger.go index cc456d863..3e9ab1efe 100644 --- a/pkg/capabilities/triggers/mercury_trigger.go +++ b/pkg/capabilities/triggers/mercury_trigger.go @@ -246,5 +246,5 @@ func (o *MercuryTriggerService) HealthReport() map[string]error { } func (o *MercuryTriggerService) Name() string { - return "MercuryTriggerService" + return o.lggr.Name() } diff --git a/pkg/custmsg/custom_message.go b/pkg/custmsg/custom_message.go index ef2bc4c5d..da2595555 100644 --- a/pkg/custmsg/custom_message.go +++ b/pkg/custmsg/custom_message.go @@ -12,7 +12,7 @@ import ( type MessageEmitter interface { // Emit sends a message to the labeler's destination. - Emit(string) error + Emit(context.Context, string) error // WithMapLabels sets the labels for the message to be emitted. Labels are cumulative. WithMapLabels(map[string]string) MessageEmitter @@ -74,8 +74,8 @@ func (l Labeler) With(keyValues ...string) MessageEmitter { return newCustomMessageLabeler } -func (l Labeler) Emit(msg string) error { - return sendLogAsCustomMessageW(msg, l.labels) +func (l Labeler) Emit(ctx context.Context, msg string) error { + return sendLogAsCustomMessageW(ctx, msg, l.labels) } func (l Labeler) Labels() map[string]string { @@ -88,11 +88,11 @@ func (l Labeler) Labels() map[string]string { // SendLogAsCustomMessage emits a BaseMessage With msg and labels as data. // any key in labels that is not part of orderedLabelKeys will not be transmitted -func (l Labeler) SendLogAsCustomMessage(msg string) error { - return sendLogAsCustomMessageW(msg, l.labels) +func (l Labeler) SendLogAsCustomMessage(ctx context.Context, msg string) error { + return sendLogAsCustomMessageW(ctx, msg, l.labels) } -func sendLogAsCustomMessageW(msg string, labels map[string]string) error { +func sendLogAsCustomMessageW(ctx context.Context, msg string, labels map[string]string) error { // TODO un-comment after INFOPLAT-1386 // cast to map[string]any //newLabels := map[string]any{} @@ -115,7 +115,7 @@ func sendLogAsCustomMessageW(msg string, labels map[string]string) error { return fmt.Errorf("sending custom message failed to marshal protobuf: %w", err) } - err = beholder.GetEmitter().Emit(context.Background(), payloadBytes, + err = beholder.GetEmitter().Emit(ctx, payloadBytes, "beholder_data_schema", "/beholder-base-message/versions/1", // required "beholder_domain", "platform", // required "beholder_entity", "BaseMessage", // required diff --git a/pkg/loop/config.go b/pkg/loop/config.go index 5165ff1d7..16fa906ec 100644 --- a/pkg/loop/config.go +++ b/pkg/loop/config.go @@ -105,7 +105,7 @@ func (e *EnvConfig) parse() error { if err != nil { return err } - e.TracingAttributes = getAttributes(envTracingAttribute) + e.TracingAttributes = getMap(envTracingAttribute) e.TracingSamplingRatio = getFloat64OrZero(envTracingSamplingRatio) e.TracingTLSCertPath = os.Getenv(envTracingTLSCertPath) } @@ -122,7 +122,7 @@ func (e *EnvConfig) parse() error { return fmt.Errorf("failed to parse %s: %w", envTelemetryEndpoint, err) } e.TelemetryCACertFile = os.Getenv(envTelemetryCACertFile) - e.TelemetryAttributes = getAttributes(envTelemetryAttribute) + e.TelemetryAttributes = getMap(envTelemetryAttribute) e.TelemetryTraceSampleRatio = getFloat64OrZero(envTelemetryTraceSampleRatio) } return nil @@ -158,14 +158,18 @@ func getValidCollectorTarget() (string, error) { return tracingCollectorTarget, nil } -func getAttributes(envKeyPrefix string) map[string]string { - tracingAttributes := make(map[string]string) +func getMap(envKeyPrefix string) map[string]string { + m := make(map[string]string) for _, env := range os.Environ() { if strings.HasPrefix(env, envKeyPrefix) { - tracingAttributes[strings.TrimPrefix(env, envKeyPrefix)] = os.Getenv(env) + key, value, found := strings.Cut(env, "=") + if found { + key = strings.TrimPrefix(key, envKeyPrefix) + m[key] = value + } } } - return tracingAttributes + return m } // Any errors in parsing result in a sampling ratio of 0.0. diff --git a/pkg/loop/config_test.go b/pkg/loop/config_test.go index e0bcc1d5e..f719ae566 100644 --- a/pkg/loop/config_test.go +++ b/pkg/loop/config_test.go @@ -2,6 +2,7 @@ package loop import ( "net/url" + "os" "strconv" "strings" "testing" @@ -153,6 +154,35 @@ func TestEnvConfig_AsCmdEnv(t *testing.T) { assert.Equal(t, "42", got[envTelemetryAttribute+"baz"]) } +func TestGetMap(t *testing.T) { + os.Setenv("TEST_PREFIX_KEY1", "value1") + os.Setenv("TEST_PREFIX_KEY2", "value2") + os.Setenv("OTHER_KEY", "othervalue") + + defer func() { + os.Unsetenv("TEST_PREFIX_KEY1") + os.Unsetenv("TEST_PREFIX_KEY2") + os.Unsetenv("OTHER_KEY") + }() + + result := getMap("TEST_PREFIX_") + + expected := map[string]string{ + "KEY1": "value1", + "KEY2": "value2", + } + + if len(result) != len(expected) { + t.Errorf("Expected map length %d, got %d", len(expected), len(result)) + } + + for k, v := range expected { + if result[k] != v { + t.Errorf("Expected key %s to have value %s, but got %s", k, v, result[k]) + } + } +} + func TestManagedGRPCClientConfig(t *testing.T) { t.Parallel() diff --git a/pkg/loop/reportingplugins/loopp_service_test.go b/pkg/loop/reportingplugins/loopp_service_test.go index e17c0a231..e4ba1bd31 100644 --- a/pkg/loop/reportingplugins/loopp_service_test.go +++ b/pkg/loop/reportingplugins/loopp_service_test.go @@ -50,39 +50,42 @@ func TestLOOPPService(t *testing.T) { {Plugin: reportingplugins.PluginServiceName}, } for _, ts := range tests { - looppSvc := reportingplugins.NewLOOPPService(logger.Test(t), loop.GRPCOpts{}, func() *exec.Cmd { - return NewHelperProcessCommand(ts.Plugin) - }, - core.ReportingPluginServiceConfig{}, - nettest.MockConn{}, - pipelinetest.PipelineRunner, - telemetrytest.Telemetry, - errorlogtest.ErrorLog, - keyvaluestoretest.KeyValueStore{}, - relayersettest.RelayerSet{}) - hook := looppSvc.XXXTestHook() - servicetest.Run(t, looppSvc) - - t.Run("control", func(t *testing.T) { - reportingplugintest.RunFactory(t, looppSvc) - }) - - t.Run("Kill", func(t *testing.T) { - hook.Kill() - - // wait for relaunch - time.Sleep(2 * goplugin.KeepAliveTickDuration) - - reportingplugintest.RunFactory(t, looppSvc) - }) - - t.Run("Reset", func(t *testing.T) { - hook.Reset() - - // wait for relaunch - time.Sleep(2 * goplugin.KeepAliveTickDuration) - - reportingplugintest.RunFactory(t, looppSvc) + t.Run(ts.Plugin, func(t *testing.T) { + t.Parallel() + looppSvc := reportingplugins.NewLOOPPService(logger.Test(t), loop.GRPCOpts{}, func() *exec.Cmd { + return NewHelperProcessCommand(ts.Plugin) + }, + core.ReportingPluginServiceConfig{}, + nettest.MockConn{}, + pipelinetest.PipelineRunner, + telemetrytest.Telemetry, + errorlogtest.ErrorLog, + keyvaluestoretest.KeyValueStore{}, + relayersettest.RelayerSet{}) + hook := looppSvc.XXXTestHook() + servicetest.Run(t, looppSvc) + + t.Run("control", func(t *testing.T) { + reportingplugintest.RunFactory(t, looppSvc) + }) + + t.Run("Kill", func(t *testing.T) { + hook.Kill() + + // wait for relaunch + time.Sleep(2 * goplugin.KeepAliveTickDuration) + + reportingplugintest.RunFactory(t, looppSvc) + }) + + t.Run("Reset", func(t *testing.T) { + hook.Reset() + + // wait for relaunch + time.Sleep(2 * goplugin.KeepAliveTickDuration) + + reportingplugintest.RunFactory(t, looppSvc) + }) }) } } diff --git a/pkg/loop/reportingplugins/ocr3/loopp_service_test.go b/pkg/loop/reportingplugins/ocr3/loopp_service_test.go index 5b17c263f..b15531cc5 100644 --- a/pkg/loop/reportingplugins/ocr3/loopp_service_test.go +++ b/pkg/loop/reportingplugins/ocr3/loopp_service_test.go @@ -54,40 +54,43 @@ func TestLOOPPService(t *testing.T) { }, } for _, ts := range tests { - looppSvc := NewLOOPPService(logger.Test(t), loop.GRPCOpts{}, func() *exec.Cmd { - return NewHelperProcessCommand(ts.Plugin) - }, - core.ReportingPluginServiceConfig{}, - nettest.MockConn{}, - pipelinetest.PipelineRunner, - telemetrytest.Telemetry, - errorlogtest.ErrorLog, - core.CapabilitiesRegistry(nil), - keyvaluestoretest.KeyValueStore{}, - relayersettest.RelayerSet{}) - hook := looppSvc.XXXTestHook() - servicetest.Run(t, looppSvc) - - t.Run("control", func(t *testing.T) { - ocr3test.OCR3ReportingPluginFactory(t, looppSvc) - }) - - t.Run("Kill", func(t *testing.T) { - hook.Kill() - - // wait for relaunch - time.Sleep(2 * goplugin.KeepAliveTickDuration) - - ocr3test.OCR3ReportingPluginFactory(t, looppSvc) - }) - - t.Run("Reset", func(t *testing.T) { - hook.Reset() - - // wait for relaunch - time.Sleep(2 * goplugin.KeepAliveTickDuration) - - ocr3test.OCR3ReportingPluginFactory(t, looppSvc) + t.Run(ts.Plugin, func(t *testing.T) { + t.Parallel() + looppSvc := NewLOOPPService(logger.Test(t), loop.GRPCOpts{}, func() *exec.Cmd { + return NewHelperProcessCommand(ts.Plugin) + }, + core.ReportingPluginServiceConfig{}, + nettest.MockConn{}, + pipelinetest.PipelineRunner, + telemetrytest.Telemetry, + errorlogtest.ErrorLog, + core.CapabilitiesRegistry(nil), + keyvaluestoretest.KeyValueStore{}, + relayersettest.RelayerSet{}) + hook := looppSvc.XXXTestHook() + servicetest.Run(t, looppSvc) + + t.Run("control", func(t *testing.T) { + ocr3test.OCR3ReportingPluginFactory(t, looppSvc) + }) + + t.Run("Kill", func(t *testing.T) { + hook.Kill() + + // wait for relaunch + time.Sleep(2 * goplugin.KeepAliveTickDuration) + + ocr3test.OCR3ReportingPluginFactory(t, looppSvc) + }) + + t.Run("Reset", func(t *testing.T) { + hook.Reset() + + // wait for relaunch + time.Sleep(2 * goplugin.KeepAliveTickDuration) + + ocr3test.OCR3ReportingPluginFactory(t, looppSvc) + }) }) } } diff --git a/pkg/types/codec.go b/pkg/types/codec.go index 93ae8ce59..395610911 100644 --- a/pkg/types/codec.go +++ b/pkg/types/codec.go @@ -29,6 +29,66 @@ type Decoder interface { GetMaxDecodingSize(ctx context.Context, n int, itemType string) (int, error) } +/* +Codec is an interface that provides encoding and decoding functionality for a specific type identified by a name. +Because there are many types that a [ContractReader] or [ChainWriter] can either accept or return, all encoding +instructions provided by the codec are based on the type name. + +Starting from the lowest level, take for instance a [big.Int] encoder where we want the output to be big endian binary +encoded. + + typeCodec, _ := binary.BigEndian().BigInt(32, true) + +This allows us to encode and decode [big.Int] values with big endian encoding using the [encodings.TypeCodec] interface. + + encodedBytes := []byte{} + + originalValue := big.NewInt(42) + encodedBytes, _ = typeCodec.Encode(originalValue, encodedBytes) // new encoded bytes are appended to existing + + value, _, _ := typeCodec.Decode(encodedBytes, value) + +The additional [encodings.TypeCodec] methods such as 'GetType() reflect.Type' allow composition. This is useful for +creating a struct codec such as the one defined in encodings/struct.go. + + tlCodec, _ := encodings.NewStructCodec([]encodings.NamedTypeCodec{{Name: "Value", Codec: typeCodec}}) + +This provides a [encodings.TopLevelCodec] which is a [encodings.TypeCodec] with a total size of all encoded elements. +Going up another level, we create a [Codec] from a map of [encodings.TypeCodec] instances using +[encodings.CodecFromTypeCodec]. + + codec := encodings.CodecFromTypeCodec{"SomeStruct": tlCodec} + + type SomeStruct struct { + Value *big.Int + } + + encodedStructBytes, _ := codec.Encode(ctx, SomeStruct{Value: big.NewInt(42)}, "SomeStruct") + + var someStruct SomeStruct + _ = codec.Decode(encodedStructBytes, &someStruct, "SomeStruct") + +Therefore 'itemType' passed to [Encode] and [Decode] references the key in the map of [encodings.TypeCodec] instances. +Also worth noting that a `TopLevelCodec` can also be added to a `CodecFromTypeCodec` map. This allows for the +[encodings.SizeAtTopLevel] method to be referenced when [encodings.GetMaxEncodingSize] is called on the [Codec]. + +Also, when the type is unknown to the caller, the decoded type for an 'itemName' can be retrieved from the codec to be +used for decoding. The `CreateType` method returns an instance of the expected type using reflection under the hood and +the overall composition of `TypeCodec` instances. This allows proper types to be conveyed to the caller through the +GRPC interface where data may be JSON encoded, passed through GRPC, and JSON decoded on the other side. + + decodedStruct, _ := codec.CreateType("SomeStruct", false) + _ = codec.Decode(encodedStructBytes, &decodedStruct, "SomeStruct") + +The `encodings` package provides a `Builder` interface that allows for the creation of any encoding type. This is useful +for creating custom encodings such as the EVM ABI encoding. An encoder implements the `Builder` interface and plugs +directly into `TypeCodec`. + +From the perspective of a `ContractReader` instance, the `itemType` at the top level is the `readIdentifier` which +can be imagined as `contractName + methodName` given that a contract method call returns some configured value that +would need its own codec. Each implementation of `ContractReader` maps the names to codecs differently on the inside, +but from the level of the interface, the `itemType` is the `readIdentifier`. +*/ type Codec interface { Encoder Decoder diff --git a/pkg/types/example_codec_test.go b/pkg/types/example_codec_test.go new file mode 100644 index 000000000..94dec8155 --- /dev/null +++ b/pkg/types/example_codec_test.go @@ -0,0 +1,46 @@ +package types_test + +import ( + "context" + "fmt" + "math/big" + + "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings" + "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings/binary" +) + +// ExampleCodec provides a minimal example of constructing and using a codec. +func ExampleCodec() { + ctx := context.Background() + typeCodec, _ := binary.BigEndian().BigInt(32, true) + + // start with empty encoded bytes + encodedBytes := []byte{} + originalValue := big.NewInt(42) + + encodedBytes, _ = typeCodec.Encode(originalValue, encodedBytes) // new encoded bytes are appended to existing + value, _, _ := typeCodec.Decode(encodedBytes) + + // originalValue is the same as value + fmt.Printf("%+v == %+v\n", originalValue, value) + + // TopLevelCodec is a TypeCodec that has a total size of all encoded elements + tlCodec, _ := encodings.NewStructCodec([]encodings.NamedTypeCodec{{Name: "Value", Codec: typeCodec}}) + codec := encodings.CodecFromTypeCodec{"SomeStruct": tlCodec} + + type SomeStruct struct { + Value *big.Int + } + + originalStruct := SomeStruct{Value: big.NewInt(42)} + encodedStructBytes, _ := codec.Encode(ctx, originalStruct, "SomeStruct") + + var someStruct SomeStruct + _ = codec.Decode(ctx, encodedStructBytes, &someStruct, "SomeStruct") + + decodedStruct, _ := codec.CreateType("SomeStruct", false) + _ = codec.Decode(ctx, encodedStructBytes, &decodedStruct, "SomeStruct") + + // encoded struct is equal to decoded struct using defined type and/or CreateType + fmt.Printf("%+v == %+v == %+v\n", originalStruct, someStruct, decodedStruct) +} diff --git a/pkg/workflows/secrets/secrets.go b/pkg/workflows/secrets/secrets.go index 40a6408d1..443e2821a 100644 --- a/pkg/workflows/secrets/secrets.go +++ b/pkg/workflows/secrets/secrets.go @@ -3,6 +3,7 @@ package secrets import ( "crypto/rand" "encoding/base64" + "encoding/hex" "encoding/json" "fmt" @@ -151,3 +152,44 @@ func DecryptSecretsForNode( return payload.Secrets, nil } + +func ValidateEncryptedSecrets(secretsData []byte, encryptionPublicKeys map[string][32]byte, workflowOwner string) error { + var encryptedSecrets EncryptedSecretsResult + err := json.Unmarshal(secretsData, &encryptedSecrets) + if err != nil { + return fmt.Errorf("failed to parse encrypted secrets JSON: %w", err) + } + + if encryptedSecrets.Metadata.WorkflowOwner != workflowOwner { + return fmt.Errorf("the workflow owner in the encrypted secrets metadata: %s does not match the input workflow owner: %s", encryptedSecrets.Metadata.WorkflowOwner, workflowOwner) + } + + // Verify that the encryptedSecrets values are all valid base64 strings + for _, encryptedSecret := range encryptedSecrets.EncryptedSecrets { + _, err := base64.StdEncoding.DecodeString(encryptedSecret) + if err != nil { + return fmt.Errorf("the encrypted secrets JSON payload contains encrypted secrets which are not in base64 format: %w", err) + } + } + + // Check that the p2pIds keys in encryptedSecrets.EncryptedSecrets match the keys in encryptionPublicKeys + for p2pId := range encryptedSecrets.Metadata.NodePublicEncryptionKeys { + if _, ok := encryptedSecrets.EncryptedSecrets[p2pId]; !ok { + return fmt.Errorf("no encrypted secret found for node with p2pId: %s. Ensure secrets have been correctly encrypted for this DON", p2pId) + } + } + + // Check that the encryptionPublicKey values in the encryptedSecrets metadata match the keys in encryptionPublicKeys + for p2pId, keyFromMetadata := range encryptedSecrets.Metadata.NodePublicEncryptionKeys { + encryptionPublicKey, ok := encryptionPublicKeys[p2pId] + if !ok { + return fmt.Errorf("encryption key not found for node with p2pId: %s. Ensure secrets have been correctly encrypted for this DON", p2pId) + } + + if keyFromMetadata != hex.EncodeToString(encryptionPublicKey[:]) { + return fmt.Errorf("the encryption public key in the encrypted secrets metadata does not match the one in the workflow registry. Ensure secrets have been correctly encrypted for this DON") + } + } + + return nil +} diff --git a/pkg/workflows/secrets/secrets_test.go b/pkg/workflows/secrets/secrets_test.go index 8f5e6c56a..cf192b5b0 100644 --- a/pkg/workflows/secrets/secrets_test.go +++ b/pkg/workflows/secrets/secrets_test.go @@ -3,6 +3,8 @@ package secrets import ( "crypto/rand" "encoding/base64" + "encoding/hex" + "encoding/json" "errors" "testing" @@ -191,3 +193,85 @@ func TestEncryptDecrypt(t *testing.T) { }) } + +func TestValidateEncryptedSecrets(t *testing.T) { + // Helper function to generate a valid base64 encoded string + validBase64 := func(input string) string { + return base64.StdEncoding.EncodeToString([]byte(input)) + } + + // Define a key for testing + keyFromMetadata := [32]byte{1, 2, 3} + + // Valid JSON input with matching workflow owner + validInput := map[string]interface{}{ + "encryptedSecrets": map[string]string{ + "09ca39cd924653c72fbb0e458b629c3efebdad3e29e7cd0b5760754d919ed829": validBase64("secret1"), + }, + "metadata": map[string]interface{}{ + "workflowOwner": "correctOwner", + "nodePublicEncryptionKeys": map[string]string{ + "09ca39cd924653c72fbb0e458b629c3efebdad3e29e7cd0b5760754d919ed829": hex.EncodeToString(keyFromMetadata[:]), + }, + }, + } + + // Serialize the valid input + validData, _ := json.Marshal(validInput) + + // Define test cases + tests := []struct { + name string + inputData []byte + encryptionPublicKeys map[string][32]byte + workflowOwner string + shouldError bool + }{ + { + name: "Valid input", + inputData: validData, + workflowOwner: "correctOwner", + encryptionPublicKeys: map[string][32]byte{ + "09ca39cd924653c72fbb0e458b629c3efebdad3e29e7cd0b5760754d919ed829": {1, 2, 3}, + }, + shouldError: false, + }, + { + name: "Invalid base64 encoded secret", + inputData: []byte(`{"encryptedSecrets": {"09ca39cd924653c72fbb0e458b629c3efebdad3e29e7cd0b5760754d919ed829": "invalid-base64!"}}`), + workflowOwner: "correctOwner", + encryptionPublicKeys: map[string][32]byte{ + "09ca39cd924653c72fbb0e458b629c3efebdad3e29e7cd0b5760754d919ed829": {1, 2, 3}, + }, + shouldError: true, + }, + { + name: "Missing public key", + inputData: validData, + workflowOwner: "correctOwner", + encryptionPublicKeys: map[string][32]byte{ + "some-other-id": {1, 2, 3}, + }, + shouldError: true, + }, + { + name: "Mismatched workflow owner", + inputData: validData, + workflowOwner: "incorrectOwner", + encryptionPublicKeys: map[string][32]byte{ + "09ca39cd924653c72fbb0e458b629c3efebdad3e29e7cd0b5760754d919ed829": {1, 2, 3}, + }, + shouldError: true, + }, + } + + // Run the test cases + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := ValidateEncryptedSecrets(test.inputData, test.encryptionPublicKeys, test.workflowOwner) + if (err != nil) != test.shouldError { + t.Errorf("Expected error: %v, got: %v", test.shouldError, err != nil) + } + }) + } +} diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index cce55d57e..821bcf013 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -29,7 +29,7 @@ import ( type RequestData struct { fetchRequestsCounter int response *wasmpb.Response - callWithCtx func(func(context.Context) (*wasmpb.FetchResponse, error)) (*wasmpb.FetchResponse, error) + ctx func() context.Context } type store struct { @@ -72,7 +72,7 @@ func (r *store) delete(id string) { var ( defaultTickInterval = 100 * time.Millisecond defaultTimeout = 2 * time.Second - defaultMaxMemoryMBs = 256 + defaultMinMemoryMBs = 128 DefaultInitialFuel = uint64(100_000_000) defaultMaxFetchRequests = 5 ) @@ -85,6 +85,7 @@ type ModuleConfig struct { TickInterval time.Duration Timeout *time.Duration MaxMemoryMBs int64 + MinMemoryMBs int64 InitialFuel uint64 Logger logger.Logger IsUncompressed bool @@ -100,9 +101,10 @@ type ModuleConfig struct { } type Module struct { - engine *wasmtime.Engine - module *wasmtime.Module - linker *wasmtime.Linker + engine *wasmtime.Engine + module *wasmtime.Module + linker *wasmtime.Linker + wconfig *wasmtime.Config requestStore *store @@ -160,11 +162,15 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) modCfg.Timeout = &defaultTimeout } - // Take the max of the default and the configured max memory mbs. + if modCfg.MinMemoryMBs == 0 { + modCfg.MinMemoryMBs = int64(defaultMinMemoryMBs) + } + + // Take the max of the min and the configured max memory mbs. // We do this because Go requires a minimum of 16 megabytes to run, - // and local testing has shown that with less than 64 mbs, some + // and local testing has shown that with less than the min, some // binaries may error sporadically. - modCfg.MaxMemoryMBs = int64(math.Max(float64(defaultMaxMemoryMBs), float64(modCfg.MaxMemoryMBs))) + modCfg.MaxMemoryMBs = int64(math.Max(float64(modCfg.MinMemoryMBs), float64(modCfg.MaxMemoryMBs))) cfg := wasmtime.NewConfig() cfg.SetEpochInterruption(true) @@ -172,8 +178,10 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) cfg.SetConsumeFuel(true) } - engine := wasmtime.NewEngineWithConfig(cfg) + cfg.CacheConfigLoadDefault() + cfg.SetCraneliftOptLevel(wasmtime.OptLevelSpeedAndSize) + engine := wasmtime.NewEngineWithConfig(cfg) if !modCfg.IsUncompressed { rdr := brotli.NewReader(bytes.NewBuffer(binary)) decompedBinary, err := io.ReadAll(rdr) @@ -201,29 +209,7 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) err = linker.FuncWrap( "env", "sendResponse", - func(caller *wasmtime.Caller, ptr int32, ptrlen int32) int32 { - b, innerErr := wasmRead(caller, ptr, ptrlen) - if innerErr != nil { - logger.Errorf("error calling sendResponse: %s", err) - return ErrnoFault - } - - var resp wasmpb.Response - innerErr = proto.Unmarshal(b, &resp) - if innerErr != nil { - logger.Errorf("error calling sendResponse: %s", innerErr) - return ErrnoFault - } - - storedReq, innerErr := requestStore.get(resp.Id) - if innerErr != nil { - logger.Errorf("error calling sendResponse: %s", innerErr) - return ErrnoFault - } - storedReq.response = &resp - - return ErrnoSuccess - }, + createSendResponseFn(logger, requestStore), ) if err != nil { return nil, fmt.Errorf("error wrapping sendResponse func: %w", err) @@ -232,48 +218,7 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) err = linker.FuncWrap( "env", "log", - func(caller *wasmtime.Caller, ptr int32, ptrlen int32) { - b, innerErr := wasmRead(caller, ptr, ptrlen) - if innerErr != nil { - logger.Errorf("error calling log: %s", err) - return - } - - var raw map[string]interface{} - innerErr = json.Unmarshal(b, &raw) - if innerErr != nil { - return - } - - level := raw["level"] - delete(raw, "level") - - msg := raw["msg"].(string) - delete(raw, "msg") - delete(raw, "ts") - - var args []interface{} - for k, v := range raw { - args = append(args, k, v) - } - - switch level { - case "debug": - logger.Debugw(msg, args...) - case "info": - logger.Infow(msg, args...) - case "warn": - logger.Warnw(msg, args...) - case "error": - logger.Errorw(msg, args...) - case "panic": - logger.Panicw(msg, args...) - case "fatal": - logger.Fatalw(msg, args...) - default: - logger.Infow(msg, args...) - } - }, + createLogFn(logger), ) if err != nil { return nil, fmt.Errorf("error wrapping log func: %w", err) @@ -291,16 +236,17 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) err = linker.FuncWrap( "env", "emit", - createEmitFn(logger, modCfg.Labeler, wasmRead, wasmWrite, wasmWriteUInt32), + createEmitFn(logger, requestStore, modCfg.Labeler, wasmRead, wasmWrite, wasmWriteUInt32), ) if err != nil { return nil, fmt.Errorf("error wrapping emit func: %w", err) } m := &Module{ - engine: engine, - module: mod, - linker: linker, + engine: engine, + module: mod, + linker: linker, + wconfig: cfg, requestStore: requestStore, @@ -336,6 +282,7 @@ func (m *Module) Close() { m.linker.Close() m.engine.Close() m.module.Close() + m.wconfig.Close() } func (m *Module) Run(ctx context.Context, request *wasmpb.Request) (*wasmpb.Response, error) { @@ -348,9 +295,7 @@ func (m *Module) Run(ctx context.Context, request *wasmpb.Request) (*wasmpb.Resp } // we add the request context to the store to make it available to the Fetch fn - err := m.requestStore.add(request.Id, &RequestData{callWithCtx: func(fn func(context.Context) (*wasmpb.FetchResponse, error)) (*wasmpb.FetchResponse, error) { - return fn(ctx) - }}) + err := m.requestStore.add(request.Id, &RequestData{ctx: func() context.Context { return ctx }}) if err != nil { return nil, fmt.Errorf("error adding ctx to the store: %w", err) } @@ -368,6 +313,8 @@ func (m *Module) Run(ctx context.Context, request *wasmpb.Request) (*wasmpb.Resp reqstr := base64.StdEncoding.EncodeToString(reqpb) wasi := wasmtime.NewWasiConfig() + defer wasi.Close() + wasi.SetArgv([]string{"wasi", reqstr}) store.SetWasi(wasi) @@ -436,6 +383,34 @@ func containsCode(err error, code int) bool { return strings.Contains(err.Error(), fmt.Sprintf("exit status %d", code)) } +// createSendResponseFn injects the dependency required by a WASM guest to +// send a response back to the host. +func createSendResponseFn(logger logger.Logger, requestStore *store) func(caller *wasmtime.Caller, ptr int32, ptrlen int32) int32 { + return func(caller *wasmtime.Caller, ptr int32, ptrlen int32) int32 { + b, innerErr := wasmRead(caller, ptr, ptrlen) + if innerErr != nil { + logger.Errorf("error calling sendResponse: %s", innerErr) + return ErrnoFault + } + + var resp wasmpb.Response + innerErr = proto.Unmarshal(b, &resp) + if innerErr != nil { + logger.Errorf("error calling sendResponse: %s", innerErr) + return ErrnoFault + } + + storedReq, innerErr := requestStore.get(resp.Id) + if innerErr != nil { + logger.Errorf("error calling sendResponse: %s", innerErr) + return ErrnoFault + } + storedReq.response = &resp + + return ErrnoSuccess + } +} + func createFetchFn( logger logger.Logger, reader unsafeReaderFunc, @@ -499,13 +474,7 @@ func createFetchFn( } storedRequest.fetchRequestsCounter++ - fetchResp, innerErr := storedRequest.callWithCtx(func(ctx context.Context) (*wasmpb.FetchResponse, error) { - if ctx == nil { - return nil, errors.New("context is nil") - } - - return modCfg.Fetch(ctx, req) - }) + fetchResp, innerErr := modCfg.Fetch(storedRequest.ctx(), req) if innerErr != nil { logger.Errorf("%s: %s", errFetchSfx, innerErr) return writeErr(innerErr) @@ -533,6 +502,7 @@ func createFetchFn( // Emit, if any, are returned in the Error Message of the response. func createEmitFn( l logger.Logger, + requestStore *store, e custmsg.MessageEmitter, reader unsafeReaderFunc, writer unsafeWriterFunc, @@ -577,12 +547,18 @@ func createEmitFn( return writeErr(err) } - msg, labels, err := toEmissible(b) + reqID, msg, labels, err := toEmissible(b) + if err != nil { + return writeErr(err) + } + + req, err := requestStore.get(reqID) if err != nil { + logErr(fmt.Errorf("failed to get request from store: %s", err)) return writeErr(err) } - if err := e.WithMapLabels(labels).Emit(msg); err != nil { + if err := e.WithMapLabels(labels).Emit(req.ctx(), msg); err != nil { return writeErr(err) } @@ -590,9 +566,55 @@ func createEmitFn( } } +// createLogFn injects dependencies and builds the log function exposed by the WASM. +func createLogFn(logger logger.Logger) func(caller *wasmtime.Caller, ptr int32, ptrlen int32) { + return func(caller *wasmtime.Caller, ptr int32, ptrlen int32) { + b, innerErr := wasmRead(caller, ptr, ptrlen) + if innerErr != nil { + logger.Errorf("error calling log: %s", innerErr) + return + } + + var raw map[string]interface{} + innerErr = json.Unmarshal(b, &raw) + if innerErr != nil { + return + } + + level := raw["level"] + delete(raw, "level") + + msg := raw["msg"].(string) + delete(raw, "msg") + delete(raw, "ts") + + var args []interface{} + for k, v := range raw { + args = append(args, k, v) + } + + switch level { + case "debug": + logger.Debugw(msg, args...) + case "info": + logger.Infow(msg, args...) + case "warn": + logger.Warnw(msg, args...) + case "error": + logger.Errorw(msg, args...) + case "panic": + logger.Panicw(msg, args...) + case "fatal": + logger.Fatalw(msg, args...) + default: + logger.Infow(msg, args...) + } + } +} + type unimplementedMessageEmitter struct{} -func (u *unimplementedMessageEmitter) Emit(string) error { +func (u *unimplementedMessageEmitter) Emit(context.Context, string) error { return errors.New("unimplemented") } @@ -608,18 +630,18 @@ func (u *unimplementedMessageEmitter) Labels() map[string]string { return nil } -func toEmissible(b []byte) (string, map[string]string, error) { +func toEmissible(b []byte) (string, string, map[string]string, error) { msg := &wasmpb.EmitMessageRequest{} if err := proto.Unmarshal(b, msg); err != nil { - return "", nil, err + return "", "", nil, err } validated, err := toValidatedLabels(msg) if err != nil { - return "", nil, err + return "", "", nil, err } - return msg.Message, validated, nil + return msg.RequestId, msg.Message, validated, nil } func toValidatedLabels(msg *wasmpb.EmitMessageRequest) (map[string]string, error) { diff --git a/pkg/workflows/wasm/host/module_test.go b/pkg/workflows/wasm/host/module_test.go index a4cd06ddd..a19c43fa2 100644 --- a/pkg/workflows/wasm/host/module_test.go +++ b/pkg/workflows/wasm/host/module_test.go @@ -13,17 +13,18 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/custmsg" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink-common/pkg/values/pb" wasmpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/pb" ) type mockMessageEmitter struct { - e func(string, map[string]string) error + e func(context.Context, string, map[string]string) error labels map[string]string } -func (m *mockMessageEmitter) Emit(msg string) error { - return m.e(msg, m.labels) +func (m *mockMessageEmitter) Emit(ctx context.Context, msg string) error { + return m.e(ctx, msg, m.labels) } func (m *mockMessageEmitter) WithMapLabels(labels map[string]string) custmsg.MessageEmitter { @@ -40,7 +41,7 @@ func (m *mockMessageEmitter) Labels() map[string]string { return m.labels } -func newMockMessageEmitter(e func(string, map[string]string) error) custmsg.MessageEmitter { +func newMockMessageEmitter(e func(context.Context, string, map[string]string) error) custmsg.MessageEmitter { return &mockMessageEmitter{e: e} } @@ -48,14 +49,31 @@ func newMockMessageEmitter(e func(string, map[string]string) error) custmsg.Mess // access functions are injected as mocks. func Test_createEmitFn(t *testing.T) { t.Run("success", func(t *testing.T) { + ctxKey := "key" + ctxValue := "test-value" + ctx := tests.Context(t) + ctx = context.WithValue(ctx, ctxKey, "test-value") + store := &store{ + m: make(map[string]*RequestData), + mu: sync.RWMutex{}, + } + reqId := "random-id" + err := store.add( + reqId, + &RequestData{ctx: func() context.Context { return ctx }}) + require.NoError(t, err) emitFn := createEmitFn( logger.Test(t), - newMockMessageEmitter(func(_ string, _ map[string]string) error { + store, + newMockMessageEmitter(func(ctx context.Context, _ string, _ map[string]string) error { + v := ctx.Value(ctxKey) + assert.Equal(t, ctxValue, v) return nil }), unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { b, err := proto.Marshal(&wasmpb.EmitMessageRequest{ - Message: "hello, world", + RequestId: reqId, + Message: "hello, world", Labels: &pb.Map{ Fields: map[string]*pb.Value{ "foo": { @@ -81,9 +99,14 @@ func Test_createEmitFn(t *testing.T) { }) t.Run("success without labels", func(t *testing.T) { + store := &store{ + m: make(map[string]*RequestData), + mu: sync.RWMutex{}, + } emitFn := createEmitFn( logger.Test(t), - newMockMessageEmitter(func(_ string, _ map[string]string) error { + store, + newMockMessageEmitter(func(_ context.Context, _ string, _ map[string]string) error { return nil }), unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { @@ -103,6 +126,10 @@ func Test_createEmitFn(t *testing.T) { }) t.Run("successfully write error to memory on failure to read", func(t *testing.T) { + store := &store{ + m: make(map[string]*RequestData), + mu: sync.RWMutex{}, + } respBytes, err := proto.Marshal(&wasmpb.EmitMessageResponse{ Error: &wasmpb.Error{ Message: assert.AnError.Error(), @@ -112,6 +139,7 @@ func Test_createEmitFn(t *testing.T) { emitFn := createEmitFn( logger.Test(t), + store, nil, unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { return nil, assert.AnError @@ -130,6 +158,14 @@ func Test_createEmitFn(t *testing.T) { }) t.Run("failure to emit writes error to memory", func(t *testing.T) { + store := &store{ + m: make(map[string]*RequestData), + mu: sync.RWMutex{}, + } + reqId := "random-id" + store.add(reqId, &RequestData{ + ctx: func() context.Context { return tests.Context(t) }, + }) respBytes, err := proto.Marshal(&wasmpb.EmitMessageResponse{ Error: &wasmpb.Error{ Message: assert.AnError.Error(), @@ -139,11 +175,14 @@ func Test_createEmitFn(t *testing.T) { emitFn := createEmitFn( logger.Test(t), - newMockMessageEmitter(func(_ string, _ map[string]string) error { + store, + newMockMessageEmitter(func(_ context.Context, _ string, _ map[string]string) error { return assert.AnError }), unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { - b, err := proto.Marshal(&wasmpb.EmitMessageRequest{}) + b, err := proto.Marshal(&wasmpb.EmitMessageRequest{ + RequestId: reqId, + }) assert.NoError(t, err) return b, nil }), @@ -161,6 +200,10 @@ func Test_createEmitFn(t *testing.T) { }) t.Run("bad read failure to unmarshal protos", func(t *testing.T) { + store := &store{ + m: make(map[string]*RequestData), + mu: sync.RWMutex{}, + } badData := []byte("not proto bufs") msg := &wasmpb.EmitMessageRequest{} marshallErr := proto.Unmarshal(badData, msg) @@ -175,6 +218,7 @@ func Test_createEmitFn(t *testing.T) { emitFn := createEmitFn( logger.Test(t), + store, nil, unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { return badData, nil @@ -203,9 +247,7 @@ func TestCreateFetchFn(t *testing.T) { // we add the request data to the store so that the fetch function can find it store.m[testID] = &RequestData{ - callWithCtx: func(fn func(context.Context) (*wasmpb.FetchResponse, error)) (*wasmpb.FetchResponse, error) { - return fn(context.Background()) - }, + ctx: func() context.Context { return tests.Context(t) }, } fetchFn := createFetchFn( @@ -348,54 +390,6 @@ func TestCreateFetchFn(t *testing.T) { assert.Equal(t, ErrnoSuccess, gotCode) }) - t.Run("NOK-fetch_fails_stored_ctx_is_nil", func(t *testing.T) { - store := &store{ - m: make(map[string]*RequestData), - mu: sync.RWMutex{}, - } - - // we add the request data to the store so that the fetch function can find it - store.m[testID] = &RequestData{ - callWithCtx: func(fn func(context.Context) (*wasmpb.FetchResponse, error)) (*wasmpb.FetchResponse, error) { - return fn(nil) - }, - } - - fetchFn := createFetchFn( - logger.Test(t), - unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { - b, err := proto.Marshal(&wasmpb.FetchRequest{ - Id: testID, - }) - assert.NoError(t, err) - return b, nil - }), - unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { - // the error is handled and written to the buffer - resp := &wasmpb.FetchResponse{} - err := proto.Unmarshal(src, resp) - require.NoError(t, err) - expectedErr := "context is nil" - assert.Equal(t, expectedErr, resp.ErrorMessage) - return 0 - }), - unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { - return 0 - }), - &ModuleConfig{ - Logger: logger.Test(t), - Fetch: func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) { - return &wasmpb.FetchResponse{}, nil - }, - MaxFetchRequests: 1, - }, - store, - ) - - gotCode := fetchFn(new(wasmtime.Caller), 0, 0, 0, 0) - assert.Equal(t, ErrnoSuccess, gotCode) - }) - t.Run("NOK-fetch_returns_an_error", func(t *testing.T) { store := &store{ m: make(map[string]*RequestData), @@ -404,9 +398,7 @@ func TestCreateFetchFn(t *testing.T) { // we add the request data to the store so that the fetch function can find it store.m[testID] = &RequestData{ - callWithCtx: func(fn func(context.Context) (*wasmpb.FetchResponse, error)) (*wasmpb.FetchResponse, error) { - return fn(context.Background()) - }, + ctx: func() context.Context { return tests.Context(t) }, } fetchFn := createFetchFn( @@ -452,9 +444,7 @@ func TestCreateFetchFn(t *testing.T) { // we add the request data to the store so that the fetch function can find it store.m[testID] = &RequestData{ - callWithCtx: func(fn func(context.Context) (*wasmpb.FetchResponse, error)) (*wasmpb.FetchResponse, error) { - return fn(context.Background()) - }, + ctx: func() context.Context { return tests.Context(t) }, } fetchFn := createFetchFn( @@ -493,9 +483,7 @@ func TestCreateFetchFn(t *testing.T) { // we add the request data to the store so that the fetch function can find it store.m[testID] = &RequestData{ - callWithCtx: func(fn func(context.Context) (*wasmpb.FetchResponse, error)) (*wasmpb.FetchResponse, error) { - return fn(context.Background()) - }, + ctx: func() context.Context { return tests.Context(t) }, } fetchFn := createFetchFn( @@ -651,8 +639,10 @@ func Test_toValidatedLabels(t *testing.T) { func Test_toEmissible(t *testing.T) { t.Run("success", func(t *testing.T) { + reqID := "random-id" msg := &wasmpb.EmitMessageRequest{ - Message: "hello, world", + RequestId: reqID, + Message: "hello, world", Labels: &pb.Map{ Fields: map[string]*pb.Value{ "test": { @@ -667,14 +657,15 @@ func Test_toEmissible(t *testing.T) { b, err := proto.Marshal(msg) assert.NoError(t, err) - gotMsg, gotLabels, err := toEmissible(b) + rid, gotMsg, gotLabels, err := toEmissible(b) assert.NoError(t, err) assert.Equal(t, "hello, world", gotMsg) assert.Equal(t, map[string]string{"test": "value"}, gotLabels) + assert.Equal(t, reqID, rid) }) t.Run("fails with bad message", func(t *testing.T) { - _, _, err := toEmissible([]byte("not proto bufs")) + _, _, _, err := toEmissible([]byte("not proto bufs")) assert.Error(t, err) }) } diff --git a/pkg/workflows/wasm/host/wasm_test.go b/pkg/workflows/wasm/host/wasm_test.go index 9e290d74e..cccaf9443 100644 --- a/pkg/workflows/wasm/host/wasm_test.go +++ b/pkg/workflows/wasm/host/wasm_test.go @@ -57,7 +57,7 @@ const ( emitBinaryCmd = "test/emit/cmd" ) -func createTestBinary(outputPath, path string, compress bool, t *testing.T) []byte { +func createTestBinary(outputPath, path string, uncompressed bool, t *testing.T) []byte { cmd := exec.Command("go", "build", "-o", path, fmt.Sprintf("github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/%s", outputPath)) // #nosec cmd.Env = append(os.Environ(), "GOOS=wasip1", "GOARCH=wasm") @@ -67,7 +67,7 @@ func createTestBinary(outputPath, path string, compress bool, t *testing.T) []by binary, err := os.ReadFile(path) require.NoError(t, err) - if !compress { + if uncompressed { return binary } @@ -83,13 +83,15 @@ func createTestBinary(outputPath, path string, compress bool, t *testing.T) []by } func Test_GetWorkflowSpec(t *testing.T) { + t.Parallel() ctx := tests.Context(t) binary := createTestBinary(successBinaryCmd, successBinaryLocation, true, t) spec, err := GetWorkflowSpec( ctx, &ModuleConfig{ - Logger: logger.Test(t), + Logger: logger.Test(t), + IsUncompressed: true, }, binary, []byte(""), @@ -101,6 +103,7 @@ func Test_GetWorkflowSpec(t *testing.T) { } func Test_GetWorkflowSpec_UncompressedBinary(t *testing.T) { + t.Parallel() ctx := tests.Context(t) binary := createTestBinary(successBinaryCmd, successBinaryLocation, false, t) @@ -108,7 +111,7 @@ func Test_GetWorkflowSpec_UncompressedBinary(t *testing.T) { ctx, &ModuleConfig{ Logger: logger.Test(t), - IsUncompressed: true, + IsUncompressed: false, }, binary, []byte(""), @@ -126,7 +129,8 @@ func Test_GetWorkflowSpec_BinaryErrors(t *testing.T) { _, err := GetWorkflowSpec( ctx, &ModuleConfig{ - Logger: logger.Test(t), + Logger: logger.Test(t), + IsUncompressed: true, }, failBinary, []byte(""), @@ -136,6 +140,7 @@ func Test_GetWorkflowSpec_BinaryErrors(t *testing.T) { } func Test_GetWorkflowSpec_Timeout(t *testing.T) { + t.Parallel() ctx := tests.Context(t) binary := createTestBinary(successBinaryCmd, successBinaryLocation, true, t) @@ -143,8 +148,9 @@ func Test_GetWorkflowSpec_Timeout(t *testing.T) { _, err := GetWorkflowSpec( ctx, &ModuleConfig{ - Timeout: &d, - Logger: logger.Test(t), + Timeout: &d, + Logger: logger.Test(t), + IsUncompressed: true, }, binary, // use the success binary with a zero timeout []byte(""), @@ -154,12 +160,14 @@ func Test_GetWorkflowSpec_Timeout(t *testing.T) { } func Test_Compute_Logs(t *testing.T) { + t.Parallel() ctx := tests.Context(t) binary := createTestBinary(logBinaryCmd, logBinaryLocation, true, t) logger, logs := logger.TestObserved(t, zapcore.InfoLevel) m, err := NewModule(&ModuleConfig{ - Logger: logger, + Logger: logger, + IsUncompressed: true, Fetch: func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) { return nil, nil }, @@ -203,6 +211,7 @@ func Test_Compute_Logs(t *testing.T) { } func Test_Compute_Emit(t *testing.T) { + t.Parallel() binary := createTestBinary(emitBinaryCmd, emitBinaryLocation, true, t) lggr := logger.Test(t) @@ -231,12 +240,15 @@ func Test_Compute_Emit(t *testing.T) { } t.Run("successfully call emit with metadata in labels", func(t *testing.T) { + ctx := tests.Context(t) m, err := NewModule(&ModuleConfig{ - Logger: lggr, - Fetch: fetchFunc, - Labeler: newMockMessageEmitter(func(msg string, kvs map[string]string) error { + Logger: lggr, + Fetch: fetchFunc, + IsUncompressed: true, + Labeler: newMockMessageEmitter(func(gotCtx context.Context, msg string, kvs map[string]string) error { t.Helper() + assert.Equal(t, ctx, gotCtx) assert.Equal(t, "testing emit", msg) assert.Equal(t, "this is a test field content", kvs["test-string-field-key"]) assert.Equal(t, "workflow-id", kvs["workflow_id"]) @@ -250,7 +262,6 @@ func Test_Compute_Emit(t *testing.T) { m.Start() - ctx := tests.Context(t) _, err = m.Run(ctx, req) assert.Nil(t, err) }) @@ -259,9 +270,10 @@ func Test_Compute_Emit(t *testing.T) { lggr, logs := logger.TestObserved(t, zapcore.InfoLevel) m, err := NewModule(&ModuleConfig{ - Logger: lggr, - Fetch: fetchFunc, - Labeler: newMockMessageEmitter(func(msg string, kvs map[string]string) error { + Logger: lggr, + Fetch: fetchFunc, + IsUncompressed: true, + Labeler: newMockMessageEmitter(func(_ context.Context, msg string, kvs map[string]string) error { t.Helper() assert.Equal(t, "testing emit", msg) @@ -300,9 +312,10 @@ func Test_Compute_Emit(t *testing.T) { lggr := logger.Test(t) m, err := NewModule(&ModuleConfig{ - Logger: lggr, - Fetch: fetchFunc, - Labeler: newMockMessageEmitter(func(msg string, labels map[string]string) error { + Logger: lggr, + Fetch: fetchFunc, + IsUncompressed: true, + Labeler: newMockMessageEmitter(func(_ context.Context, msg string, labels map[string]string) error { return nil }), // never called }, binary) @@ -333,9 +346,11 @@ func Test_Compute_Emit(t *testing.T) { } func Test_Compute_Fetch(t *testing.T) { + t.Parallel() binary := createTestBinary(fetchBinaryCmd, fetchBinaryLocation, true, t) t.Run("OK_default_runtime_cfg", func(t *testing.T) { + t.Parallel() ctx := tests.Context(t) expected := sdk.FetchResponse{ ExecutionError: false, @@ -345,7 +360,8 @@ func Test_Compute_Fetch(t *testing.T) { } m, err := NewModule(&ModuleConfig{ - Logger: logger.Test(t), + Logger: logger.Test(t), + IsUncompressed: true, Fetch: func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) { return &wasmpb.FetchResponse{ ExecutionError: expected.ExecutionError, @@ -385,6 +401,7 @@ func Test_Compute_Fetch(t *testing.T) { }) t.Run("OK_custom_runtime_cfg", func(t *testing.T) { + t.Parallel() ctx := tests.Context(t) expected := sdk.FetchResponse{ ExecutionError: false, @@ -394,7 +411,8 @@ func Test_Compute_Fetch(t *testing.T) { } m, err := NewModule(&ModuleConfig{ - Logger: logger.Test(t), + Logger: logger.Test(t), + IsUncompressed: true, Fetch: func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) { return &wasmpb.FetchResponse{ ExecutionError: expected.ExecutionError, @@ -437,11 +455,13 @@ func Test_Compute_Fetch(t *testing.T) { }) t.Run("NOK_fetch_error_returned", func(t *testing.T) { + t.Parallel() ctx := tests.Context(t) logger, logs := logger.TestObserved(t, zapcore.InfoLevel) m, err := NewModule(&ModuleConfig{ - Logger: logger, + Logger: logger, + IsUncompressed: true, Fetch: func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) { return nil, assert.AnError }, @@ -481,6 +501,7 @@ func Test_Compute_Fetch(t *testing.T) { }) t.Run("OK_context_propagation", func(t *testing.T) { + t.Parallel() type testkey string var key testkey = "test-key" var expectedValue string = "test-value" @@ -493,7 +514,8 @@ func Test_Compute_Fetch(t *testing.T) { } m, err := NewModule(&ModuleConfig{ - Logger: logger.Test(t), + Logger: logger.Test(t), + IsUncompressed: true, Fetch: func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) { return &wasmpb.FetchResponse{ ExecutionError: expected.ExecutionError, @@ -538,8 +560,10 @@ func Test_Compute_Fetch(t *testing.T) { }) t.Run("OK_context_cancelation", func(t *testing.T) { + t.Parallel() m, err := NewModule(&ModuleConfig{ - Logger: logger.Test(t), + Logger: logger.Test(t), + IsUncompressed: true, Fetch: func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) { select { case <-ctx.Done(): @@ -579,6 +603,7 @@ func Test_Compute_Fetch(t *testing.T) { }) t.Run("NOK_exceed_amout_of_defined_max_fetch_calls", func(t *testing.T) { + t.Parallel() binary := createTestBinary(fetchlimitBinaryCmd, fetchlimitBinaryLocation, true, t) ctx := tests.Context(t) expected := sdk.FetchResponse{ @@ -589,7 +614,8 @@ func Test_Compute_Fetch(t *testing.T) { } m, err := NewModule(&ModuleConfig{ - Logger: logger.Test(t), + Logger: logger.Test(t), + IsUncompressed: true, Fetch: func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) { return &wasmpb.FetchResponse{ ExecutionError: expected.ExecutionError, @@ -622,6 +648,7 @@ func Test_Compute_Fetch(t *testing.T) { }) t.Run("NOK_exceed_amout_of_default_max_fetch_calls", func(t *testing.T) { + t.Parallel() binary := createTestBinary(fetchlimitBinaryCmd, fetchlimitBinaryLocation, true, t) ctx := tests.Context(t) expected := sdk.FetchResponse{ @@ -632,7 +659,8 @@ func Test_Compute_Fetch(t *testing.T) { } m, err := NewModule(&ModuleConfig{ - Logger: logger.Test(t), + Logger: logger.Test(t), + IsUncompressed: true, Fetch: func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) { return &wasmpb.FetchResponse{ ExecutionError: expected.ExecutionError, @@ -664,6 +692,7 @@ func Test_Compute_Fetch(t *testing.T) { }) t.Run("OK_making_up_to_max_fetch_calls", func(t *testing.T) { + t.Parallel() binary := createTestBinary(fetchlimitBinaryCmd, fetchlimitBinaryLocation, true, t) ctx := tests.Context(t) expected := sdk.FetchResponse{ @@ -674,7 +703,8 @@ func Test_Compute_Fetch(t *testing.T) { } m, err := NewModule(&ModuleConfig{ - Logger: logger.Test(t), + Logger: logger.Test(t), + IsUncompressed: true, Fetch: func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) { return &wasmpb.FetchResponse{ ExecutionError: expected.ExecutionError, @@ -707,6 +737,7 @@ func Test_Compute_Fetch(t *testing.T) { }) t.Run("OK_multiple_request_reusing_module", func(t *testing.T) { + t.Parallel() binary := createTestBinary(fetchlimitBinaryCmd, fetchlimitBinaryLocation, true, t) ctx := tests.Context(t) expected := sdk.FetchResponse{ @@ -717,7 +748,8 @@ func Test_Compute_Fetch(t *testing.T) { } m, err := NewModule(&ModuleConfig{ - Logger: logger.Test(t), + Logger: logger.Test(t), + IsUncompressed: true, Fetch: func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) { return &wasmpb.FetchResponse{ ExecutionError: expected.ExecutionError, @@ -756,10 +788,11 @@ func Test_Compute_Fetch(t *testing.T) { } func TestModule_Errors(t *testing.T) { + t.Parallel() ctx := tests.Context(t) binary := createTestBinary(successBinaryCmd, successBinaryLocation, true, t) - m, err := NewModule(&ModuleConfig{Logger: logger.Test(t)}, binary) + m, err := NewModule(&ModuleConfig{IsUncompressed: true, Logger: logger.Test(t)}, binary) require.NoError(t, err) _, err = m.Run(ctx, nil) @@ -806,7 +839,7 @@ func TestModule_Sandbox_Memory(t *testing.T) { ctx := tests.Context(t) binary := createTestBinary(oomBinaryCmd, oomBinaryLocation, true, t) - m, err := NewModule(&ModuleConfig{Logger: logger.Test(t)}, binary) + m, err := NewModule(&ModuleConfig{IsUncompressed: true, Logger: logger.Test(t)}, binary) require.NoError(t, err) m.Start() @@ -820,10 +853,11 @@ func TestModule_Sandbox_Memory(t *testing.T) { } func TestModule_Sandbox_SleepIsStubbedOut(t *testing.T) { + t.Parallel() ctx := tests.Context(t) binary := createTestBinary(sleepBinaryCmd, sleepBinaryLocation, true, t) - m, err := NewModule(&ModuleConfig{Logger: logger.Test(t)}, binary) + m, err := NewModule(&ModuleConfig{IsUncompressed: true, Logger: logger.Test(t)}, binary) require.NoError(t, err) m.Start() @@ -849,7 +883,7 @@ func TestModule_Sandbox_Timeout(t *testing.T) { binary := createTestBinary(sleepBinaryCmd, sleepBinaryLocation, true, t) tmt := 10 * time.Millisecond - m, err := NewModule(&ModuleConfig{Logger: logger.Test(t), Timeout: &tmt}, binary) + m, err := NewModule(&ModuleConfig{IsUncompressed: true, Logger: logger.Test(t), Timeout: &tmt}, binary) require.NoError(t, err) m.Start() @@ -865,10 +899,11 @@ func TestModule_Sandbox_Timeout(t *testing.T) { } func TestModule_Sandbox_CantReadFiles(t *testing.T) { + t.Parallel() ctx := tests.Context(t) binary := createTestBinary(filesBinaryCmd, filesBinaryLocation, true, t) - m, err := NewModule(&ModuleConfig{Logger: logger.Test(t)}, binary) + m, err := NewModule(&ModuleConfig{IsUncompressed: true, Logger: logger.Test(t)}, binary) require.NoError(t, err) m.Start() @@ -892,10 +927,11 @@ func TestModule_Sandbox_CantReadFiles(t *testing.T) { } func TestModule_Sandbox_CantCreateDir(t *testing.T) { + t.Parallel() ctx := tests.Context(t) binary := createTestBinary(dirsBinaryCmd, dirsBinaryLocation, true, t) - m, err := NewModule(&ModuleConfig{Logger: logger.Test(t)}, binary) + m, err := NewModule(&ModuleConfig{IsUncompressed: true, Logger: logger.Test(t)}, binary) require.NoError(t, err) m.Start() @@ -919,10 +955,11 @@ func TestModule_Sandbox_CantCreateDir(t *testing.T) { } func TestModule_Sandbox_HTTPRequest(t *testing.T) { + t.Parallel() ctx := tests.Context(t) binary := createTestBinary(httpBinaryCmd, httpBinaryLocation, true, t) - m, err := NewModule(&ModuleConfig{Logger: logger.Test(t)}, binary) + m, err := NewModule(&ModuleConfig{IsUncompressed: true, Logger: logger.Test(t)}, binary) require.NoError(t, err) m.Start() @@ -946,10 +983,11 @@ func TestModule_Sandbox_HTTPRequest(t *testing.T) { } func TestModule_Sandbox_ReadEnv(t *testing.T) { + t.Parallel() ctx := tests.Context(t) binary := createTestBinary(envBinaryCmd, envBinaryLocation, true, t) - m, err := NewModule(&ModuleConfig{Logger: logger.Test(t)}, binary) + m, err := NewModule(&ModuleConfig{IsUncompressed: true, Logger: logger.Test(t)}, binary) require.NoError(t, err) m.Start() @@ -977,6 +1015,7 @@ func TestModule_Sandbox_ReadEnv(t *testing.T) { } func TestModule_Sandbox_RandomGet(t *testing.T) { + t.Parallel() req := &wasmpb.Request{ Id: uuid.New().String(), Message: &wasmpb.Request_ComputeRequest{ @@ -996,7 +1035,8 @@ func TestModule_Sandbox_RandomGet(t *testing.T) { binary := createTestBinary(randBinaryCmd, randBinaryLocation, true, t) m, err := NewModule(&ModuleConfig{ - Logger: logger.Test(t), + Logger: logger.Test(t), + IsUncompressed: true, Determinism: &DeterminismConfig{ Seed: 42, }, @@ -1014,7 +1054,8 @@ func TestModule_Sandbox_RandomGet(t *testing.T) { binary := createTestBinary(randBinaryCmd, randBinaryLocation, true, t) m, err := NewModule(&ModuleConfig{ - Logger: logger.Test(t), + Logger: logger.Test(t), + IsUncompressed: true, }, binary) require.NoError(t, err) diff --git a/pkg/workflows/wasm/pb/wasm.pb.go b/pkg/workflows/wasm/pb/wasm.pb.go index 32b9a3aba..25b8ffb21 100644 --- a/pkg/workflows/wasm/pb/wasm.pb.go +++ b/pkg/workflows/wasm/pb/wasm.pb.go @@ -759,8 +759,9 @@ type EmitMessageRequest struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` - Labels *pb1.Map `protobuf:"bytes,2,opt,name=labels,proto3" json:"labels,omitempty"` + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` + Labels *pb1.Map `protobuf:"bytes,2,opt,name=labels,proto3" json:"labels,omitempty"` + RequestId string `protobuf:"bytes,3,opt,name=requestId,proto3" json:"requestId,omitempty"` } func (x *EmitMessageRequest) Reset() { @@ -809,6 +810,13 @@ func (x *EmitMessageRequest) GetLabels() *pb1.Map { return nil } +func (x *EmitMessageRequest) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + type Error struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1009,22 +1017,24 @@ var file_workflows_wasm_pb_wasm_proto_rawDesc = []byte{ 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x2e, 0x4d, 0x61, 0x70, 0x52, 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x62, 0x6f, 0x64, - 0x79, 0x22, 0x53, 0x0a, 0x12, 0x45, 0x6d, 0x69, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x79, 0x22, 0x71, 0x0a, 0x12, 0x45, 0x6d, 0x69, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x23, 0x0a, 0x06, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x2e, 0x4d, 0x61, 0x70, 0x52, 0x06, - 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x73, 0x22, 0x21, 0x0a, 0x05, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x12, - 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x37, 0x0a, 0x13, 0x45, 0x6d, 0x69, - 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x20, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x0a, 0x2e, 0x73, 0x64, 0x6b, 0x2e, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x52, 0x05, 0x65, 0x72, 0x72, - 0x6f, 0x72, 0x42, 0x43, 0x5a, 0x41, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, - 0x2f, 0x73, 0x6d, 0x61, 0x72, 0x74, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x61, 0x63, 0x74, 0x6b, 0x69, - 0x74, 0x2f, 0x63, 0x68, 0x61, 0x69, 0x6e, 0x6c, 0x69, 0x6e, 0x6b, 0x2d, 0x63, 0x6f, 0x6d, 0x6d, - 0x6f, 0x6e, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x73, - 0x2f, 0x73, 0x64, 0x6b, 0x2f, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x49, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x49, 0x64, 0x22, 0x21, 0x0a, 0x05, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x18, 0x0a, + 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, + 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x37, 0x0a, 0x13, 0x45, 0x6d, 0x69, 0x74, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x20, + 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0a, 0x2e, + 0x73, 0x64, 0x6b, 0x2e, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, + 0x42, 0x43, 0x5a, 0x41, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, + 0x6d, 0x61, 0x72, 0x74, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x61, 0x63, 0x74, 0x6b, 0x69, 0x74, 0x2f, + 0x63, 0x68, 0x61, 0x69, 0x6e, 0x6c, 0x69, 0x6e, 0x6b, 0x2d, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, + 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x73, 0x2f, 0x73, + 0x64, 0x6b, 0x2f, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/pkg/workflows/wasm/pb/wasm.proto b/pkg/workflows/wasm/pb/wasm.proto index 68af34c31..4f252d1f3 100644 --- a/pkg/workflows/wasm/pb/wasm.proto +++ b/pkg/workflows/wasm/pb/wasm.proto @@ -81,6 +81,7 @@ message FetchResponse { message EmitMessageRequest { string message = 1; values.Map labels = 2; + string requestId = 3; } message Error { string message = 1; } diff --git a/pkg/workflows/wasm/runner_test.go b/pkg/workflows/wasm/runner_test.go index 7f8bb8453..aaf4659e8 100644 --- a/pkg/workflows/wasm/runner_test.go +++ b/pkg/workflows/wasm/runner_test.go @@ -213,6 +213,7 @@ func TestRunner_Run_GetWorkflowSpec(t *testing.T) { func Test_createEmitFn(t *testing.T) { var ( l = logger.Test(t) + reqId = "random-id" sdkConfig = &RuntimeConfig{ MaxFetchResponseSizeBytes: 1_000, Metadata: &capabilities.RequestMetadata{ @@ -221,6 +222,7 @@ func Test_createEmitFn(t *testing.T) { WorkflowName: "workflow_name", WorkflowOwner: "workflow_owner_address", }, + RequestID: &reqId, } giveMsg = "testing guest" giveLabels = map[string]string{ diff --git a/pkg/workflows/wasm/sdk.go b/pkg/workflows/wasm/sdk.go index 577fc9801..341c580b1 100644 --- a/pkg/workflows/wasm/sdk.go +++ b/pkg/workflows/wasm/sdk.go @@ -109,8 +109,9 @@ func createEmitFn( // Marshal the message and labels into a protobuf message b, err := proto.Marshal(&wasmpb.EmitMessageRequest{ - Message: msg, - Labels: values.ProtoMap(vm), + RequestId: *sdkConfig.RequestID, + Message: msg, + Labels: values.ProtoMap(vm), }) if err != nil { return err