From 385785dcee4fb290deafa43555acb028f8ec6c78 Mon Sep 17 00:00:00 2001 From: Liu Ziming Date: Wed, 31 Jul 2024 13:52:51 +0800 Subject: [PATCH] ci: add more tests (#21) add more tests! --- .../instrumenter/db/db_client_extractor.go | 20 +- .../db/db_client_extractor_test.go | 93 +++++ .../db/db_span_name_extractor_test.go | 18 + .../instrumenter/http/http_attrs_extractor.go | 26 +- .../http/http_attrs_extractor_test.go | 325 ++++++++++++++++++ .../http/http_span_name_extractor.go | 12 +- .../http/http_span_name_extractor_test.go | 68 ++++ .../http/http_status_code_converter_test.go | 23 ++ .../message/message_attrs_extractor.go | 40 +-- .../message/message_attrs_extractor_test.go | 156 +++++++++ .../message/message_span_name_extractor.go | 5 +- .../message_span_name_extractor_test.go | 98 ++++++ .../net/network_attrs_extractor.go | 36 +- .../net/network_attrs_extractor_test.go | 91 +++++ .../instrumenter/rpc/rpc_attrs_extractor.go | 13 +- .../rpc/rpc_attrs_extractor_test.go | 94 +++++ .../rpc/rpc_span_name_extractor_test.go | 45 +++ pkg/verifier/access.go | 4 - pkg/verifier/access_test.go | 84 +++++ pkg/verifier/runner.go | 2 +- pkg/verifier/runner_test.go | 97 ++++++ pkg/verifier/util.go | 3 +- pkg/verifier/util_test.go | 32 +- pkg/verifier/verifier.go | 1 + pkg/verifier/verifier_test.go | 40 +++ 25 files changed, 1323 insertions(+), 103 deletions(-) create mode 100644 pkg/inst-api-semconv/instrumenter/db/db_client_extractor_test.go create mode 100644 pkg/inst-api-semconv/instrumenter/db/db_span_name_extractor_test.go create mode 100644 pkg/inst-api-semconv/instrumenter/http/http_attrs_extractor_test.go create mode 100644 pkg/inst-api-semconv/instrumenter/http/http_span_name_extractor_test.go create mode 100644 pkg/inst-api-semconv/instrumenter/http/http_status_code_converter_test.go create mode 100644 pkg/inst-api-semconv/instrumenter/message/message_attrs_extractor_test.go create mode 100644 pkg/inst-api-semconv/instrumenter/message/message_span_name_extractor_test.go create mode 100644 pkg/inst-api-semconv/instrumenter/net/network_attrs_extractor_test.go create mode 100644 pkg/inst-api-semconv/instrumenter/rpc/rpc_attrs_extractor_test.go create mode 100644 pkg/inst-api-semconv/instrumenter/rpc/rpc_span_name_extractor_test.go create mode 100644 pkg/verifier/access_test.go create mode 100644 pkg/verifier/runner_test.go create mode 100644 pkg/verifier/verifier_test.go diff --git a/pkg/inst-api-semconv/instrumenter/db/db_client_extractor.go b/pkg/inst-api-semconv/instrumenter/db/db_client_extractor.go index 60d83675..5b8f17ce 100644 --- a/pkg/inst-api-semconv/instrumenter/db/db_client_extractor.go +++ b/pkg/inst-api-semconv/instrumenter/db/db_client_extractor.go @@ -4,15 +4,9 @@ import ( "context" "github.com/alibaba/opentelemetry-go-auto-instrumentation/pkg/inst-api/utils" "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.19.0" ) -const db_name = attribute.Key("db.name") -const db_system = attribute.Key("db.system") -const db_user = attribute.Key("db.user") -const db_connection_string = attribute.Key("db.connection_string") -const db_statement = attribute.Key("db.statement") -const db_operation = attribute.Key("db.operation") - type DbClientCommonAttrsExtractor[REQUEST any, RESPONSE any, GETTER DbClientCommonAttrsGetter[REQUEST]] struct { getter GETTER } @@ -23,16 +17,16 @@ func (d *DbClientCommonAttrsExtractor[REQUEST, RESPONSE, GETTER]) GetSpanKey() a func (d *DbClientCommonAttrsExtractor[REQUEST, RESPONSE, GETTER]) OnStart(attributes []attribute.KeyValue, parentContext context.Context, request REQUEST) []attribute.KeyValue { attributes = append(attributes, attribute.KeyValue{ - Key: db_name, + Key: semconv.DBNameKey, Value: attribute.StringValue(d.getter.GetName(request)), }, attribute.KeyValue{ - Key: db_system, + Key: semconv.DBSystemKey, Value: attribute.StringValue(d.getter.GetSystem(request)), }, attribute.KeyValue{ - Key: db_user, + Key: semconv.DBUserKey, Value: attribute.StringValue(d.getter.GetUser(request)), }, attribute.KeyValue{ - Key: db_connection_string, + Key: semconv.DBConnectionStringKey, Value: attribute.StringValue(d.getter.GetConnectionString(request)), }) return attributes @@ -49,10 +43,10 @@ type DbClientAttrsExtractor[REQUEST any, RESPONSE any, GETTER DbClientAttrsGette func (d *DbClientAttrsExtractor[REQUEST, RESPONSE, GETTER]) OnStart(attrs []attribute.KeyValue, parentContext context.Context, request REQUEST) []attribute.KeyValue { attrs = d.base.OnStart(attrs, parentContext, request) attrs = append(attrs, attribute.KeyValue{ - Key: db_statement, + Key: semconv.DBStatementKey, Value: attribute.StringValue(d.base.getter.GetStatement(request)), }, attribute.KeyValue{ - Key: db_operation, + Key: semconv.DBOperationKey, Value: attribute.StringValue(d.base.getter.GetOperation(request)), }) return attrs diff --git a/pkg/inst-api-semconv/instrumenter/db/db_client_extractor_test.go b/pkg/inst-api-semconv/instrumenter/db/db_client_extractor_test.go new file mode 100644 index 00000000..96ad04fb --- /dev/null +++ b/pkg/inst-api-semconv/instrumenter/db/db_client_extractor_test.go @@ -0,0 +1,93 @@ +package db + +import ( + "context" + "github.com/alibaba/opentelemetry-go-auto-instrumentation/pkg/inst-api/utils" + "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.19.0" + "log" + "testing" +) + +type testRequest struct { + Name string + Operation string +} + +type testResponse struct { +} + +type mongoAttrsGetter struct { +} + +func (m mongoAttrsGetter) GetSystem(request testRequest) string { + return "test" +} + +func (m mongoAttrsGetter) GetUser(request testRequest) string { + return "test" +} + +func (m mongoAttrsGetter) GetName(request testRequest) string { + if request.Name != "" { + return request.Name + } + return "" +} + +func (m mongoAttrsGetter) GetConnectionString(request testRequest) string { + return "test" +} + +func (m mongoAttrsGetter) GetStatement(request testRequest) string { + return "test" +} + +func (m mongoAttrsGetter) GetOperation(request testRequest) string { + if request.Operation != "" { + return request.Operation + } + return "" +} + +func TestGetSpanKey(t *testing.T) { + dbExtractor := &DbClientAttrsExtractor[testRequest, any, mongoAttrsGetter]{} + if dbExtractor.GetSpanKey() != utils.DB_CLIENT_KEY { + t.Fatalf("Should have returned DB_CLIENT_KEY") + } +} + +func TestDbClientExtractorStart(t *testing.T) { + dbExtractor := DbClientAttrsExtractor[testRequest, testResponse, mongoAttrsGetter]{} + attrs := make([]attribute.KeyValue, 0) + parentContext := context.Background() + attrs = dbExtractor.OnStart(attrs, parentContext, testRequest{Name: "test"}) + if attrs[0].Key != semconv.DBNameKey || attrs[0].Value.AsString() != "test" { + t.Fatalf("db name should be test") + } + if attrs[1].Key != semconv.DBSystemKey || attrs[1].Value.AsString() != "test" { + t.Fatalf("db system should be test") + } + if attrs[2].Key != semconv.DBUserKey || attrs[2].Value.AsString() != "test" { + t.Fatalf("db user should be test") + } + if attrs[3].Key != semconv.DBConnectionStringKey || attrs[3].Value.AsString() != "test" { + t.Fatalf("db connection key should be test") + } + if attrs[4].Key != semconv.DBStatementKey || attrs[4].Value.AsString() != "test" { + t.Fatalf("db statement key should be test") + } + if attrs[5].Key != semconv.DBOperationKey || attrs[5].Value.AsString() != "" { + t.Fatalf("db operation key should be empty") + } +} + +func TestDbClientExtractorEnd(t *testing.T) { + dbExtractor := DbClientAttrsExtractor[testRequest, testResponse, mongoAttrsGetter]{} + attrs := make([]attribute.KeyValue, 0) + parentContext := context.Background() + attrs = dbExtractor.OnEnd(attrs, parentContext, testRequest{Name: "test"}, testResponse{}, nil) + if len(attrs) != 0 { + log.Fatal("attrs should be empty") + } +} diff --git a/pkg/inst-api-semconv/instrumenter/db/db_span_name_extractor_test.go b/pkg/inst-api-semconv/instrumenter/db/db_span_name_extractor_test.go new file mode 100644 index 00000000..d5199951 --- /dev/null +++ b/pkg/inst-api-semconv/instrumenter/db/db_span_name_extractor_test.go @@ -0,0 +1,18 @@ +package db + +import "testing" + +func TestDbNameExtractor(t *testing.T) { + dbSpanNameExtractor := DBSpanNameExtractor[testRequest]{ + getter: mongoAttrsGetter{}, + } + if dbSpanNameExtractor.Extract(testRequest{}) != "DB Query" { + t.Fatalf("Should have returned DB_QUERY") + } + if dbSpanNameExtractor.Extract(testRequest{Name: "test"}) != "test" { + t.Fatalf("Should have returned test") + } + if dbSpanNameExtractor.Extract(testRequest{Operation: "op_test"}) != "op_test" { + t.Fatalf("Should have returned op_test") + } +} diff --git a/pkg/inst-api-semconv/instrumenter/http/http_attrs_extractor.go b/pkg/inst-api-semconv/instrumenter/http/http_attrs_extractor.go index 03e4f098..42acc25b 100644 --- a/pkg/inst-api-semconv/instrumenter/http/http_attrs_extractor.go +++ b/pkg/inst-api-semconv/instrumenter/http/http_attrs_extractor.go @@ -4,19 +4,9 @@ import ( "context" "github.com/alibaba/opentelemetry-go-auto-instrumentation/pkg/inst-api-semconv/instrumenter/net" "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" ) -const http_request_method = attribute.Key("http.request.method") -const http_response_status_code = attribute.Key("http.response.status_code") -const http_route = attribute.Key("http.route") - -const network_protocol_name = attribute.Key("network.protocol.name") -const network_protocol_version = attribute.Key("network.protocol.version") - -const url_full = attribute.Key("url.full") - -const user_agent_original = attribute.Key("user_agent.original") - // TODO: http.route type HttpCommonAttrsExtractor[REQUEST any, RESPONSE any, GETTER1 HttpCommonAttrsGetter[REQUEST, RESPONSE], GETTER2 net.NetworkAttrsGetter[REQUEST, RESPONSE]] struct { @@ -27,7 +17,7 @@ type HttpCommonAttrsExtractor[REQUEST any, RESPONSE any, GETTER1 HttpCommonAttrs func (h *HttpCommonAttrsExtractor[REQUEST, RESPONSE, GETTER1, GETTER2]) OnStart(attributes []attribute.KeyValue, parentContext context.Context, request REQUEST) []attribute.KeyValue { attributes = append(attributes, attribute.KeyValue{ - Key: http_request_method, + Key: semconv.HTTPRequestMethodKey, Value: attribute.StringValue(h.httpGetter.GetRequestMethod(request)), }) return attributes @@ -38,13 +28,13 @@ func (h *HttpCommonAttrsExtractor[REQUEST, RESPONSE, GETTER, GETTER2]) OnEnd(att protocolName := h.netGetter.GetNetworkProtocolName(request, response) protocolVersion := h.netGetter.GetNetworkProtocolVersion(request, response) attributes = append(attributes, attribute.KeyValue{ - Key: http_response_status_code, + Key: semconv.HTTPResponseStatusCodeKey, Value: attribute.IntValue(statusCode), }, attribute.KeyValue{ - Key: network_protocol_name, + Key: semconv.NetworkProtocolNameKey, Value: attribute.StringValue(protocolName), }, attribute.KeyValue{ - Key: network_protocol_version, + Key: semconv.NetworkProtocolVersionKey, Value: attribute.StringValue(protocolVersion), }) return attributes @@ -60,7 +50,7 @@ func (h *HttpClientAttrsExtractor[REQUEST, RESPONSE, GETTER1, GETTER2]) OnStart( fullUrl := h.base.httpGetter.GetUrlFull(request) // TODO: add resend count attributes = append(attributes, attribute.KeyValue{ - Key: url_full, + Key: semconv.URLFullKey, Value: attribute.StringValue(fullUrl), }) return attributes @@ -88,10 +78,10 @@ func (h *HttpServerAttrsExtractor[REQUEST, RESPONSE, GETTER1, GETTER2, GETTER3]) } else { firstUserAgent = "" } - attributes = append(attributes, attribute.KeyValue{Key: http_route, + attributes = append(attributes, attribute.KeyValue{Key: semconv.HTTPRouteKey, Value: attribute.StringValue(h.base.httpGetter.GetHttpRoute(request)), }, attribute.KeyValue{ - Key: user_agent_original, + Key: semconv.UserAgentOriginalKey, Value: attribute.StringValue(firstUserAgent), }) return attributes diff --git a/pkg/inst-api-semconv/instrumenter/http/http_attrs_extractor_test.go b/pkg/inst-api-semconv/instrumenter/http/http_attrs_extractor_test.go new file mode 100644 index 00000000..7c54866e --- /dev/null +++ b/pkg/inst-api-semconv/instrumenter/http/http_attrs_extractor_test.go @@ -0,0 +1,325 @@ +package http + +import ( + "context" + "github.com/alibaba/opentelemetry-go-auto-instrumentation/pkg/inst-api-semconv/instrumenter/net" + "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" + "testing" +) + +type httpServerAttrsGetter struct { +} + +type httpClientAttrsGetter struct { +} + +type networkAttrsGetter struct { +} + +type urlAttrsGetter struct { +} + +func (u urlAttrsGetter) GetUrlScheme(request testRequest) string { + return "url-scheme" +} + +func (u urlAttrsGetter) GetUrlPath(request testRequest) string { + return "url-path" +} + +func (u urlAttrsGetter) GetUrlQuery(request testRequest) string { + return "url-query" +} + +func (n networkAttrsGetter) GetNetworkType(request testRequest, response testResponse) string { + return "network-type" +} + +func (n networkAttrsGetter) GetNetworkTransport(request testRequest, response testResponse) string { + return "network-transport" +} + +func (n networkAttrsGetter) GetNetworkProtocolName(request testRequest, response testResponse) string { + return "network-protocol-name" +} + +func (n networkAttrsGetter) GetNetworkProtocolVersion(request testRequest, response testResponse) string { + return "network-protocol-version" +} + +func (n networkAttrsGetter) GetNetworkLocalInetAddress(request testRequest, response testResponse) string { + return "network-local-inet-address" +} + +func (n networkAttrsGetter) GetNetworkLocalPort(request testRequest, response testResponse) int { + return 8080 +} + +func (n networkAttrsGetter) GetNetworkPeerInetAddress(request testRequest, response testResponse) string { + return "network-peer-inet-address" +} + +func (n networkAttrsGetter) GetNetworkPeerPort(request testRequest, response testResponse) int { + return 8080 +} + +func (h httpClientAttrsGetter) GetRequestMethod(request testRequest) string { + return "GET" +} + +func (h httpClientAttrsGetter) GetHttpRequestHeader(request testRequest, name string) []string { + return []string{"request-header"} +} + +func (h httpClientAttrsGetter) GetHttpResponseStatusCode(request testRequest, response testResponse, err error) int { + return 200 +} + +func (h httpClientAttrsGetter) GetHttpResponseHeader(request testRequest, response testResponse, name string) []string { + return []string{"response-header"} +} + +func (h httpClientAttrsGetter) GetErrorType(request testRequest, response testResponse, err error) string { + return "" +} + +func (h httpClientAttrsGetter) GetNetworkType(request testRequest, response testResponse) string { + return "ipv4" +} + +func (h httpClientAttrsGetter) GetNetworkTransport(request testRequest, response testResponse) string { + return "TCP" +} + +func (h httpClientAttrsGetter) GetNetworkProtocolName(request testRequest, response testResponse) string { + return "HTTP" +} + +func (h httpClientAttrsGetter) GetNetworkProtocolVersion(request testRequest, response testResponse) string { + return "HTTP/1.1" +} + +func (h httpClientAttrsGetter) GetNetworkLocalInetAddress(request testRequest, response testResponse) string { + return "127.0.0.1" +} + +func (h httpClientAttrsGetter) GetNetworkLocalPort(request testRequest, response testResponse) int { + return 8080 +} + +func (h httpClientAttrsGetter) GetNetworkPeerInetAddress(request testRequest, response testResponse) string { + return "127.0.0.1" +} + +func (h httpClientAttrsGetter) GetNetworkPeerPort(request testRequest, response testResponse) int { + return 8080 +} + +func (h httpClientAttrsGetter) GetUrlFull(request testRequest) string { + return "url-full" +} + +func (h httpClientAttrsGetter) GetServerAddress(request testRequest) string { + return "server-address" +} + +func (h httpServerAttrsGetter) GetRequestMethod(request testRequest) string { + return "GET" +} + +func (h httpServerAttrsGetter) GetHttpRequestHeader(request testRequest, name string) []string { + return []string{"request-header"} +} + +func (h httpServerAttrsGetter) GetHttpResponseStatusCode(request testRequest, response testResponse, err error) int { + return 200 +} + +func (h httpServerAttrsGetter) GetHttpResponseHeader(request testRequest, response testResponse, name string) []string { + return []string{"response-header"} +} + +func (h httpServerAttrsGetter) GetErrorType(request testRequest, response testResponse, err error) string { + return "error-type" +} + +func (h httpServerAttrsGetter) GetUrlScheme(request testRequest) string { + return "url-scheme" +} + +func (h httpServerAttrsGetter) GetUrlPath(request testRequest) string { + return "url-path" +} + +func (h httpServerAttrsGetter) GetUrlQuery(request testRequest) string { + return "url-query" +} + +func (h httpServerAttrsGetter) GetNetworkType(request testRequest, response testResponse) string { + return "network-type" +} + +func (h httpServerAttrsGetter) GetNetworkTransport(request testRequest, response testResponse) string { + return "network-transport" +} + +func (h httpServerAttrsGetter) GetNetworkProtocolName(request testRequest, response testResponse) string { + return "network-protocol-name" +} + +func (h httpServerAttrsGetter) GetNetworkProtocolVersion(request testRequest, response testResponse) string { + return "network-protocol-version" +} + +func (h httpServerAttrsGetter) GetNetworkLocalInetAddress(request testRequest, response testResponse) string { + return "127.0.0.1" +} + +func (h httpServerAttrsGetter) GetNetworkLocalPort(request testRequest, response testResponse) int { + return 8080 +} + +func (h httpServerAttrsGetter) GetNetworkPeerInetAddress(request testRequest, response testResponse) string { + return "127.0.0.1" +} + +func (h httpServerAttrsGetter) GetNetworkPeerPort(request testRequest, response testResponse) int { + return 8080 +} + +func (h httpServerAttrsGetter) GetHttpRoute(request testRequest) string { + return "http-route" +} + +func TestHttpClientExtractorStart(t *testing.T) { + httpClientExtractor := HttpClientAttrsExtractor[testRequest, testResponse, httpClientAttrsGetter, networkAttrsGetter]{ + base: HttpCommonAttrsExtractor[testRequest, testResponse, httpClientAttrsGetter, networkAttrsGetter]{}, + networkExtractor: net.NetworkAttrsExtractor[testRequest, testResponse, networkAttrsGetter]{}, + } + attrs := make([]attribute.KeyValue, 0) + parentContext := context.Background() + attrs = httpClientExtractor.OnStart(attrs, parentContext, testRequest{}) + if attrs[0].Key != semconv.HTTPRequestMethodKey || attrs[0].Value.AsString() != "GET" { + t.Fatalf("http method should be GET") + } + if attrs[1].Key != semconv.URLFullKey || attrs[1].Value.AsString() != "url-full" { + t.Fatalf("urlfull should be url-full") + } +} + +func TestHttpClientExtractorEnd(t *testing.T) { + httpClientExtractor := HttpClientAttrsExtractor[testRequest, testResponse, httpClientAttrsGetter, networkAttrsGetter]{ + base: HttpCommonAttrsExtractor[testRequest, testResponse, httpClientAttrsGetter, networkAttrsGetter]{}, + networkExtractor: net.NetworkAttrsExtractor[testRequest, testResponse, networkAttrsGetter]{}, + } + attrs := make([]attribute.KeyValue, 0) + parentContext := context.Background() + attrs = httpClientExtractor.OnEnd(attrs, parentContext, testRequest{}, testResponse{}, nil) + if attrs[0].Key != semconv.HTTPResponseStatusCodeKey || attrs[0].Value.AsInt64() != 200 { + t.Fatalf("status code should be 200") + } + if attrs[1].Key != semconv.NetworkProtocolNameKey || attrs[1].Value.AsString() != "network-protocol-name" { + t.Fatalf("wrong network protocol name") + } + if attrs[2].Key != semconv.NetworkProtocolVersionKey || attrs[2].Value.AsString() != "network-protocol-version" { + t.Fatalf("wrong network protocol version") + } + if attrs[3].Key != semconv.NetworkTransportKey || attrs[3].Value.AsString() != "network-transport" { + t.Fatalf("wrong network transport") + } + if attrs[4].Key != semconv.NetworkTypeKey || attrs[4].Value.AsString() != "network-type" { + t.Fatalf("wrong network type") + } + if attrs[5].Key != semconv.NetworkProtocolNameKey || attrs[5].Value.AsString() != "network-protocol-name" { + t.Fatalf("wrong network protocol name") + } + if attrs[6].Key != semconv.NetworkProtocolVersionKey || attrs[6].Value.AsString() != "network-protocol-version" { + t.Fatalf("wrong network protocol version") + } + if attrs[7].Key != semconv.NetworkLocalAddressKey || attrs[7].Value.AsString() != "network-local-inet-address" { + t.Fatalf("wrong network protocol inet address") + } + if attrs[8].Key != semconv.NetworkPeerAddressKey || attrs[8].Value.AsString() != "network-peer-inet-address" { + t.Fatalf("wrong network peer address") + } + if attrs[9].Key != semconv.NetworkLocalPortKey || attrs[9].Value.AsInt64() != 8080 { + t.Fatalf("wrong network local port") + } + if attrs[10].Key != semconv.NetworkPeerPortKey || attrs[10].Value.AsInt64() != 8080 { + t.Fatalf("wrong network peer port") + } +} + +func TestHttpServerExtractorStart(t *testing.T) { + httpServerExtractor := HttpServerAttrsExtractor[testRequest, testResponse, httpServerAttrsGetter, networkAttrsGetter, urlAttrsGetter]{ + base: HttpCommonAttrsExtractor[testRequest, testResponse, httpServerAttrsGetter, networkAttrsGetter]{}, + networkExtractor: net.NetworkAttrsExtractor[testRequest, testResponse, networkAttrsGetter]{}, + urlExtractor: net.UrlAttrsExtractor[testRequest, testResponse, urlAttrsGetter]{}, + } + attrs := make([]attribute.KeyValue, 0) + parentContext := context.Background() + attrs = httpServerExtractor.OnStart(attrs, parentContext, testRequest{}) + if attrs[0].Key != semconv.HTTPRequestMethodKey || attrs[0].Value.AsString() != "GET" { + t.Fatalf("http method should be GET") + } + if attrs[1].Key != semconv.URLSchemeKey || attrs[1].Value.AsString() != "url-scheme" { + t.Fatalf("urlscheme should be url-scheme") + } + if attrs[2].Key != semconv.URLPathKey || attrs[2].Value.AsString() != "url-path" { + t.Fatalf("urlpath should be url-path") + } + if attrs[3].Key != semconv.URLQueryKey || attrs[3].Value.AsString() != "url-query" { + t.Fatalf("urlquery should be url-query") + } + if attrs[4].Key != semconv.HTTPRouteKey || attrs[4].Value.AsString() != "http-route" { + t.Fatalf("httproute should be http-route") + } + if attrs[5].Key != semconv.UserAgentOriginalKey || attrs[5].Value.AsString() != "request-header" { + t.Fatalf("user agent original should be request-header") + } +} + +func TestHttpServerExtractorEnd(t *testing.T) { + httpServerExtractor := HttpServerAttrsExtractor[testRequest, testResponse, httpServerAttrsGetter, networkAttrsGetter, urlAttrsGetter]{ + base: HttpCommonAttrsExtractor[testRequest, testResponse, httpServerAttrsGetter, networkAttrsGetter]{}, + networkExtractor: net.NetworkAttrsExtractor[testRequest, testResponse, networkAttrsGetter]{}, + urlExtractor: net.UrlAttrsExtractor[testRequest, testResponse, urlAttrsGetter]{}, + } + attrs := make([]attribute.KeyValue, 0) + parentContext := context.Background() + attrs = httpServerExtractor.OnEnd(attrs, parentContext, testRequest{}, testResponse{}, nil) + if attrs[0].Key != semconv.HTTPResponseStatusCodeKey || attrs[0].Value.AsInt64() != 200 { + t.Fatalf("status code should be 200") + } + if attrs[1].Key != semconv.NetworkProtocolNameKey || attrs[1].Value.AsString() != "network-protocol-name" { + t.Fatalf("wrong network protocol name") + } + if attrs[2].Key != semconv.NetworkProtocolVersionKey || attrs[2].Value.AsString() != "network-protocol-version" { + t.Fatalf("wrong network protocol version") + } + if attrs[3].Key != semconv.NetworkTransportKey || attrs[3].Value.AsString() != "network-transport" { + t.Fatalf("wrong network transport") + } + if attrs[4].Key != semconv.NetworkTypeKey || attrs[4].Value.AsString() != "network-type" { + t.Fatalf("wrong network type") + } + if attrs[5].Key != semconv.NetworkProtocolNameKey || attrs[5].Value.AsString() != "network-protocol-name" { + t.Fatalf("wrong network protocol name") + } + if attrs[6].Key != semconv.NetworkProtocolVersionKey || attrs[6].Value.AsString() != "network-protocol-version" { + t.Fatalf("wrong network protocol version") + } + if attrs[7].Key != semconv.NetworkLocalAddressKey || attrs[7].Value.AsString() != "network-local-inet-address" { + t.Fatalf("wrong network protocol inet address") + } + if attrs[8].Key != semconv.NetworkPeerAddressKey || attrs[8].Value.AsString() != "network-peer-inet-address" { + t.Fatalf("wrong network peer address") + } + if attrs[9].Key != semconv.NetworkLocalPortKey || attrs[9].Value.AsInt64() != 8080 { + t.Fatalf("wrong network local port") + } + if attrs[10].Key != semconv.NetworkPeerPortKey || attrs[10].Value.AsInt64() != 8080 { + t.Fatalf("wrong network peer port") + } +} diff --git a/pkg/inst-api-semconv/instrumenter/http/http_span_name_extractor.go b/pkg/inst-api-semconv/instrumenter/http/http_span_name_extractor.go index 84d022db..7acaea1f 100644 --- a/pkg/inst-api-semconv/instrumenter/http/http_span_name_extractor.go +++ b/pkg/inst-api-semconv/instrumenter/http/http_span_name_extractor.go @@ -1,10 +1,10 @@ package http -type HttpClientSpanNameExtractor[REQUEST any] struct { - getter HttpClientAttrsGetter[REQUEST, any] +type HttpClientSpanNameExtractor[REQUEST any, RESPONSE any] struct { + getter HttpClientAttrsGetter[REQUEST, RESPONSE] } -func (h *HttpClientSpanNameExtractor[REQUEST]) Extract(request REQUEST) string { +func (h *HttpClientSpanNameExtractor[REQUEST, RESPONSE]) Extract(request REQUEST) string { method := h.getter.GetRequestMethod(request) if method == "" { return "HTTP" @@ -12,11 +12,11 @@ func (h *HttpClientSpanNameExtractor[REQUEST]) Extract(request REQUEST) string { return method } -type HttpServerSpanNameExtractor[REQUEST any] struct { - getter HttpServerAttrsGetter[REQUEST, any] +type HttpServerSpanNameExtractor[REQUEST any, RESPONSE any] struct { + getter HttpServerAttrsGetter[REQUEST, RESPONSE] } -func (h *HttpServerSpanNameExtractor[REQUEST]) Extract(request REQUEST) string { +func (h *HttpServerSpanNameExtractor[REQUEST, RESPONSE]) Extract(request REQUEST) string { method := h.getter.GetRequestMethod(request) route := h.getter.GetHttpRoute(request) if method == "" { diff --git a/pkg/inst-api-semconv/instrumenter/http/http_span_name_extractor_test.go b/pkg/inst-api-semconv/instrumenter/http/http_span_name_extractor_test.go new file mode 100644 index 00000000..ada7975b --- /dev/null +++ b/pkg/inst-api-semconv/instrumenter/http/http_span_name_extractor_test.go @@ -0,0 +1,68 @@ +package http + +import "testing" + +type testRequest struct { + Method string + Route string +} + +type testResponse struct { +} + +type testClientGetter struct { + HttpClientAttrsGetter[testRequest, testResponse] +} + +type testServerGetter struct { + HttpServerAttrsGetter[testRequest, testResponse] +} + +func (t testClientGetter) GetRequestMethod(request testRequest) string { + if request.Method != "" { + return request.Method + } + return "" +} + +func (t testServerGetter) GetRequestMethod(request testRequest) string { + if request.Method != "" { + return request.Method + } + return "" +} + +func (t testServerGetter) GetHttpRoute(request testRequest) string { + if request.Route != "" { + return request.Route + } + return "" +} + +func TestHttpClientExtractSpanName(t *testing.T) { + r := HttpClientSpanNameExtractor[testRequest, testResponse]{getter: testClientGetter{}} + spanName := r.Extract(testRequest{Method: "GET"}) + if spanName != "GET" { + t.Errorf("want GET, got %s", spanName) + } + spanName = r.Extract(testRequest{}) + if spanName != "HTTP" { + t.Errorf("want HTTP, got %s", spanName) + } +} + +func TestHttpServerExtractSpanName(t *testing.T) { + r := HttpServerSpanNameExtractor[testRequest, testResponse]{getter: testServerGetter{}} + spanName := r.Extract(testRequest{Method: "GET"}) + if spanName != "GET" { + t.Errorf("want GET, got %s", spanName) + } + spanName = r.Extract(testRequest{}) + if spanName != "HTTP" { + t.Errorf("want HTTP, got %s", spanName) + } + spanName = r.Extract(testRequest{Method: "GET", Route: "/a/b"}) + if spanName != "GET /a/b" { + t.Errorf("want GET /a/b, got %s", spanName) + } +} diff --git a/pkg/inst-api-semconv/instrumenter/http/http_status_code_converter_test.go b/pkg/inst-api-semconv/instrumenter/http/http_status_code_converter_test.go new file mode 100644 index 00000000..a81a51b2 --- /dev/null +++ b/pkg/inst-api-semconv/instrumenter/http/http_status_code_converter_test.go @@ -0,0 +1,23 @@ +package http + +import "testing" + +func TestClientHttpStatusCodeConverter(t *testing.T) { + c := ClientHttpStatusCodeConverter{} + if c.IsError(200) { + t.Fatalf("200 should not be an error") + } + if !c.IsError(600) || !c.IsError(90) { + t.Fatalf("600 and 90 should be an error") + } +} + +func TestServerHttpStatusCodeConverter(t *testing.T) { + c := ServerHttpStatusCodeConverter{} + if c.IsError(200) { + t.Fatalf("200 should not be an error") + } + if !c.IsError(500) || !c.IsError(90) { + t.Fatalf("600 and 90 should be an error") + } +} diff --git a/pkg/inst-api-semconv/instrumenter/message/message_attrs_extractor.go b/pkg/inst-api-semconv/instrumenter/message/message_attrs_extractor.go index 68115648..4b08f60c 100644 --- a/pkg/inst-api-semconv/instrumenter/message/message_attrs_extractor.go +++ b/pkg/inst-api-semconv/instrumenter/message/message_attrs_extractor.go @@ -4,6 +4,7 @@ import ( "context" "github.com/alibaba/opentelemetry-go-auto-instrumentation/pkg/inst-api/utils" "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" ) type MessageOperation string @@ -12,19 +13,6 @@ const PUBLISH MessageOperation = "publish" const RECEIVE MessageOperation = "receive" const PROCESS MessageOperation = "process" -const messaging_batch_message_count = attribute.Key("messaging.batch.message_count") -const messaging_client_id = attribute.Key("messaging.client_id") -const messaging_destination_anoymous = attribute.Key("messaging.destination.anonymous") -const messaging_destination_name = attribute.Key("messaging.destination.name") -const messaging_destination_template = attribute.Key("messaging.destination.template") -const messaging_destination_temporary = attribute.Key("messaging.destination.temporary") -const messaging_message_body_size = attribute.Key("messaging.message.body.size") -const messaging_message_conversation_id = attribute.Key("messaging.message.conversation_id") -const messaging_message_envelope_size = attribute.Key("messaging.message.envelope.size") -const messaging_message_id = attribute.Key("messaging.message.id") -const messaging_operation = attribute.Key("messaging.operation") -const messaging_system = attribute.Key("messaging.system") - type MessageAttrsExtractor[REQUEST any, RESPONSE any, GETTER MessageAttrsGetter[REQUEST, RESPONSE]] struct { getter GETTER operation MessageOperation @@ -47,45 +35,45 @@ func (m *MessageAttrsExtractor[REQUEST, RESPONSE, GETTER]) OnStart(attributes [] isTemporaryDestination := m.getter.IsTemporaryDestination(request) if isTemporaryDestination { attributes = append(attributes, attribute.KeyValue{ - Key: messaging_destination_temporary, + Key: semconv.MessagingDestinationTemporaryKey, Value: attribute.BoolValue(true), }, attribute.KeyValue{ - Key: messaging_destination_name, + Key: semconv.MessagingDestinationNameKey, Value: attribute.StringValue("(temporary)"), }) } else { attributes = append(attributes, attribute.KeyValue{ - Key: messaging_destination_name, + Key: semconv.MessagingDestinationNameKey, Value: attribute.StringValue(m.getter.GetDestination(request)), }, attribute.KeyValue{ - Key: messaging_destination_template, + Key: semconv.MessagingDestinationTemplateKey, Value: attribute.StringValue(m.getter.GetDestinationTemplate(request)), }) } isAnonymousDestination := m.getter.isAnonymousDestination(request) if isAnonymousDestination { attributes = append(attributes, attribute.KeyValue{ - Key: messaging_destination_anoymous, + Key: semconv.MessagingDestinationAnonymousKey, Value: attribute.BoolValue(true), }) } attributes = append(attributes, attribute.KeyValue{ - Key: messaging_message_conversation_id, + Key: semconv.MessagingMessageConversationIDKey, Value: attribute.StringValue(m.getter.GetConversationId(request)), }, attribute.KeyValue{ - Key: messaging_message_body_size, + Key: semconv.MessagingMessageBodySizeKey, Value: attribute.Int64Value(m.getter.GetMessageBodySize(request)), }, attribute.KeyValue{ - Key: messaging_message_envelope_size, + Key: semconv.MessagingMessageEnvelopeSizeKey, Value: attribute.Int64Value(m.getter.GetMessageEnvelopSize(request)), }, attribute.KeyValue{ - Key: messaging_client_id, + Key: semconv.MessagingClientIDKey, Value: attribute.StringValue(m.getter.GetClientId(request)), }, attribute.KeyValue{ - Key: messaging_operation, + Key: semconv.MessagingOperationNameKey, Value: attribute.StringValue(string(m.operation)), }, attribute.KeyValue{ - Key: messaging_system, + Key: semconv.MessagingSystemKey, Value: attribute.StringValue(messageAttrSystem), }) return attributes @@ -93,10 +81,10 @@ func (m *MessageAttrsExtractor[REQUEST, RESPONSE, GETTER]) OnStart(attributes [] func (m *MessageAttrsExtractor[REQUEST, RESPONSE, GETTER]) OnEnd(attributes []attribute.KeyValue, context context.Context, request REQUEST, response RESPONSE, err error) []attribute.KeyValue { attributes = append(attributes, attribute.KeyValue{ - Key: messaging_message_id, + Key: semconv.MessagingMessageIDKey, Value: attribute.StringValue(m.getter.GetMessageId(request, response)), }, attribute.KeyValue{ - Key: messaging_batch_message_count, + Key: semconv.MessagingBatchMessageCountKey, Value: attribute.Int64Value(m.getter.GetBatchMessageCount(request, response)), }) // TODO: add custom captured headers attributes diff --git a/pkg/inst-api-semconv/instrumenter/message/message_attrs_extractor_test.go b/pkg/inst-api-semconv/instrumenter/message/message_attrs_extractor_test.go new file mode 100644 index 00000000..0c3fa763 --- /dev/null +++ b/pkg/inst-api-semconv/instrumenter/message/message_attrs_extractor_test.go @@ -0,0 +1,156 @@ +package message + +import ( + "context" + "github.com/alibaba/opentelemetry-go-auto-instrumentation/pkg/inst-api/utils" + "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" + "testing" +) + +type messageAttrsGetter struct { +} + +func (m messageAttrsGetter) GetSystem(request testRequest) string { + return "system" +} + +func (m messageAttrsGetter) GetDestination(request testRequest) string { + return "destination" +} + +func (m messageAttrsGetter) GetDestinationTemplate(request testRequest) string { + return "destination-template" +} + +func (m messageAttrsGetter) IsTemporaryDestination(request testRequest) bool { + return request.IsTemporaryDestination +} + +func (m messageAttrsGetter) isAnonymousDestination(request testRequest) bool { + return request.IsAnonymousDestination +} + +func (m messageAttrsGetter) GetConversationId(request testRequest) string { + return "conversation-id" +} + +func (m messageAttrsGetter) GetMessageBodySize(request testRequest) int64 { + return 2024 +} + +func (m messageAttrsGetter) GetMessageEnvelopSize(request testRequest) int64 { + return 2024 +} + +func (m messageAttrsGetter) GetMessageId(request testRequest, response testResponse) string { + return "message-id" +} + +func (m messageAttrsGetter) GetClientId(request testRequest) string { + return "client-id" +} + +func (m messageAttrsGetter) GetBatchMessageCount(request testRequest, response testResponse) int64 { + return 2024 +} + +func (m messageAttrsGetter) GetMessageHeader(request testRequest, name string) []string { + return []string{"header1", "header2"} +} + +func TestMessageGetSpanKey(t *testing.T) { + messageExtractor := &MessageAttrsExtractor[testRequest, testResponse, messageAttrsGetter]{operation: PUBLISH} + if messageExtractor.GetSpanKey() != utils.PRODUCER_KEY { + t.Fatalf("Should have returned producer key") + } + messageExtractor.operation = RECEIVE + if messageExtractor.GetSpanKey() != utils.CONSUMER_RECEIVE_KEY { + t.Fatalf("Should have returned consumer receive key") + } + messageExtractor.operation = PROCESS + if messageExtractor.GetSpanKey() != utils.CONSUMER_PROCESS_KEY { + t.Fatalf("Should have returned consumer process key") + } +} + +func TestMessageClientExtractorStartWithTemporaryDestination(t *testing.T) { + messageExtractor := MessageAttrsExtractor[testRequest, testResponse, messageAttrsGetter]{operation: PUBLISH} + attrs := make([]attribute.KeyValue, 0) + parentContext := context.Background() + attrs = messageExtractor.OnStart(attrs, parentContext, testRequest{IsTemporaryDestination: true, IsAnonymousDestination: true}) + if attrs[0].Key != semconv.MessagingDestinationTemporaryKey || attrs[0].Value.AsBool() != true { + t.Fatalf("temporary should be true") + } + if attrs[1].Key != semconv.MessagingDestinationNameKey || attrs[1].Value.AsString() != "(temporary)" { + t.Fatalf("destination name should be temporary") + } + if attrs[2].Key != semconv.MessagingDestinationAnonymousKey || attrs[2].Value.AsBool() != true { + t.Fatalf("destination anoymous should be true") + } + if attrs[3].Key != semconv.MessagingMessageConversationIDKey || attrs[3].Value.AsString() != "conversation-id" { + t.Fatalf("conversation should be conversation-id") + } + if attrs[4].Key != semconv.MessagingMessageBodySizeKey || attrs[4].Value.AsInt64() != 2024 { + t.Fatalf("message body size should be 2024") + } + if attrs[5].Key != semconv.MessagingMessageEnvelopeSizeKey || attrs[5].Value.AsInt64() != 2024 { + t.Fatalf("messsage envelope size should be 2024") + } + if attrs[6].Key != semconv.MessagingClientIDKey || attrs[6].Value.AsString() != "client-id" { + t.Fatalf("messsage client id should be client-id") + } + if attrs[7].Key != semconv.MessagingOperationNameKey || attrs[7].Value.AsString() != "publish" { + t.Fatalf("messsage operation should be publish") + } + if attrs[8].Key != semconv.MessagingSystemKey || attrs[8].Value.AsString() != "system" { + t.Fatalf("messsage system should be system") + } +} + +func TestMessageClientExtractorStartWithoutTemporaryDestination(t *testing.T) { + messageExtractor := MessageAttrsExtractor[testRequest, testResponse, messageAttrsGetter]{operation: PUBLISH} + attrs := make([]attribute.KeyValue, 0) + parentContext := context.Background() + attrs = messageExtractor.OnStart(attrs, parentContext, testRequest{IsTemporaryDestination: false, IsAnonymousDestination: true}) + if attrs[0].Key != semconv.MessagingDestinationNameKey || attrs[0].Value.AsString() != "destination" { + t.Fatalf("destination name should be destination") + } + if attrs[1].Key != semconv.MessagingDestinationTemplateKey || attrs[1].Value.AsString() != "destination-template" { + t.Fatalf("destination template should be destination-template") + } + if attrs[2].Key != semconv.MessagingDestinationAnonymousKey || attrs[2].Value.AsBool() != true { + t.Fatalf("destination anoymous should be true") + } + if attrs[3].Key != semconv.MessagingMessageConversationIDKey || attrs[3].Value.AsString() != "conversation-id" { + t.Fatalf("conversation should be conversation-id") + } + if attrs[4].Key != semconv.MessagingMessageBodySizeKey || attrs[4].Value.AsInt64() != 2024 { + t.Fatalf("message body size should be 2024") + } + if attrs[5].Key != semconv.MessagingMessageEnvelopeSizeKey || attrs[5].Value.AsInt64() != 2024 { + t.Fatalf("messsage envelope size should be 2024") + } + if attrs[6].Key != semconv.MessagingClientIDKey || attrs[6].Value.AsString() != "client-id" { + t.Fatalf("messsage client id should be client-id") + } + if attrs[7].Key != semconv.MessagingOperationNameKey || attrs[7].Value.AsString() != "publish" { + t.Fatalf("messsage operation should be publish") + } + if attrs[8].Key != semconv.MessagingSystemKey || attrs[8].Value.AsString() != "system" { + t.Fatalf("messsage system should be system") + } +} + +func TestMessageClientExtractorEnd(t *testing.T) { + messageExtractor := MessageAttrsExtractor[testRequest, testResponse, messageAttrsGetter]{} + attrs := make([]attribute.KeyValue, 0) + parentContext := context.Background() + attrs = messageExtractor.OnEnd(attrs, parentContext, testRequest{}, testResponse{}, nil) + if attrs[0].Key != semconv.MessagingMessageIDKey || attrs[0].Value.AsString() != "message-id" { + t.Fatalf("message id should be message-id") + } + if attrs[1].Key != semconv.MessagingBatchMessageCountKey || attrs[1].Value.AsInt64() != 2024 { + t.Fatalf("messaging batch message count should be 2024") + } +} diff --git a/pkg/inst-api-semconv/instrumenter/message/message_span_name_extractor.go b/pkg/inst-api-semconv/instrumenter/message/message_span_name_extractor.go index 59744ca3..cfbfb661 100644 --- a/pkg/inst-api-semconv/instrumenter/message/message_span_name_extractor.go +++ b/pkg/inst-api-semconv/instrumenter/message/message_span_name_extractor.go @@ -17,5 +17,8 @@ func (m *MessageSpanNameExtractor[REQUEST, RESPONSE]) Extract(request REQUEST) s if destinationName == "" { destinationName = "unknown" } - return destinationName + " " + string(m.operationName) + if m.operationName != "" { + destinationName = destinationName + " " + string(m.operationName) + } + return destinationName } diff --git a/pkg/inst-api-semconv/instrumenter/message/message_span_name_extractor_test.go b/pkg/inst-api-semconv/instrumenter/message/message_span_name_extractor_test.go new file mode 100644 index 00000000..16df2b3a --- /dev/null +++ b/pkg/inst-api-semconv/instrumenter/message/message_span_name_extractor_test.go @@ -0,0 +1,98 @@ +package message + +import ( + "testing" +) + +type testRequest struct { + IsTemporaryDestination bool + IsAnonymousDestination bool + Destination string +} + +type testResponse struct { +} + +type testGetter struct { +} + +func (t testGetter) GetSystem(request testRequest) string { + //TODO implement me + panic("implement me") +} + +func (t testGetter) GetDestination(request testRequest) string { + if request.Destination != "" { + return request.Destination + } + return "" +} + +func (t testGetter) GetDestinationTemplate(request testRequest) string { + //TODO implement me + panic("implement me") +} + +func (t testGetter) IsTemporaryDestination(request testRequest) bool { + return request.IsTemporaryDestination +} + +func (t testGetter) isAnonymousDestination(request testRequest) bool { + //TODO implement me + panic("implement me") +} + +func (t testGetter) GetConversationId(request testRequest) string { + //TODO implement me + panic("implement me") +} + +func (t testGetter) GetMessageBodySize(request testRequest) int64 { + //TODO implement me + panic("implement me") +} + +func (t testGetter) GetMessageEnvelopSize(request testRequest) int64 { + //TODO implement me + panic("implement me") +} + +func (t testGetter) GetMessageId(request testRequest, response testResponse) string { + //TODO implement me + panic("implement me") +} + +func (t testGetter) GetClientId(request testRequest) string { + //TODO implement me + panic("implement me") +} + +func (t testGetter) GetBatchMessageCount(request testRequest, response testResponse) int64 { + //TODO implement me + panic("implement me") +} + +func (t testGetter) GetMessageHeader(request testRequest, name string) []string { + //TODO implement me + panic("implement me") +} + +func TestExtractSpanName(t *testing.T) { + r := MessageSpanNameExtractor[testRequest, testResponse]{getter: testGetter{}} + spanName := r.Extract(testRequest{IsTemporaryDestination: true, Destination: "Destination"}) + if spanName != "(temporary)" { + t.Fatalf("extract span name failed: expected (temporary) but got %s", spanName) + } + spanName = r.Extract(testRequest{IsTemporaryDestination: false, Destination: ""}) + if spanName != "unknown" { + t.Fatalf("extract span name failed: expected unknown but got %s", spanName) + } +} + +func TestExtractSpanNameWithOperationName(t *testing.T) { + r := MessageSpanNameExtractor[testRequest, testResponse]{getter: testGetter{}, operationName: PUBLISH} + spanName := r.Extract(testRequest{IsTemporaryDestination: true, Destination: "Destination"}) + if spanName != "(temporary) publish" { + t.Fatalf("extract span name failed: expected (temporary) publish but got %s", spanName) + } +} diff --git a/pkg/inst-api-semconv/instrumenter/net/network_attrs_extractor.go b/pkg/inst-api-semconv/instrumenter/net/network_attrs_extractor.go index 83d8717e..ecec5de8 100644 --- a/pkg/inst-api-semconv/instrumenter/net/network_attrs_extractor.go +++ b/pkg/inst-api-semconv/instrumenter/net/network_attrs_extractor.go @@ -3,22 +3,10 @@ package net import ( "context" "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" "strings" ) -const network_transport = attribute.Key("network.transport") -const network_protocol_name = attribute.Key("network.protocol.name") -const network_local_address = attribute.Key("network.local.address") -const network_local_port = attribute.Key("network.local.port") -const network_peer_address = attribute.Key("network.peer.address") -const network_peer_port = attribute.Key("network.peer.port") -const network_protocol_version = attribute.Key("network.protocol.version") -const network_type = attribute.Key("network.type") - -const url_scheme = attribute.Key("url.scheme") -const url_query = attribute.Key("url.query") -const url_path = attribute.Key("url.path") - type NetworkAttrsExtractor[REQUEST any, RESPONSE any, GETTER NetworkAttrsGetter[REQUEST, RESPONSE]] struct { getter GETTER } @@ -29,35 +17,35 @@ func (i *NetworkAttrsExtractor[REQUEST, RESPONSE, GETTER]) OnStart(attributes [] func (i *NetworkAttrsExtractor[REQUEST, RESPONSE, GETTER]) OnEnd(attributes []attribute.KeyValue, context context.Context, request REQUEST, response RESPONSE, err error) []attribute.KeyValue { attributes = append(attributes, attribute.KeyValue{ - Key: network_transport, + Key: semconv.NetworkTransportKey, Value: attribute.StringValue(i.getter.GetNetworkTransport(request, response)), }, attribute.KeyValue{ - Key: network_type, + Key: semconv.NetworkTypeKey, Value: attribute.StringValue(strings.ToLower(i.getter.GetNetworkType(request, response))), }, attribute.KeyValue{ - Key: network_protocol_name, + Key: semconv.NetworkProtocolNameKey, Value: attribute.StringValue(strings.ToLower(i.getter.GetNetworkProtocolName(request, response))), }, attribute.KeyValue{ - Key: network_protocol_version, + Key: semconv.NetworkProtocolVersionKey, Value: attribute.StringValue(strings.ToLower(i.getter.GetNetworkProtocolVersion(request, response))), }, attribute.KeyValue{ - Key: network_local_address, + Key: semconv.NetworkLocalAddressKey, Value: attribute.StringValue(i.getter.GetNetworkLocalInetAddress(request, response)), }, attribute.KeyValue{ - Key: network_peer_address, + Key: semconv.NetworkPeerAddressKey, Value: attribute.StringValue(i.getter.GetNetworkPeerInetAddress(request, response)), }) localPort := i.getter.GetNetworkLocalPort(request, response) if localPort > 0 { attributes = append(attributes, attribute.KeyValue{ - Key: network_local_port, + Key: semconv.NetworkLocalPortKey, Value: attribute.IntValue(localPort), }) } peerPort := i.getter.GetNetworkPeerPort(request, response) if peerPort > 0 { attributes = append(attributes, attribute.KeyValue{ - Key: network_peer_port, + Key: semconv.NetworkPeerPortKey, Value: attribute.IntValue(peerPort), }) } @@ -71,13 +59,13 @@ type UrlAttrsExtractor[REQUEST any, RESPONSE any, GETTER UrlAttrsGetter[REQUEST] func (u *UrlAttrsExtractor[REQUEST, RESPONSE, GETTER]) OnStart(attributes []attribute.KeyValue, parentContext context.Context, request REQUEST) []attribute.KeyValue { attributes = append(attributes, attribute.KeyValue{ - Key: url_scheme, + Key: semconv.URLSchemeKey, Value: attribute.StringValue(u.getter.GetUrlScheme(request)), }, attribute.KeyValue{ - Key: url_path, + Key: semconv.URLPathKey, Value: attribute.StringValue(u.getter.GetUrlPath(request)), }, attribute.KeyValue{ - Key: url_query, + Key: semconv.URLQueryKey, Value: attribute.StringValue(u.getter.GetUrlQuery(request)), }) return attributes diff --git a/pkg/inst-api-semconv/instrumenter/net/network_attrs_extractor_test.go b/pkg/inst-api-semconv/instrumenter/net/network_attrs_extractor_test.go new file mode 100644 index 00000000..f92b0535 --- /dev/null +++ b/pkg/inst-api-semconv/instrumenter/net/network_attrs_extractor_test.go @@ -0,0 +1,91 @@ +package net + +import ( + "context" + "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" + "log" + "testing" +) + +type testRequest struct { +} + +type testResponse struct { +} + +type netAttrsGetter struct { +} + +func (n netAttrsGetter) GetNetworkType(request testRequest, response testResponse) string { + return "test" +} + +func (n netAttrsGetter) GetNetworkTransport(request testRequest, response testResponse) string { + return "test" +} + +func (n netAttrsGetter) GetNetworkProtocolName(request testRequest, response testResponse) string { + return "test" +} + +func (n netAttrsGetter) GetNetworkProtocolVersion(request testRequest, response testResponse) string { + return "test" +} + +func (n netAttrsGetter) GetNetworkLocalInetAddress(request testRequest, response testResponse) string { + return "test" +} + +func (n netAttrsGetter) GetNetworkLocalPort(request testRequest, response testResponse) int { + return 8080 +} + +func (n netAttrsGetter) GetNetworkPeerInetAddress(request testRequest, response testResponse) string { + return "test" +} + +func (n netAttrsGetter) GetNetworkPeerPort(request testRequest, response testResponse) int { + return 8080 +} + +func TestNetClientExtractorStart(t *testing.T) { + netExtractor := NetworkAttrsExtractor[testRequest, testResponse, netAttrsGetter]{} + attrs := make([]attribute.KeyValue, 0) + parentContext := context.Background() + attrs = netExtractor.OnStart(attrs, parentContext, testRequest{}) + if len(attrs) != 0 { + log.Fatal("attrs should be empty") + } +} + +func TestNetClientExtractorEnd(t *testing.T) { + netExtractor := NetworkAttrsExtractor[testRequest, testResponse, netAttrsGetter]{} + attrs := make([]attribute.KeyValue, 0) + parentContext := context.Background() + attrs = netExtractor.OnEnd(attrs, parentContext, testRequest{}, testResponse{}, nil) + if attrs[0].Key != semconv.NetworkTransportKey || attrs[0].Value.AsString() != "test" { + t.Fatalf("network transport key should be test") + } + if attrs[1].Key != semconv.NetworkTypeKey || attrs[1].Value.AsString() != "test" { + t.Fatalf("network type should be test") + } + if attrs[2].Key != semconv.NetworkProtocolNameKey || attrs[2].Value.AsString() != "test" { + t.Fatalf("network protocol name should be test") + } + if attrs[3].Key != semconv.NetworkProtocolVersionKey || attrs[3].Value.AsString() != "test" { + t.Fatalf("network protocol version should be test") + } + if attrs[4].Key != semconv.NetworkLocalAddressKey || attrs[4].Value.AsString() != "test" { + t.Fatalf("network local address should be test") + } + if attrs[5].Key != semconv.NetworkPeerAddressKey || attrs[5].Value.AsString() != "test" { + t.Fatalf("network peer address should be test") + } + if attrs[6].Key != semconv.NetworkLocalPortKey || attrs[6].Value.AsInt64() != 8080 { + t.Fatalf("network local port should be empty") + } + if attrs[7].Key != semconv.NetworkPeerPortKey || attrs[7].Value.AsInt64() != 8080 { + t.Fatalf("network peer port should be empty") + } +} diff --git a/pkg/inst-api-semconv/instrumenter/rpc/rpc_attrs_extractor.go b/pkg/inst-api-semconv/instrumenter/rpc/rpc_attrs_extractor.go index 7ac29557..85f08d61 100644 --- a/pkg/inst-api-semconv/instrumenter/rpc/rpc_attrs_extractor.go +++ b/pkg/inst-api-semconv/instrumenter/rpc/rpc_attrs_extractor.go @@ -4,25 +4,22 @@ import ( "context" "github.com/alibaba/opentelemetry-go-auto-instrumentation/pkg/inst-api/utils" "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" ) -const rpc_method = attribute.Key("rpc.method") -const rpc_service = attribute.Key("rpc.service") -const rpc_system = attribute.Key("rpc.system") - type RpcAttrsExtractor[REQUEST any, RESPONSE any, GETTER RpcAttrsGetter[REQUEST]] struct { getter GETTER } func (r *RpcAttrsExtractor[REQUEST, RESPONSE, GETTER]) OnStart(attributes []attribute.KeyValue, parentContext context.Context, request REQUEST) []attribute.KeyValue { attributes = append(attributes, attribute.KeyValue{ - Key: rpc_system, + Key: semconv.RPCSystemKey, Value: attribute.StringValue(r.getter.GetSystem(request)), }, attribute.KeyValue{ - Key: rpc_service, + Key: semconv.RPCServiceKey, Value: attribute.StringValue(r.getter.GetService(request)), }, attribute.KeyValue{ - Key: rpc_method, + Key: semconv.RPCMethodKey, Value: attribute.StringValue(r.getter.GetMethod(request)), }) return attributes @@ -53,7 +50,7 @@ type ClientRpcAttrsExtractor[REQUEST any, RESPONSE any, GETTER RpcAttrsGetter[RE } func (s *ClientRpcAttrsExtractor[REQUEST, RESPONSE, GETTER]) GetSpanKey() attribute.Key { - return utils.RPC_SERVER_KEY + return utils.RPC_CLIENT_KEY } func (s *ClientRpcAttrsExtractor[REQUEST, RESPONSE, GETTER]) OnStart(attributes []attribute.KeyValue, parentContext context.Context, request REQUEST) []attribute.KeyValue { diff --git a/pkg/inst-api-semconv/instrumenter/rpc/rpc_attrs_extractor_test.go b/pkg/inst-api-semconv/instrumenter/rpc/rpc_attrs_extractor_test.go new file mode 100644 index 00000000..3dd0ae99 --- /dev/null +++ b/pkg/inst-api-semconv/instrumenter/rpc/rpc_attrs_extractor_test.go @@ -0,0 +1,94 @@ +package rpc + +import ( + "context" + "github.com/alibaba/opentelemetry-go-auto-instrumentation/pkg/inst-api/utils" + "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" + "log" + "testing" +) + +type testResponse struct { +} + +type rpcAttrsGetter struct { +} + +func (h rpcAttrsGetter) GetSystem(request testRequest) string { + return "system" +} + +func (h rpcAttrsGetter) GetService(request testRequest) string { + return "service" +} + +func (h rpcAttrsGetter) GetMethod(request testRequest) string { + return "method" +} + +func TestClientGetSpanKey(t *testing.T) { + rpcExtractor := &ClientRpcAttrsExtractor[testRequest, any, rpcAttrsGetter]{} + if rpcExtractor.GetSpanKey() != utils.RPC_CLIENT_KEY { + t.Fatalf("Should have returned RPC_CLIENT_KEY") + } +} + +func TestServerGetSpanKey(t *testing.T) { + rpcExtractor := &ServerRpcAttrsExtractor[testRequest, any, rpcAttrsGetter]{} + if rpcExtractor.GetSpanKey() != utils.RPC_SERVER_KEY { + t.Fatalf("Should have returned RPC_SERVER_KEY") + } +} + +func TestRpcClientExtractorStart(t *testing.T) { + rpcExtractor := ClientRpcAttrsExtractor[testRequest, testResponse, rpcAttrsGetter]{} + attrs := make([]attribute.KeyValue, 0) + parentContext := context.Background() + attrs = rpcExtractor.OnStart(attrs, parentContext, testRequest{}) + if attrs[0].Key != semconv.RPCSystemKey || attrs[0].Value.AsString() != "system" { + t.Fatalf("rpc system should be system") + } + if attrs[1].Key != semconv.RPCServiceKey || attrs[1].Value.AsString() != "service" { + t.Fatalf("rpc service should be service") + } + if attrs[2].Key != semconv.RPCMethodKey || attrs[2].Value.AsString() != "method" { + t.Fatalf("rpc method should be method") + } +} + +func TestRpcClientExtractorEnd(t *testing.T) { + rpcExtractor := ClientRpcAttrsExtractor[testRequest, testResponse, rpcAttrsGetter]{} + attrs := make([]attribute.KeyValue, 0) + parentContext := context.Background() + attrs = rpcExtractor.OnEnd(attrs, parentContext, testRequest{}, testResponse{}, nil) + if len(attrs) != 0 { + log.Fatal("attrs should be empty") + } +} + +func TestRpcServerExtractorStart(t *testing.T) { + rpcExtractor := ServerRpcAttrsExtractor[testRequest, testResponse, rpcAttrsGetter]{} + attrs := make([]attribute.KeyValue, 0) + parentContext := context.Background() + attrs = rpcExtractor.OnStart(attrs, parentContext, testRequest{}) + if attrs[0].Key != semconv.RPCSystemKey || attrs[0].Value.AsString() != "system" { + t.Fatalf("rpc system should be system") + } + if attrs[1].Key != semconv.RPCServiceKey || attrs[1].Value.AsString() != "service" { + t.Fatalf("rpc service should be service") + } + if attrs[2].Key != semconv.RPCMethodKey || attrs[2].Value.AsString() != "method" { + t.Fatalf("rpc method should be method") + } +} + +func TestRpcServerExtractorEnd(t *testing.T) { + rpcExtractor := ServerRpcAttrsExtractor[testRequest, testResponse, rpcAttrsGetter]{} + attrs := make([]attribute.KeyValue, 0) + parentContext := context.Background() + attrs = rpcExtractor.OnEnd(attrs, parentContext, testRequest{}, testResponse{}, nil) + if len(attrs) != 0 { + log.Fatal("attrs should be empty") + } +} diff --git a/pkg/inst-api-semconv/instrumenter/rpc/rpc_span_name_extractor_test.go b/pkg/inst-api-semconv/instrumenter/rpc/rpc_span_name_extractor_test.go new file mode 100644 index 00000000..66fd1b5a --- /dev/null +++ b/pkg/inst-api-semconv/instrumenter/rpc/rpc_span_name_extractor_test.go @@ -0,0 +1,45 @@ +package rpc + +import "testing" + +type testRequest struct { + System string + Service string + Method string +} + +type testGetter struct { +} + +func (t testGetter) GetSystem(request testRequest) string { + if request.System != "" { + return request.System + } + return "" +} + +func (t testGetter) GetService(request testRequest) string { + if request.Service != "" { + return request.Service + } + return "" +} + +func (t testGetter) GetMethod(request testRequest) string { + if request.Method != "" { + return request.Method + } + return "" +} + +func TestExtractSpanName(t *testing.T) { + r := RpcSpanNameExtractor[testRequest]{getter: testGetter{}} + spanName := r.Extract(testRequest{Method: "method", Service: "service"}) + if spanName != "service/method" { + t.Fatalf("extract span name extractor failed, expected 'service/method', got '%s'", spanName) + } + spanName = r.Extract(testRequest{}) + if spanName != "RPC request" { + t.Fatalf("extract span name extractor failed, expected 'RPC request', got '%s'", spanName) + } +} diff --git a/pkg/verifier/access.go b/pkg/verifier/access.go index 8256cdaa..4f357dc8 100644 --- a/pkg/verifier/access.go +++ b/pkg/verifier/access.go @@ -20,7 +20,3 @@ func GetTestSpans() *tracetest.SpanStubs { func ResetTestSpans() { spanExporter.Reset() } - -func ClearSpan() { - spanExporter.Reset() -} diff --git a/pkg/verifier/access_test.go b/pkg/verifier/access_test.go new file mode 100644 index 00000000..43d08471 --- /dev/null +++ b/pkg/verifier/access_test.go @@ -0,0 +1,84 @@ +package verifier + +import ( + "context" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/sdk/instrumentation" + "go.opentelemetry.io/otel/sdk/resource" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/trace" + "testing" + "time" +) + +type testSpan struct { + // Embed the interface to implement the private method. + sdktrace.ReadOnlySpan + ID string + name string + spanContext trace.SpanContext + parent trace.SpanContext + spanKind trace.SpanKind + startTime time.Time + endTime time.Time + attributes []attribute.KeyValue + events []sdktrace.Event + links []sdktrace.Link + status sdktrace.Status + droppedAttributes int + droppedEvents int + droppedLinks int + childSpanCount int + resource *resource.Resource + instrumentationScope instrumentation.Scope +} + +func (s testSpan) Name() string { return s.name } +func (s testSpan) SpanContext() trace.SpanContext { return s.spanContext } +func (s testSpan) Parent() trace.SpanContext { return s.parent } +func (s testSpan) SpanKind() trace.SpanKind { return s.spanKind } +func (s testSpan) StartTime() time.Time { return s.startTime } +func (s testSpan) EndTime() time.Time { return s.endTime } +func (s testSpan) Attributes() []attribute.KeyValue { return s.attributes } +func (s testSpan) Links() []sdktrace.Link { return s.links } +func (s testSpan) Events() []sdktrace.Event { return s.events } +func (s testSpan) Status() sdktrace.Status { return s.status } +func (s testSpan) DroppedAttributes() int { return s.droppedAttributes } +func (s testSpan) DroppedLinks() int { return s.droppedLinks } +func (s testSpan) DroppedEvents() int { return s.droppedEvents } +func (s testSpan) ChildSpanCount() int { return s.childSpanCount } +func (s testSpan) Resource() *resource.Resource { return s.resource } +func (s testSpan) InstrumentationScope() instrumentation.Scope { + return s.instrumentationScope +} + +func (s testSpan) InstrumentationLibrary() instrumentation.Library { + return s.instrumentationScope +} + +func TestResetSpan(t *testing.T) { + err := spanExporter.ExportSpans(context.Background(), []sdktrace.ReadOnlySpan{ + testSpan{ID: "1"}, + testSpan{ID: "2"}, + }) + if err != nil { + t.Fatal(err) + } + ResetTestSpans() + if len(spanExporter.GetSpans()) != 0 { + t.Fatal("expected no all the spans are cleared") + } +} + +func TestGetTestSpans(t *testing.T) { + err := GetSpanExporter().ExportSpans(context.Background(), []sdktrace.ReadOnlySpan{ + testSpan{ID: "1"}, + testSpan{ID: "2"}, + }) + if err != nil { + t.Fatal(err) + } + if len(spanExporter.GetSpans()) != 2 { + t.Fatalf("expected 2 spans, got %d", len(spanExporter.GetSpans())) + } +} diff --git a/pkg/verifier/runner.go b/pkg/verifier/runner.go index 5dbc4429..8c0f06cc 100644 --- a/pkg/verifier/runner.go +++ b/pkg/verifier/runner.go @@ -22,6 +22,7 @@ func WaitAndAssertTraces(traceVerifiers ...func([]tracetest.SpanStubs)) { } func waitForTraces(numberOfTraces int) []tracetest.SpanStubs { + defer ResetTestSpans() // 最多等20s finish := false var traces []tracetest.SpanStubs @@ -47,7 +48,6 @@ func waitForTraces(numberOfTraces int) []tracetest.SpanStubs { func groupAndSortTrace() []tracetest.SpanStubs { spans := GetTestSpans() - defer ResetTestSpans() traceMap := make(map[string][]tracetest.SpanStub) for _, span := range *spans { if span.SpanContext.HasTraceID() && span.SpanContext.TraceID().IsValid() { diff --git a/pkg/verifier/runner_test.go b/pkg/verifier/runner_test.go new file mode 100644 index 00000000..5d22ea63 --- /dev/null +++ b/pkg/verifier/runner_test.go @@ -0,0 +1,97 @@ +package verifier + +import ( + "context" + "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + oteltrace "go.opentelemetry.io/otel/trace" + "testing" +) + +func TestWaitAndAssertTracesOneTrace(t *testing.T) { + err := spanExporter.ExportSpans(context.Background(), []trace.ReadOnlySpan{ + testSpan{ID: "1", spanContext: oteltrace.NewSpanContext(oteltrace.SpanContextConfig{ + TraceID: oteltrace.TraceID{0x010}, + SpanID: oteltrace.SpanID{0x01}, + })}, + testSpan{ID: "2", spanContext: oteltrace.NewSpanContext(oteltrace.SpanContextConfig{ + TraceID: oteltrace.TraceID{0x010}, + SpanID: oteltrace.SpanID{0x01}, + })}, + testSpan{ID: "3", spanContext: oteltrace.NewSpanContext(oteltrace.SpanContextConfig{ + TraceID: oteltrace.TraceID{0x010}, + SpanID: oteltrace.SpanID{0x01}, + })}, + testSpan{ID: "4", spanContext: oteltrace.NewSpanContext(oteltrace.SpanContextConfig{ + TraceID: oteltrace.TraceID{0x010}, + SpanID: oteltrace.SpanID{0x01}, + })}, + }) + if err != nil { + t.Fatal(err) + } + defer ResetTestSpans() + WaitAndAssertTraces(func(stubs []tracetest.SpanStubs) { + if len(stubs) != 1 { + t.Fatalf("expecting 1 traces but got %d", len(stubs)) + } + }) +} + +func TestWaitAndAssertTracesMultipleTrace(t *testing.T) { + err := spanExporter.ExportSpans(context.Background(), []trace.ReadOnlySpan{ + testSpan{ID: "1", spanContext: oteltrace.NewSpanContext(oteltrace.SpanContextConfig{ + TraceID: oteltrace.TraceID{0x010}, + SpanID: oteltrace.SpanID{0x01}, + })}, + testSpan{ID: "2", spanContext: oteltrace.NewSpanContext(oteltrace.SpanContextConfig{ + TraceID: oteltrace.TraceID{0x020}, + SpanID: oteltrace.SpanID{0x01}, + })}, + testSpan{ID: "3", spanContext: oteltrace.NewSpanContext(oteltrace.SpanContextConfig{ + TraceID: oteltrace.TraceID{0x030}, + SpanID: oteltrace.SpanID{0x01}, + })}, + testSpan{ID: "4", spanContext: oteltrace.NewSpanContext(oteltrace.SpanContextConfig{ + TraceID: oteltrace.TraceID{0x040}, + SpanID: oteltrace.SpanID{0x01}, + })}, + }) + if err != nil { + t.Fatal(err) + } + defer ResetTestSpans() + WaitAndAssertTraces(func(stubs []tracetest.SpanStubs) { + if len(stubs) != 4 { + t.Fatalf("expecting 4 traces but got %d", len(stubs)) + } + }) +} + +func TestWaitAndAssertTraceLink(t *testing.T) { + err := spanExporter.ExportSpans(context.Background(), []trace.ReadOnlySpan{ + testSpan{ID: "1", spanContext: oteltrace.NewSpanContext(oteltrace.SpanContextConfig{ + TraceID: oteltrace.TraceID{0x010}, + SpanID: oteltrace.SpanID{0x01}, + })}, + testSpan{ID: "2", spanContext: oteltrace.NewSpanContext(oteltrace.SpanContextConfig{ + TraceID: oteltrace.TraceID{0x010}, + SpanID: oteltrace.SpanID{0x02}, + }), parent: oteltrace.NewSpanContext(oteltrace.SpanContextConfig{ + TraceID: oteltrace.TraceID{0x010}, + SpanID: oteltrace.SpanID{0x01}, + })}, + }) + if err != nil { + t.Fatal(err) + } + defer ResetTestSpans() + WaitAndAssertTraces(func(stubs []tracetest.SpanStubs) { + if len(stubs) != 1 { + t.Fatalf("expecting 1 traces but got %d", len(stubs)) + } + if stubs[0].Snapshots()[1].Parent().SpanID() != stubs[0].Snapshots()[0].SpanContext().SpanID() { + t.Fatalf("expecting parent span id to be equal") + } + }) +} diff --git a/pkg/verifier/util.go b/pkg/verifier/util.go index faa069db..0e9f5234 100644 --- a/pkg/verifier/util.go +++ b/pkg/verifier/util.go @@ -39,7 +39,7 @@ func GetFreePort() (int, error) { return cli.Addr().(*net.TCPAddr).Port, nil } -func GetServer(ctx context.Context, url string) { +func GetServer(ctx context.Context, url string) (string, error) { req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { panic(err) @@ -50,6 +50,7 @@ func GetServer(ctx context.Context, url string) { panic(err) } defer resp.Body.Close() + return resp.Status, nil } func IsInTest() bool { diff --git a/pkg/verifier/util_test.go b/pkg/verifier/util_test.go index d0471d60..3ec30f20 100644 --- a/pkg/verifier/util_test.go +++ b/pkg/verifier/util_test.go @@ -1,6 +1,9 @@ package verifier -import "testing" +import ( + "go.opentelemetry.io/otel/attribute" + "testing" +) func TestGetFreePort(t *testing.T) { port, err := GetFreePort() @@ -11,3 +14,30 @@ func TestGetFreePort(t *testing.T) { t.Fatal("port is 0") } } + +func TestGetAttribute(t *testing.T) { + attrs := []attribute.KeyValue{ + attribute.Key("key").String("value"), + attribute.Key("key1").String("value1"), + } + if GetAttribute(attrs, "key").AsString() != "value" { + t.Fatal("key should exist") + } + if GetAttribute(attrs, "key2").Type() != attribute.INVALID { + t.Fatal("key 2 should not exist") + } +} + +func TestAssert(t *testing.T) { + defer func() { + pass := false + if r := recover(); r != nil { + pass = true + } + if !pass { + t.Fatal("Should be recovered from panic") + } + }() + Assert(1 == 1, "1 should equal to 1") + Assert(1 == 2, "1 should equal to 1") +} diff --git a/pkg/verifier/verifier.go b/pkg/verifier/verifier.go index 5b4bab80..60b34ded 100644 --- a/pkg/verifier/verifier.go +++ b/pkg/verifier/verifier.go @@ -6,6 +6,7 @@ import ( "strings" ) +// VerifyNoSqlAttributes TODO: make attribute name to semconv attribute func VerifyNoSqlAttributes(span tracetest.SpanStub, name, dbName, system, user, connString, statement, operation string) { Assert(span.SpanKind == trace.SpanKindClient, "Expect to be client span, got %d", span.SpanKind) Assert(span.Name == name, "Except client span name to be %s, got %s", name, span.Name) diff --git a/pkg/verifier/verifier_test.go b/pkg/verifier/verifier_test.go new file mode 100644 index 00000000..524a80b6 --- /dev/null +++ b/pkg/verifier/verifier_test.go @@ -0,0 +1,40 @@ +package verifier + +import ( + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + semconv "go.opentelemetry.io/otel/semconv/v1.19.0" + "go.opentelemetry.io/otel/trace" + "testing" +) + +func TestNoSqlAttributesPass(t *testing.T) { + VerifyNoSqlAttributes(tracetest.SpanStub{SpanKind: trace.SpanKindClient, Name: "name", Attributes: []attribute.KeyValue{ + {Key: semconv.DBNameKey, Value: attribute.StringValue("dbname")}, + {Key: semconv.DBSystemKey, Value: attribute.StringValue("system")}, + {Key: semconv.DBUserKey, Value: attribute.StringValue("user")}, + {Key: semconv.DBConnectionStringKey, Value: attribute.StringValue("connString")}, + {Key: semconv.DBStatementKey, Value: attribute.StringValue("statement")}, + {Key: semconv.DBOperationKey, Value: attribute.StringValue("operation")}, + }}, "name", "dbname", "system", "user", "connString", "statement", "operation") +} + +func TestNoSqlAttributesFail(t *testing.T) { + defer func() { + pass := false + if r := recover(); r != nil { + pass = true + } + if !pass { + t.Fatal("Should be recovered from panic") + } + }() + VerifyNoSqlAttributes(tracetest.SpanStub{SpanKind: trace.SpanKindClient, Name: "name", Attributes: []attribute.KeyValue{ + {Key: semconv.DBNameKey, Value: attribute.StringValue("dbname")}, + {Key: semconv.DBSystemKey, Value: attribute.StringValue("system")}, + {Key: semconv.DBUserKey, Value: attribute.StringValue("user")}, + {Key: semconv.DBConnectionStringKey, Value: attribute.StringValue("connString")}, + {Key: semconv.DBStatementKey, Value: attribute.StringValue("wrong statement")}, + {Key: semconv.DBOperationKey, Value: attribute.StringValue("operation")}, + }}, "name", "dbname", "system", "user", "connString", "statement", "operation") +}