Skip to content

Commit

Permalink
chore: review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Sidddddarth committed Feb 12, 2025
1 parent d1cba2d commit faa40c2
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 142 deletions.
3 changes: 2 additions & 1 deletion enterprise/reporting/reporting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/rudderlabs/rudder-go-kit/config"
"github.com/rudderlabs/rudder-go-kit/logger"
"github.com/rudderlabs/rudder-go-kit/stats"
"github.com/rudderlabs/rudder-server/utils/misc"
"github.com/rudderlabs/rudder-server/utils/types"
)

Expand Down Expand Up @@ -545,7 +546,7 @@ func TestSanitizeJSONForReports(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := sanitizeJSONForReports(tt.input)
got, err := misc.SanitizeJSON(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("sanitizeJSONForReports() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down
2 changes: 1 addition & 1 deletion enterprise/reporting/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func TestFloorFactor(t *testing.T) {
}

func TestGetSampleWithEventSampling(t *testing.T) {
sampleEvent := json.RawMessage(`{"event": "2"}`)
sampleEvent := json.RawMessage(`{"event":"2"}`)
sampleResponse := "sample response"
emptySampleEvent := json.RawMessage(`{}`)
emptySampleResponse := ""
Expand Down
30 changes: 3 additions & 27 deletions enterprise/reporting/utils.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package reporting

import (
"bytes"
"encoding/json"
"sort"
"strings"

jsoniter "github.com/json-iterator/go"

"github.com/rudderlabs/rudder-go-kit/config"
"github.com/rudderlabs/rudder-server/enterprise/reporting/event_sampler"
"github.com/rudderlabs/rudder-server/utils/misc"
"github.com/rudderlabs/rudder-server/utils/types"
)

Expand Down Expand Up @@ -60,7 +58,7 @@ func getSampleWithEventSampling(metric types.PUReportedMetric, reportedAt int64,
if !eventSamplingEnabled || eventSampler == nil {
// Sanitize both sample event and response before returning
if sampleEvent != nil {
sampleEvent, err = sanitizeJSONForReports(sampleEvent)
sampleEvent, err = misc.SanitizeJSON(sampleEvent)
if err != nil {
return []byte(`{}`), "", err
}
Expand Down Expand Up @@ -91,7 +89,7 @@ func getSampleWithEventSampling(metric types.PUReportedMetric, reportedAt int64,
}
// Sanitize both sample event and response before returning
if sampleEvent != nil {
sampleEvent, err = sanitizeJSONForReports(sampleEvent)
sampleEvent, err = misc.SanitizeJSON(sampleEvent)
if err != nil {
return []byte(`{}`), "", err
}

Check warning on line 95 in enterprise/reporting/utils.go

View check run for this annotation

Codecov / codecov/patch

enterprise/reporting/utils.go#L94-L95

Added lines #L94 - L95 were not covered by tests
Expand Down Expand Up @@ -131,25 +129,3 @@ func getPIIColumnsToExclude() []string {
}
return piiColumnsToExclude
}

func sanitizeJSONForReports(input json.RawMessage) (json.RawMessage, error) {
// Remove null characters
v := bytes.ReplaceAll(input, []byte(`\u0000`), []byte(""))

if len(v) == 0 {
return []byte(`{}`), nil
}

// Validate JSON structure by unmarshaling and marshaling
var a any
err := jsoniter.ConfigCompatibleWithStandardLibrary.Unmarshal(v, &a)
if err != nil {
return nil, err
}
v, err = jsoniter.ConfigCompatibleWithStandardLibrary.Marshal(a)
if err != nil {
return nil, err
}

return v, nil
}
38 changes: 5 additions & 33 deletions jobsdb/jobsdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package jobsdb
//go:generate mockgen -destination=../mocks/jobsdb/mock_jobsdb.go -package=mocks_jobsdb github.com/rudderlabs/rudder-server/jobsdb JobsDB

import (
"bytes"
"context"
"crypto/sha256"
"database/sql"
Expand All @@ -37,7 +36,6 @@ import (
"unicode/utf8"

"github.com/google/uuid"
jsoniter "github.com/json-iterator/go"
"github.com/lib/pq"
"github.com/samber/lo"
"github.com/tidwall/gjson"
Expand All @@ -58,10 +56,7 @@ import (
. "github.com/rudderlabs/rudder-server/utils/tx" //nolint:staticcheck
)

var (
errStaleDsList = errors.New("stale dataset list")
jsonfast = jsoniter.ConfigCompatibleWithStandardLibrary
)
var errStaleDsList = errors.New("stale dataset list")

const (
pgReadonlyTableExceptionFuncName = "readonly_table_exception()"
Expand Down Expand Up @@ -383,12 +378,12 @@ type ConnectionDetails struct {

func (r *JobStatusT) sanitizeJson() error {
var err error
r.ErrorResponse, err = sanitizeJSON(r.ErrorResponse)
r.ErrorResponse, err = misc.SanitizeJSON(r.ErrorResponse)
if err != nil {
return err
}

r.Parameters, err = sanitizeJSON(r.Parameters)
r.Parameters, err = misc.SanitizeJSON(r.Parameters)
if err != nil {
return err
}
Expand Down Expand Up @@ -421,11 +416,11 @@ func (job *JobT) String() string {

func (job *JobT) sanitizeJSON() error {
var err error
job.EventPayload, err = sanitizeJSON(job.EventPayload)
job.EventPayload, err = misc.SanitizeJSON(job.EventPayload)
if err != nil {
return err
}
job.Parameters, err = sanitizeJSON(job.Parameters)
job.Parameters, err = misc.SanitizeJSON(job.Parameters)
if err != nil {
return err
}
Expand Down Expand Up @@ -3292,29 +3287,6 @@ func (jd *Handle) GetLastJob(ctx context.Context) *JobT {
return &job
}

// sanitizeJSON makes a json payload safe for writing into postgres.
// 1. Removes any \u0000 string from the payload
// ~2. Replaces any invalid utf8 characters using github.com/rudderlabs/rudder-go-kit/utf8~
// 3. unmashals and marshals the payload to remove any extra keys
func sanitizeJSON(input json.RawMessage) (json.RawMessage, error) {
v := bytes.ReplaceAll(input, []byte(`\u0000`), []byte(""))
if len(v) == 0 {
v = []byte(`{}`)
}

var a any
err := jsonfast.Unmarshal(v, &a)
if err != nil {
return nil, err
}
v, err = jsonfast.Marshal(a)
if err != nil {
return nil, err
}

return v, nil
}

func (jd *Handle) withDistributedLock(ctx context.Context, tx *Tx, operation string, f func() error) error {
advisoryLock := jd.getAdvisoryLockForOperation(operation)
_, err := tx.ExecContext(ctx, fmt.Sprintf(`SELECT pg_advisory_xact_lock(%d);`, advisoryLock))
Expand Down
34 changes: 34 additions & 0 deletions utils/misc/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"crypto/md5"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"hash/fnv"
Expand All @@ -28,6 +29,7 @@ import (
"github.com/araddon/dateparse"
"github.com/cenkalti/backoff"
"github.com/google/uuid"
jsoniter "github.com/json-iterator/go"
"github.com/tidwall/sjson"

"github.com/rudderlabs/rudder-go-kit/config"
Expand Down Expand Up @@ -965,3 +967,35 @@ func GetInstanceID() string {
}
return ""
}

var jsonfast = jsoniter.ConfigCompatibleWithStandardLibrary

// SanitizeJSON makes a json payload safe for writing into postgres.
// 1. Removes any \u0000 string from the payload
// ~2. Replaces any invalid utf8 characters using github.com/rudderlabs/rudder-go-kit/utf8~
// 3. unmarshals and marshals the payload to remove any extra keys
func SanitizeJSON(input json.RawMessage) (json.RawMessage, error) {
// Remove null characters
v := bytes.ReplaceAll(input, []byte(`\u0000`), []byte(""))

if len(v) == 0 {
return []byte(`{}`), nil
}

// Validate JSON structure by unmarshaling and marshaling
var a any
err := jsonfast.Unmarshal(v, &a)
if err != nil {
return nil, err
}
v, err = jsonfast.Marshal(a)
if err != nil {
return nil, err
}

Check warning on line 994 in utils/misc/misc.go

View check run for this annotation

Codecov / codecov/patch

utils/misc/misc.go#L993-L994

Added lines #L993 - L994 were not covered by tests

return v, nil
}

func SanitizeString(input string) string {
return strings.ReplaceAll(input, "\u0000", "")
}
87 changes: 87 additions & 0 deletions utils/misc/misc_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package misc

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -788,3 +790,88 @@ func Test_GetInstanceID(t *testing.T) {
t.Setenv("INSTANCE_ID", "prousmtusmt-v0-rs-gw-ha-12-234234-10")
require.Equal(t, "12", GetInstanceID())
}

func TestSanitizeJSON(t *testing.T) {
tests := []struct {
name string
input json.RawMessage
want json.RawMessage
wantErr bool
}{
{
name: "empty input",
input: json.RawMessage(``),
want: json.RawMessage(`{}`),
wantErr: false,
},
{
name: "valid json",
input: json.RawMessage(`{"key":"value"}`),
want: json.RawMessage(`{"key":"value"}`),
wantErr: false,
},
{
name: "json with null characters",
input: json.RawMessage(`{"key":"\u0000value\u0000"}`),
want: json.RawMessage(`{"key":"value"}`),
wantErr: false,
},
{
name: "json with html entities",
input: json.RawMessage(`{"key":"\u0026\u0000value\u003ctest\u003e"}`),
want: json.RawMessage(`{"key":"\u0026value\u003ctest\u003e"}`),
wantErr: false,
},
{
name: "invalid json",
input: json.RawMessage(`{"key":"value`),
want: nil,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := SanitizeJSON(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("sanitizeJSONForReports() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && !bytes.Equal(got, tt.want) {
t.Errorf("sanitizeJSONForReports() = %s, want %s", got, tt.want)
}
})
}
}

func TestSanitizeString(t *testing.T) {
testCases := []struct {
name string
input string
expected string
}{
{
name: "empty string",
input: "",
},
{
name: "with unicode characters",
input: "Start: \u0000\u0000\u0000\u0000\u0000\u0000\u0000 : End",
expected: "Start: : End",
},
{
name: "without unicode characters",
input: "Start: : End",
expected: "Start: : End",
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

require.Equal(t, tc.expected, SanitizeString(tc.input))
})
}
}
3 changes: 2 additions & 1 deletion warehouse/internal/repo/table_upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/lib/pq"

"github.com/rudderlabs/rudder-server/utils/misc"
"github.com/rudderlabs/rudder-server/utils/timeutil"
sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper"
"github.com/rudderlabs/rudder-server/warehouse/internal/model"
Expand Down Expand Up @@ -297,7 +298,7 @@ func (tu *TableUploads) Set(ctx context.Context, uploadId int64, tableName strin
}
if options.Error != nil {
setQuery.WriteString(fmt.Sprintf(`error = $%d,`, len(queryArgs)+1))
sanitizedError := warehouseutils.SanitizeString(*options.Error)
sanitizedError := misc.SanitizeString(*options.Error)
queryArgs = append(queryArgs, sanitizedError)
}
if options.LastExecTime != nil {
Expand Down
2 changes: 1 addition & 1 deletion warehouse/router/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ func (job *UploadJob) setUploadError(statusError error, state string) (string, e
}

serializedErr, _ := json.Marshal(&uploadErrors)
serializedErr = whutils.SanitizeJSON(serializedErr)
serializedErr, _ = misc.SanitizeJSON(serializedErr)

txn, err := job.db.BeginTx(job.ctx, &sql.TxOptions{})
if err != nil {
Expand Down
13 changes: 0 additions & 13 deletions warehouse/utils/utils.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package warehouseutils

import (
"bytes"
"crypto/sha512"
"database/sql"
"encoding/hex"
Expand Down Expand Up @@ -317,18 +316,6 @@ func GetObjectLocation(provider, location string) (objectLocation string) {
return
}

func SanitizeJSON(input json.RawMessage) json.RawMessage {
v := bytes.ReplaceAll(input, []byte(`\u0000`), []byte(""))
if len(v) == 0 {
v = []byte(`{}`)
}
return v
}

func SanitizeString(input string) string {
return strings.ReplaceAll(input, "\u0000", "")
}

// GetObjectName extracts object/key objectName from different buckets locations
// ex: https://bucket-endpoint/bucket-name/object -> object
func GetObjectName(location string, providerConfig interface{}, objectProvider string) (objectName string, err error) {
Expand Down
Loading

0 comments on commit faa40c2

Please sign in to comment.