From bb4f838ce0e7667c3ab97d918a06021dd9af34de Mon Sep 17 00:00:00 2001 From: "yu.deng" Date: Tue, 19 Nov 2024 11:40:49 +0800 Subject: [PATCH] feat: support CtxOption for SetCtx --- plugins/wasm-go/extensions/ai-proxy/main.go | 6 +- plugins/wasm-go/pkg/wrapper/log_wrapper.go | 45 +++-- plugins/wasm-go/pkg/wrapper/plugin_wrapper.go | 166 +++++++++++++----- 3 files changed, 155 insertions(+), 62 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 7527fb8e4b..01ff015b07 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -190,14 +190,14 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) action, err := handler.OnResponseHeaders(ctx, apiName, log) if err == nil { - checkStream(&ctx, &log) + checkStream(&ctx, log) return action } _ = util.SendResponse(500, "ai-proxy.proc_resp_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process response headers: %v", err)) return types.ActionContinue } - checkStream(&ctx, &log) + checkStream(&ctx, log) _, needHandleBody := activeProvider.(provider.ResponseBodyHandler) _, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler) if !needHandleBody && !needHandleStreamingBody { @@ -254,7 +254,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi return types.ActionContinue } -func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) { +func checkStream(ctx *wrapper.HttpContext, log wrapper.Log) { contentType, err := proxywasm.GetHttpResponseHeader("Content-Type") if err != nil || !strings.HasPrefix(contentType, "text/event-stream") { if err != nil { diff --git a/plugins/wasm-go/pkg/wrapper/log_wrapper.go b/plugins/wasm-go/pkg/wrapper/log_wrapper.go index 65c0aa346c..b8da27db39 100644 --- a/plugins/wasm-go/pkg/wrapper/log_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/log_wrapper.go @@ -31,11 +31,26 @@ const ( LogLevelCritical ) -type Log struct { +type Log interface { + Trace(msg string) + Tracef(format string, args ...interface{}) + Debug(msg string) + Debugf(format string, args ...interface{}) + Info(msg string) + Infof(format string, args ...interface{}) + Warn(msg string) + Warnf(format string, args ...interface{}) + Error(msg string) + Errorf(format string, args ...interface{}) + Critical(msg string) + Criticalf(format string, args ...interface{}) +} + +type DefaultLog struct { pluginName string } -func (l Log) log(level LogLevel, msg string) { +func (l *DefaultLog) log(level LogLevel, msg string) { requestIDRaw, _ := proxywasm.GetProperty([]string{"x_request_id"}) requestID := string(requestIDRaw) if requestID == "" { @@ -58,7 +73,7 @@ func (l Log) log(level LogLevel, msg string) { } } -func (l Log) logFormat(level LogLevel, format string, args ...interface{}) { +func (l *DefaultLog) logFormat(level LogLevel, format string, args ...interface{}) { requestIDRaw, _ := proxywasm.GetProperty([]string{"x_request_id"}) requestID := string(requestIDRaw) if requestID == "" { @@ -81,50 +96,50 @@ func (l Log) logFormat(level LogLevel, format string, args ...interface{}) { } } -func (l Log) Trace(msg string) { +func (l *DefaultLog) Trace(msg string) { l.log(LogLevelTrace, msg) } -func (l Log) Tracef(format string, args ...interface{}) { +func (l *DefaultLog) Tracef(format string, args ...interface{}) { l.logFormat(LogLevelTrace, format, args...) } -func (l Log) Debug(msg string) { +func (l *DefaultLog) Debug(msg string) { l.log(LogLevelDebug, msg) } -func (l Log) Debugf(format string, args ...interface{}) { +func (l *DefaultLog) Debugf(format string, args ...interface{}) { l.logFormat(LogLevelDebug, format, args...) } -func (l Log) Info(msg string) { +func (l *DefaultLog) Info(msg string) { l.log(LogLevelInfo, msg) } -func (l Log) Infof(format string, args ...interface{}) { +func (l *DefaultLog) Infof(format string, args ...interface{}) { l.logFormat(LogLevelInfo, format, args...) } -func (l Log) Warn(msg string) { +func (l *DefaultLog) Warn(msg string) { l.log(LogLevelWarn, msg) } -func (l Log) Warnf(format string, args ...interface{}) { +func (l *DefaultLog) Warnf(format string, args ...interface{}) { l.logFormat(LogLevelWarn, format, args...) } -func (l Log) Error(msg string) { +func (l *DefaultLog) Error(msg string) { l.log(LogLevelError, msg) } -func (l Log) Errorf(format string, args ...interface{}) { +func (l *DefaultLog) Errorf(format string, args ...interface{}) { l.logFormat(LogLevelError, format, args...) } -func (l Log) Critical(msg string) { +func (l *DefaultLog) Critical(msg string) { l.log(LogLevelCritical, msg) } -func (l Log) Criticalf(format string, args ...interface{}) { +func (l *DefaultLog) Criticalf(format string, args ...interface{}) { l.logFormat(LogLevelCritical, format, args...) } diff --git a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go b/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go index cce0456f8c..3690f4fa44 100644 --- a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go @@ -98,79 +98,157 @@ func RegisteTickFunc(tickPeriod int64, tickFunc func()) { globalOnTickFuncs = append(globalOnTickFuncs, TickFuncEntry{0, tickPeriod, tickFunc}) } -func SetCtx[PluginConfig any](pluginName string, setFuncs ...SetPluginFunc[PluginConfig]) { - proxywasm.SetVMContext(NewCommonVmCtx(pluginName, setFuncs...)) +func SetCtx[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) { + proxywasm.SetVMContext(NewCommonVmCtx(pluginName, options...)) } -type SetPluginFunc[PluginConfig any] func(*CommonVmCtx[PluginConfig]) +func SetCtxWithOptions[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) { + proxywasm.SetVMContext(NewCommonVmCtxWithOptions(pluginName, options...)) +} -func ParseConfigBy[PluginConfig any](f ParseConfigFunc[PluginConfig]) SetPluginFunc[PluginConfig] { - return func(ctx *CommonVmCtx[PluginConfig]) { - ctx.parseConfig = f - } +type CtxOption[PluginConfig any] interface { + Apply(*CommonVmCtx[PluginConfig]) } -func ParseOverrideConfigBy[PluginConfig any](f ParseConfigFunc[PluginConfig], g ParseRuleConfigFunc[PluginConfig]) SetPluginFunc[PluginConfig] { - return func(ctx *CommonVmCtx[PluginConfig]) { - ctx.parseConfig = f - ctx.parseRuleConfig = g - } +type parseConfigOption[PluginConfig any] struct { + f ParseConfigFunc[PluginConfig] } -func ProcessRequestHeadersBy[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) SetPluginFunc[PluginConfig] { - return func(ctx *CommonVmCtx[PluginConfig]) { - ctx.onHttpRequestHeaders = f - } +func (o parseConfigOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { + ctx.parseConfig = o.f } -func ProcessRequestBodyBy[PluginConfig any](f onHttpBodyFunc[PluginConfig]) SetPluginFunc[PluginConfig] { - return func(ctx *CommonVmCtx[PluginConfig]) { - ctx.onHttpRequestBody = f - } +func ParseConfigBy[PluginConfig any](f ParseConfigFunc[PluginConfig]) CtxOption[PluginConfig] { + return parseConfigOption[PluginConfig]{f} } -func ProcessStreamingRequestBodyBy[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) SetPluginFunc[PluginConfig] { - return func(ctx *CommonVmCtx[PluginConfig]) { - ctx.onHttpStreamingRequestBody = f - } +type parseOverrideConfigOption[PluginConfig any] struct { + parseConfigF ParseConfigFunc[PluginConfig] + parseRuleConfigF ParseRuleConfigFunc[PluginConfig] } -func ProcessResponseHeadersBy[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) SetPluginFunc[PluginConfig] { - return func(ctx *CommonVmCtx[PluginConfig]) { - ctx.onHttpResponseHeaders = f - } +func (o *parseOverrideConfigOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { + ctx.parseConfig = o.parseConfigF + ctx.parseRuleConfig = o.parseRuleConfigF } -func ProcessResponseBodyBy[PluginConfig any](f onHttpBodyFunc[PluginConfig]) SetPluginFunc[PluginConfig] { - return func(ctx *CommonVmCtx[PluginConfig]) { - ctx.onHttpResponseBody = f - } +func ParseOverrideConfigBy[PluginConfig any](f ParseConfigFunc[PluginConfig], g ParseRuleConfigFunc[PluginConfig]) CtxOption[PluginConfig] { + return &parseOverrideConfigOption[PluginConfig]{f, g} } -func ProcessStreamingResponseBodyBy[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) SetPluginFunc[PluginConfig] { - return func(ctx *CommonVmCtx[PluginConfig]) { - ctx.onHttpStreamingResponseBody = f - } +type onProcessRequestHeadersOption[PluginConfig any] struct { + f onHttpHeadersFunc[PluginConfig] } -func ProcessStreamDoneBy[PluginConfig any](f onHttpStreamDoneFunc[PluginConfig]) SetPluginFunc[PluginConfig] { - return func(ctx *CommonVmCtx[PluginConfig]) { - ctx.onHttpStreamDone = f - } +func (o *onProcessRequestHeadersOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { + ctx.onHttpRequestHeaders = o.f +} + +func ProcessRequestHeadersBy[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) CtxOption[PluginConfig] { + return &onProcessRequestHeadersOption[PluginConfig]{f} +} + +type onProcessRequestBodyOption[PluginConfig any] struct { + f onHttpBodyFunc[PluginConfig] +} + +func (o *onProcessRequestBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { + ctx.onHttpRequestBody = o.f +} + +func ProcessRequestBodyBy[PluginConfig any](f onHttpBodyFunc[PluginConfig]) CtxOption[PluginConfig] { + return &onProcessRequestBodyOption[PluginConfig]{f} +} + +type onProcessStreamingRequestBodyOption[PluginConfig any] struct { + f onHttpStreamingBodyFunc[PluginConfig] +} + +func (o *onProcessStreamingRequestBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { + ctx.onHttpStreamingRequestBody = o.f +} + +func ProcessStreamingRequestBodyBy[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) CtxOption[PluginConfig] { + return &onProcessStreamingRequestBodyOption[PluginConfig]{f} +} + +type onProcessResponseHeadersOption[PluginConfig any] struct { + f onHttpHeadersFunc[PluginConfig] +} + +func (o *onProcessResponseHeadersOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { + ctx.onHttpResponseHeaders = o.f +} + +func ProcessResponseHeadersBy[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) CtxOption[PluginConfig] { + return &onProcessResponseHeadersOption[PluginConfig]{f} +} + +type onProcessResponseBodyOption[PluginConfig any] struct { + f onHttpBodyFunc[PluginConfig] +} + +func (o *onProcessResponseBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { + ctx.onHttpResponseBody = o.f +} + +func ProcessResponseBodyBy[PluginConfig any](f onHttpBodyFunc[PluginConfig]) CtxOption[PluginConfig] { + return &onProcessResponseBodyOption[PluginConfig]{f} +} + +type onProcessStreamingResponseBodyOption[PluginConfig any] struct { + f onHttpStreamingBodyFunc[PluginConfig] +} + +func (o *onProcessStreamingResponseBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { + ctx.onHttpStreamingResponseBody = o.f +} + +func ProcessStreamingResponseBodyBy[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) CtxOption[PluginConfig] { + return &onProcessStreamingResponseBodyOption[PluginConfig]{f} +} + +type onProcessStreamDoneOption[PluginConfig any] struct { + f onHttpStreamDoneFunc[PluginConfig] +} + +func (o *onProcessStreamDoneOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { + ctx.onHttpStreamDone = o.f +} + +func ProcessStreamDoneBy[PluginConfig any](f onHttpStreamDoneFunc[PluginConfig]) CtxOption[PluginConfig] { + return &onProcessStreamDoneOption[PluginConfig]{f} +} + +type logOption[PluginConfig any] struct { + logger Log +} + +func (o *logOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { + ctx.log = o.logger +} + +func WithLogger[PluginConfig any](logger Log) CtxOption[PluginConfig] { + return &logOption[PluginConfig]{logger} } func parseEmptyPluginConfig[PluginConfig any](gjson.Result, *PluginConfig, Log) error { return nil } -func NewCommonVmCtx[PluginConfig any](pluginName string, setFuncs ...SetPluginFunc[PluginConfig]) *CommonVmCtx[PluginConfig] { +func NewCommonVmCtx[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) *CommonVmCtx[PluginConfig] { + logger := &DefaultLog{pluginName} + opts := append([]CtxOption[PluginConfig]{WithLogger[PluginConfig](logger)}, options...) + return NewCommonVmCtxWithOptions(pluginName, opts...) +} + +func NewCommonVmCtxWithOptions[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) *CommonVmCtx[PluginConfig] { ctx := &CommonVmCtx[PluginConfig]{ pluginName: pluginName, - log: Log{pluginName}, hasCustomConfig: true, } - for _, set := range setFuncs { - set(ctx) + for _, opt := range options { + opt.Apply(ctx) } if ctx.parseConfig == nil { var config PluginConfig