diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 69816c05..ff89a08a 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -34,7 +34,7 @@ jobs: - name: Coverage uses: gwatts/go-coverage-action@v1 - if: github.ref == 'refs/heads/master' + # if: github.ref == 'refs/heads/master' # continue-on-error: true with: add-comment: true diff --git a/.gitignore b/.gitignore index 0dc82900..6c793efe 100644 --- a/.gitignore +++ b/.gitignore @@ -35,11 +35,11 @@ Gopkg.toml # Version file (used in CI) version -#plugins +# Myrtea Plugins /plugin/*.plugin -/certs - -plugin/*.plugin +# Exports directory (for local testing) +exports/ +# GoLang Linter executable golangci-lint.exe \ No newline at end of file diff --git a/config/engine-api.toml b/config/engine-api.toml index be3e07fc..9a9cb12a 100644 --- a/config/engine-api.toml +++ b/config/engine-api.toml @@ -240,4 +240,33 @@ AUTHENTICATION_OIDC_FRONT_END_URL = "http://127.0.0.1:4200" # Note: The key length is critical for the AES encryption algorithm used for state encryption/decryption. # It must be exactly 16, 24 or 32 characters long. # Default value: "thisis24characterslongs." (24 characters) -AUTHENTICATION_OIDC_ENCRYPTION_KEY = "thisis24characterslongs." \ No newline at end of file +AUTHENTICATION_OIDC_ENCRYPTION_KEY = "thisis24characterslongs." + +# NOTIFICATION_LIFETIME: The lifetime of a notification in the database. +# Default value: "168h" +NOTIFICATION_LIFETIME = "168h" # 168h = 7 days, available units are "ns", "us" (or "µs"), "ms", "s", "m", "h" + +# Path to directory where the resulting export files will be stored. +# Default value: "exports/" +EXPORT_BASE_PATH = "exports/" + +# Number of days before one export file will be auto deleted +# Default value: 4 +EXPORT_DISK_RETENTION_DAYS = 4 + +# Export queue max size, any export request that is made when queue is full will be refused. +# Default value: 30 +EXPORT_QUEUE_MAX_SIZE = 30 + +# Number of concurrent export workers +# Default value: 4 +EXPORT_WORKERS_COUNT = 4 + +# Whether download must be directly streamed through http or handled by an external web server +# Default value: true +EXPORT_DIRECT_DOWNLOAD = true + +# Reverse proxy like nginx, apache gives direct access to the exports directory at a specific path +# Full URL to the exports directory +# Default value: "" +EXPORT_INDIRECT_DOWNLOAD_URL = "" \ No newline at end of file diff --git a/go.mod b/go.mod index 7efff45d..60ad25f2 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.20 require ( github.com/Masterminds/squirrel v1.5.3 - github.com/alexmullins/zip v0.0.0-20180717182244-4affb64b04d0 github.com/coreos/go-oidc/v3 v3.6.0 github.com/crewjam/saml v0.4.6 github.com/dgrijalva/jwt-go v3.2.0+incompatible @@ -19,7 +18,7 @@ require ( github.com/gorilla/context v1.1.1 github.com/gorilla/websocket v1.5.0 github.com/hashicorp/go-hclog v1.3.1 - github.com/hashicorp/go-plugin v1.3.0 + github.com/hashicorp/go-plugin v1.5.2 github.com/jmoiron/sqlx v1.2.0 github.com/json-iterator/go v1.1.12 github.com/lestrrat-go/jwx v1.2.6 @@ -38,7 +37,7 @@ require ( golang.org/x/net v0.12.0 golang.org/x/oauth2 v0.6.0 google.golang.org/grpc v1.40.0 - google.golang.org/protobuf v1.28.1 + google.golang.org/protobuf v1.28.2-0.20230222093303-bc1253ad3743 ) require ( diff --git a/go.sum b/go.sum index 71725cca..19569d75 100644 --- a/go.sum +++ b/go.sum @@ -43,15 +43,11 @@ github.com/PaesslerAG/gval v1.2.2 h1:Y7iBzhgE09IGTt5QgGQ2IdaYYYOU134YGHBThD+wm9E github.com/PaesslerAG/gval v1.2.2/go.mod h1:XRFLwvmkTEdYziLdaCeCa5ImcGVrfQbeNUbVR+C6xac= github.com/PaesslerAG/jsonpath v0.1.0 h1:gADYeifvlqK3R3i2cR5B4DGgxLXIPb3TRTH1mGi0jPI= github.com/PaesslerAG/jsonpath v0.1.0/go.mod h1:4BzmtoM/PI8fPO4aQGIusjGxGir2BzcV0grWtFzq1Y8= -github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= -github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= -github.com/alexmullins/zip v0.0.0-20180717182244-4affb64b04d0 h1:BVts5dexXf4i+JX8tXlKT0aKoi38JwTXSe+3WUneX0k= -github.com/alexmullins/zip v0.0.0-20180717182244-4affb64b04d0/go.mod h1:FDIQmoMNJJl5/k7upZEnGvgWVZfFeE6qHeN7iCMbCsA= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/beevik/etree v1.1.0 h1:T0xke/WvNtMoCqgzPhkX2r4rjY3GDZFi+FjpRZY2Jbs= @@ -62,6 +58,7 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bufbuild/protocompile v0.4.0 h1:LbFKd2XowZvQ/kajzguUp2DC9UEIQhIq77fZZlaQsNA= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -224,11 +221,10 @@ github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51 github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= -github.com/hashicorp/go-hclog v0.0.0-20180709165350-ff2cf002a8dd/go.mod h1:9bjs9uLqI8l75knNv3lV1kA55veR+WUPSiKIWcQHudI= github.com/hashicorp/go-hclog v1.3.1 h1:vDwF1DFNZhntP4DAjuTpOw3uEgMUpXh1pB5fW9DqHpo= github.com/hashicorp/go-hclog v1.3.1/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= -github.com/hashicorp/go-plugin v1.3.0 h1:4d/wJojzvHV1I4i/rrjVaeuyxWrLzDE1mDCyDy8fXS8= -github.com/hashicorp/go-plugin v1.3.0/go.mod h1:F9eH4LrE/ZsRdbwhfjs9k9HoDUwAHnYtXdgmf1AVNs0= +github.com/hashicorp/go-plugin v1.5.2 h1:aWv8eimFqWlsEiMrYZdPYl+FdHaBJSN4AWwGWfT1G2Y= +github.com/hashicorp/go-plugin v1.5.2/go.mod h1:w1sAEES3g3PuV/RzUrgow20W2uErMly84hhD3um1WL4= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= @@ -236,8 +232,7 @@ github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb h1:b5rjCoWHc7eqmAS4/qyk21ZsHyb6Mxv/jykxvNTkU4M= github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/jhump/protoreflect v1.6.0 h1:h5jfMVslIg6l29nsMs0D8Wj17RDVdNYti0vDN/PZZoE= -github.com/jhump/protoreflect v1.6.0/go.mod h1:eaTn3RZAmMBcV0fifFvlm6VHNz3wSkYyXYWUh7ymB74= +github.com/jhump/protoreflect v1.15.1 h1:HUMERORf3I3ZdX05WaQ6MIpd/NJ434hTp5YiKgfCL6c= github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA= github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= @@ -312,7 +307,6 @@ github.com/mattn/go-sqlite3 v1.9.0 h1:pDRiWfl+++eC2FEFRy6jXmQlvp4Yh3z1MJKg4UeYM/ github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mitchellh/go-testing-interface v0.0.0-20171004221916-a61a99592b77/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= github.com/mitchellh/go-testing-interface v1.0.0 h1:fzU/JVNcaqHQEcVFAKeR41fkiLdIPrefOvVG1VZ96U0= github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= @@ -327,8 +321,6 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/myrteametrics/myrtea-sdk/v4 v4.4.5 h1:8dbIWpNLzvOq9/fQTiJUpZd56NRVAEoAmcRQoC+uD5c= -github.com/myrteametrics/myrtea-sdk/v4 v4.4.5/go.mod h1:wa9nwNcFGpGbZeqXXqhTLp7sXERbCrRhhcASGY6H0QA= github.com/myrteametrics/myrtea-sdk/v4 v4.4.7 h1:cIn6+hCgzGAaWGjtAm0rFPdXX/cl6z4wWQuWI+KG9eQ= github.com/myrteametrics/myrtea-sdk/v4 v4.4.7/go.mod h1:wa9nwNcFGpGbZeqXXqhTLp7sXERbCrRhhcASGY6H0QA= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= @@ -494,7 +486,6 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= -golang.org/x/net v0.0.0-20180530234432-1e491301e022/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -692,7 +683,6 @@ google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCID google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/genproto v0.0.0-20170818010345-ee236bd376b0/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -725,7 +715,6 @@ google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210917145530-b395a37504d4 h1:ysnBoUyeL/H6RCvNRhWHjKoDEmguI+mPU+qHgK8qv/w= google.golang.org/genproto v0.0.0-20210917145530-b395a37504d4/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= -google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -755,8 +744,8 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= -google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.28.2-0.20230222093303-bc1253ad3743 h1:yqElulDvOF26oZ2O+2/aoX7mQ8DY/6+p39neytrycd8= +google.golang.org/protobuf v1.28.2-0.20230222093303-bc1253ad3743/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internals/app/app.go b/internals/app/app.go index 70ce6d1f..b5e25c21 100644 --- a/internals/app/app.go +++ b/internals/app/app.go @@ -5,7 +5,7 @@ import ( "github.com/spf13/viper" ) -// Init initialiaze all the app configuration and components +// Init initialize all the app configuration and components func Init() { docs.SwaggerInfo.Host = viper.GetString("SWAGGER_HOST") diff --git a/internals/app/services.go b/internals/app/services.go index c319b757..ebf19d20 100644 --- a/internals/app/services.go +++ b/internals/app/services.go @@ -2,6 +2,7 @@ package app import ( "errors" + "github.com/myrteametrics/myrtea-engine-api/v5/internals/export" "strings" "github.com/myrteametrics/myrtea-engine-api/v5/internals/calendar" @@ -78,6 +79,11 @@ func stopServices() { } func initNotifier() { + notificationLifetime := viper.GetDuration("NOTIFICATION_LIFETIME") + handler := notification.NewHandler(notificationLifetime) + handler.RegisterNotificationType(notification.MockNotification{}) + handler.RegisterNotificationType(export.ExportNotification{}) + notification.ReplaceHandlerGlobals(handler) notifier.ReplaceGlobals(notifier.NewNotifier()) } @@ -90,7 +96,6 @@ func initScheduler() { if viper.GetBool("ENABLE_CRONS_ON_START") { scheduler.S().C.Start() } - } } func initTasker() { @@ -100,7 +105,6 @@ func initTasker() { func initCalendars() { calendar.Init() - } func initCoordinator() { @@ -114,7 +118,7 @@ func initCoordinator() { instanceName := viper.GetString("INSTANCE_NAME") if err = coordinator.InitInstance(instanceName, models); err != nil { - zap.L().Fatal("Intialisation of coordinator master", zap.Error(err)) + zap.L().Fatal("Initialization of coordinator master", zap.Error(err)) } if viper.GetBool("ENABLE_CRONS_ON_START") { for _, li := range coordinator.GetInstance().LogicalIndices { @@ -127,13 +131,11 @@ func initCoordinator() { } func initEmailSender() { - username := viper.GetString("SMTP_USERNAME") password := viper.GetString("SMTP_PASSWORD") host := viper.GetString("SMTP_HOST") port := viper.GetString("SMTP_PORT") email.InitSender(username, password, host, port) - } func initOidcAuthentication() { diff --git a/internals/export/csv.go b/internals/export/csv.go index 7e746006..584cec86 100644 --- a/internals/export/csv.go +++ b/internals/export/csv.go @@ -6,36 +6,43 @@ import ( "fmt" "strings" "time" + "unicode/utf8" "github.com/myrteametrics/myrtea-engine-api/v5/internals/reader" "go.uber.org/zap" ) -func ConvertHitsToCSV(hits []reader.Hit, columns []string, columnsLabel []string, formatColumnsData map[string]string, separator rune) ([]byte, error) { - b := new(bytes.Buffer) - w := csv.NewWriter(b) - w.Comma = separator +// WriteConvertHitsToCSV writes hits to CSV +func WriteConvertHitsToCSV(w *csv.Writer, hits []reader.Hit, params CSVParameters, writeHeader bool) error { + if len(params.Separator) == 1 { + w.Comma, _ = utf8.DecodeRune([]byte(params.Separator)) + if w.Comma == utf8.RuneError { + w.Comma = ',' + } + } else { + w.Comma = ',' + } // avoid to print header when labels are empty - if len(columnsLabel) > 0 { - w.Write(columnsLabel) + if writeHeader && len(params.Columns) > 0 { + w.Write(params.GetColumnsLabel()) } for _, hit := range hits { record := make([]string, 0) - for _, column := range columns { - value, err := nestedMapLookup(hit.Fields, strings.Split(column, ".")...) + for _, column := range params.Columns { + value, err := nestedMapLookup(hit.Fields, strings.Split(column.Name, ".")...) if err != nil { value = "" - } else if format, ok := formatColumnsData[column]; ok { + } else if column.Format != "" { if date, ok := value.(time.Time); ok { - value = date.Format(format) + value = date.Format(column.Format) } else if dateStr, ok := value.(string); ok { date, err := parseDate(dateStr) if err != nil { zap.L().Error("Failed to parse date string:", zap.Any(":", dateStr), zap.Error(err)) } else { - value = date.Format(format) + value = date.Format(column.Format) } } } @@ -45,12 +52,23 @@ func ConvertHitsToCSV(hits []reader.Hit, columns []string, columnsLabel []string } w.Flush() - if err := w.Error(); err != nil { + return w.Error() +} + +// ConvertHitsToCSV converts hits to CSV +func ConvertHitsToCSV(hits []reader.Hit, params CSVParameters, writeHeader bool) ([]byte, error) { + b := new(bytes.Buffer) + w := csv.NewWriter(b) + err := WriteConvertHitsToCSV(w, hits, params, writeHeader) + + if err != nil { return nil, err } + return b.Bytes(), nil } +// nestedMapLookup looks up a nested map item func nestedMapLookup(m map[string]interface{}, ks ...string) (rval interface{}, err error) { var ok bool if len(ks) == 0 { @@ -67,6 +85,7 @@ func nestedMapLookup(m map[string]interface{}, ks ...string) (rval interface{}, } } +// parseDate parses a date string func parseDate(dateStr string) (time.Time, error) { formats := []string{ "2006-01-02T15:04:05.999", diff --git a/internals/export/csv_test.go b/internals/export/csv_test.go index f27d08a9..e1bcc644 100644 --- a/internals/export/csv_test.go +++ b/internals/export/csv_test.go @@ -1,6 +1,8 @@ package export import ( + "bytes" + csv2 "encoding/csv" "testing" "github.com/myrteametrics/myrtea-engine-api/v5/internals/reader" @@ -13,15 +15,99 @@ func TestConvertHitsToCSV(t *testing.T) { {ID: "3", Fields: map[string]interface{}{"a": "hello", "b": 20, "c": 3.123456, "date": "2023-06-30T10:42:59.500"}}, {ID: "1", Fields: map[string]interface{}{"a": "hello", "b": 20, "c": 3.123456, "d": map[string]interface{}{"zzz": "nested"}, "date": "2023-06-30T10:42:59.500"}}, } - columns := []string{"a", "b", "c", "d.e", "date"} - columnsLabel := []string{"Label A", "Label B", "Label C", "Label D.E", "Date"} - formatColumnsData := map[string]string{ - "date": "02/01/2006", + params := CSVParameters{ + Columns: []Column{ + {Name: "a", Label: "Label A", Format: ""}, + {Name: "b", Label: "Label B", Format: ""}, + {Name: "c", Label: "Label C", Format: ""}, + {Name: "d.e", Label: "Label D.E", Format: ""}, + {Name: "date", Label: "Date", Format: "02/01/2006"}, + }, + Separator: ",", } - csv, err := ConvertHitsToCSV(hits, columns, columnsLabel, formatColumnsData, ',') + csv, err := ConvertHitsToCSV(hits, params, true) if err != nil { t.Log(err) t.FailNow() } t.Log("\n" + string(csv)) } + +func TestWriteConvertHitsToCSV(t *testing.T) { + hits := []reader.Hit{ + {ID: "1", Fields: map[string]interface{}{"a": "hello", "b": 20, "c": 3.123456, "d": map[string]interface{}{"e": "nested"}, "date": "2023-06-30T10:42:59.500"}}, + {ID: "2", Fields: map[string]interface{}{"b": 20, "c": 3.123456, "d": map[string]interface{}{"e": "nested"}, "date": "2023-06-30T10:42:59.500"}}, + {ID: "3", Fields: map[string]interface{}{"a": "hello", "b": 20, "c": 3.123456, "date": "2023-06-30T10:42:59.500"}}, + {ID: "1", Fields: map[string]interface{}{"a": "hello", "b": 20, "c": 3.123456, "d": map[string]interface{}{"zzz": "nested"}, "date": "2023-06-30T10:42:59.500"}}, + } + params := CSVParameters{ + Columns: []Column{ + {Name: "a", Label: "Label A", Format: ""}, + {Name: "b", Label: "Label B", Format: ""}, + {Name: "c", Label: "Label C", Format: ""}, + {Name: "d.e", Label: "Label D.E", Format: ""}, + {Name: "date", Label: "Date", Format: "02/01/2006"}, + }, + Separator: ",", + } + b := new(bytes.Buffer) + w := csv2.NewWriter(b) + err := WriteConvertHitsToCSV(w, hits, params, true) + if err != nil { + t.Log(err) + t.FailNow() + } + t.Log("\n" + string(b.Bytes())) +} + +func TestNestedMapLookup_WithEmptyKeys(t *testing.T) { + _, err := nestedMapLookup(map[string]interface{}{}, "") + if err == nil { + t.FailNow() + } +} + +func TestNestedMapLookup_WithNonExistentKey(t *testing.T) { + _, err := nestedMapLookup(map[string]interface{}{"a": "hello"}, "b") + if err == nil { + t.FailNow() + } +} + +func TestNestedMapLookup_WithNestedNonExistentKey(t *testing.T) { + _, err := nestedMapLookup(map[string]interface{}{"a": map[string]interface{}{"b": "hello"}}, "a", "c") + if err == nil { + t.FailNow() + } +} + +func TestNestedMapLookup_WithNestedKey(t *testing.T) { + val, err := nestedMapLookup(map[string]interface{}{"a": map[string]interface{}{"b": "hello"}}, "a", "b") + if err != nil || val != "hello" { + t.Error(err) + t.FailNow() + } +} + +func TestParseDate_WithInvalidFormat(t *testing.T) { + _, err := parseDate("2023-06-30") + if err == nil { + t.FailNow() + } +} + +func TestParseDate_WithValidFormat(t *testing.T) { + _, err := parseDate("2023-06-30T10:42:59.500") + if err != nil { + t.Error(err) + t.FailNow() + } +} + +func TestConvertHitsToCSV_WithEmptyHits(t *testing.T) { + _, err := ConvertHitsToCSV([]reader.Hit{}, CSVParameters{}, true) + if err != nil { + t.Error(err) + t.FailNow() + } +} diff --git a/internals/export/notification.go b/internals/export/notification.go new file mode 100644 index 00000000..99e3695d --- /dev/null +++ b/internals/export/notification.go @@ -0,0 +1,88 @@ +package export + +import ( + "encoding/json" + "github.com/myrteametrics/myrtea-engine-api/v5/internals/notifier/notification" + "reflect" +) + +const ( + ExportNotificationStarted = 0 + ExportNotificationArchived = 1 // happens when + ExportNotificationDeleted = 2 // happens when the export is deleted from archive +) + +type ExportNotification struct { + notification.BaseNotification + Export WrapperItem `json:"export"` + Status int `json:"status"` +} + +func NewExportNotification(id int64, export WrapperItem, status int) *ExportNotification { + return &ExportNotification{ + BaseNotification: notification.BaseNotification{ + Id: id, + Type: "ExportNotification", + Persistent: false, + }, + Export: export, + Status: status, + } +} + +// ToBytes convert a notification in a json byte slice to be sent through any required channel +func (e ExportNotification) ToBytes() ([]byte, error) { + b, err := json.Marshal(e) + if err != nil { + return nil, err + } + return b, nil +} + +// NewInstance returns a new instance of a ExportNotification +func (e ExportNotification) NewInstance(id int64, data []byte, isRead bool) (notification.Notification, error) { + var notif ExportNotification + err := json.Unmarshal(data, ¬if) + if err != nil { + return nil, err + } + notif.Id = id + notif.IsRead = isRead + notif.Notification = notif + return notif, nil +} + +// Equals returns true if the two notifications are equals +func (e ExportNotification) Equals(notification notification.Notification) bool { + notif, ok := notification.(ExportNotification) + if !ok { + return ok + } + if !notif.BaseNotification.Equals(e.BaseNotification) { + return false + } + if !reflect.DeepEqual(notif.Export, e.Export) { + return false + } + if notif.Status != e.Status { + return false + } + return true +} + +// SetId set the notification ID +func (e ExportNotification) SetId(id int64) notification.Notification { + e.Id = id + return e +} + +// SetPersistent sets whether the notification is persistent (saved to a database) +func (e ExportNotification) SetPersistent(persistent bool) notification.Notification { + e.Persistent = persistent + return e +} + +// IsPersistent returns whether the notification is persistent (saved to a database) +func (e ExportNotification) IsPersistent() bool { + return e.Persistent +} diff --git a/internals/export/notification_test.go b/internals/export/notification_test.go new file mode 100644 index 00000000..ca390224 --- /dev/null +++ b/internals/export/notification_test.go @@ -0,0 +1,119 @@ +package export + +import ( + "github.com/google/uuid" + "github.com/myrteametrics/myrtea-engine-api/v5/internals/notifier/notification" + "github.com/myrteametrics/myrtea-sdk/v4/expression" + "testing" +) + +func TestExportNotification(t *testing.T) { + // init handler + handler := notification.NewHandler(0) + handler.RegisterNotificationType(notification.MockNotification{}) + handler.RegisterNotificationType(ExportNotification{}) + notification.ReplaceHandlerGlobals(handler) + + notif := ExportNotification{ + Export: WrapperItem{ + Id: uuid.New().String(), + }, + Status: 1, + } + notif.Id = 1 + notif.IsRead = false + + bytes, err := notif.ToBytes() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if bytes == nil { + t.Fatalf("Expected bytes, got nil") + } + + t.Log(string(bytes)) + + // find type and create new instance + notifType, ok := notification.H().GetNotificationByType("ExportNotification") + if !ok { + t.Fatalf("ExportNotification type does not exist") + } + + instance, err := notifType.NewInstance(1, bytes, false) + if err != nil { + t.Fatalf("ExportNotification couldn't be instanced") + } + bt, _ := instance.ToBytes() + t.Log(string(bt)) + + expression.AssertEqual(t, string(bytes), string(bt)) +} + +func TestExportNotification_Equals(t *testing.T) { + id := uuid.New().String() + exportNotification := ExportNotification{ + BaseNotification: notification.BaseNotification{ + Id: 1, + Type: "Test", + IsRead: true, + }, + Export: WrapperItem{ + Id: id, + }, + Status: 1, + } + + expression.AssertEqual(t, exportNotification.Equals(ExportNotification{ + BaseNotification: notification.BaseNotification{ + Id: 1, + Type: "Test", + IsRead: true, + }, + Status: 1, + Export: WrapperItem{Id: id}, + }), true) + + expression.AssertEqual(t, exportNotification.Equals(ExportNotification{ + BaseNotification: notification.BaseNotification{ + Id: 2, + Type: "Test", + IsRead: true, + }, + Status: 1, + Export: WrapperItem{Id: id}, + }), false) + + expression.AssertEqual(t, exportNotification.Equals(ExportNotification{ + BaseNotification: notification.BaseNotification{ + Id: 1, + Type: "Test", + IsRead: true, + }, + Status: 2, + Export: WrapperItem{Id: id}, + }), false) + + expression.AssertEqual(t, exportNotification.Equals(ExportNotification{ + BaseNotification: notification.BaseNotification{ + Id: 1, + Type: "Test", + IsRead: true, + }, + Status: 1, + Export: WrapperItem{Id: uuid.New().String()}, + }), false) + +} + +func TestExportNotification_SetId(t *testing.T) { + notif, err := ExportNotification{}.NewInstance(1, []byte(`{}`), true) + if err != nil { + t.Fatalf("Error: %v", err) + } + + notif = notif.SetId(2) + exportNotification, ok := notif.(ExportNotification) + expression.AssertEqual(t, ok, true) + expression.AssertEqual(t, exportNotification.Id, int64(2)) +} diff --git a/internals/export/utils.go b/internals/export/utils.go new file mode 100644 index 00000000..741c4817 --- /dev/null +++ b/internals/export/utils.go @@ -0,0 +1,68 @@ +package export + +import "github.com/myrteametrics/myrtea-engine-api/v5/internals/notifier/notification" + +type CSVParameters struct { + Columns []Column `json:"columns"` + Separator string `json:"separator"` + Limit int64 `json:"limit"` +} + +type Column struct { + Name string `json:"name"` + Label string `json:"label"` + Format string `json:"format" default:""` +} + +// Equals compares two Column +func (p Column) Equals(column Column) bool { + if p.Name != column.Name { + return false + } + if p.Label != column.Label { + return false + } + if p.Format != column.Format { + return false + } + return true +} + +// Equals compares two CSVParameters +func (p CSVParameters) Equals(params CSVParameters) bool { + if p.Separator != params.Separator { + return false + } + if p.Limit != params.Limit { + return false + } + for i, column := range p.Columns { + if !column.Equals(params.Columns[i]) { + return false + } + } + return true +} + +// GetColumnsLabel returns the label of the columns +func (p CSVParameters) GetColumnsLabel() []string { + columns := make([]string, 0) + for _, column := range p.Columns { + columns = append(columns, column.Label) + } + return columns +} + +// createExportNotification creates an export notification using given parameters +func createExportNotification(status int, item *WrapperItem) ExportNotification { + return ExportNotification{ + BaseNotification: notification.BaseNotification{ + Id: 0, + IsRead: false, + Type: "ExportNotification", + Persistent: false, + }, + Export: *item, + Status: status, + } +} diff --git a/internals/export/utils_test.go b/internals/export/utils_test.go new file mode 100644 index 00000000..8be6a793 --- /dev/null +++ b/internals/export/utils_test.go @@ -0,0 +1,68 @@ +package export + +import ( + "github.com/myrteametrics/myrtea-sdk/v4/expression" + "testing" +) + +func TestColumnEquals_WithDifferentName(t *testing.T) { + column1 := Column{Name: "name1", Label: "label", Format: "format"} + column2 := Column{Name: "name2", Label: "label", Format: "format"} + expression.AssertEqual(t, column1.Equals(column2), false) +} + +func TestColumnEquals_WithDifferentLabel(t *testing.T) { + column1 := Column{Name: "name", Label: "label1", Format: "format"} + column2 := Column{Name: "name", Label: "label2", Format: "format"} + expression.AssertEqual(t, column1.Equals(column2), false) +} + +func TestColumnEquals_WithDifferentFormat(t *testing.T) { + column1 := Column{Name: "name", Label: "label", Format: "format1"} + column2 := Column{Name: "name", Label: "label", Format: "format2"} + expression.AssertEqual(t, column1.Equals(column2), false) +} + +func TestColumnEquals_WithSameValues(t *testing.T) { + column1 := Column{Name: "name", Label: "label", Format: "format"} + column2 := Column{Name: "name", Label: "label", Format: "format"} + expression.AssertEqual(t, column1.Equals(column2), true) +} + +func TestCSVParametersEquals_WithDifferentSeparator(t *testing.T) { + params1 := CSVParameters{Separator: ",", Limit: 10, Columns: []Column{{Name: "name", Label: "label", Format: "format"}}} + params2 := CSVParameters{Separator: ";", Limit: 10, Columns: []Column{{Name: "name", Label: "label", Format: "format"}}} + expression.AssertEqual(t, params1.Equals(params2), false) +} + +func TestCSVParametersEquals_WithDifferentLimit(t *testing.T) { + params1 := CSVParameters{Separator: ",", Limit: 10, Columns: []Column{{Name: "name", Label: "label", Format: "format"}}} + params2 := CSVParameters{Separator: ",", Limit: 20, Columns: []Column{{Name: "name", Label: "label", Format: "format"}}} + expression.AssertEqual(t, params1.Equals(params2), false) +} + +func TestCSVParametersEquals_WithDifferentColumns(t *testing.T) { + params1 := CSVParameters{Separator: ",", Limit: 10, Columns: []Column{{Name: "name1", Label: "label", Format: "format"}}} + params2 := CSVParameters{Separator: ",", Limit: 10, Columns: []Column{{Name: "name2", Label: "label", Format: "format"}}} + expression.AssertEqual(t, params1.Equals(params2), false) +} + +func TestCSVParametersEquals_WithSameValues(t *testing.T) { + params1 := CSVParameters{Separator: ",", Limit: 10, Columns: []Column{{Name: "name", Label: "label", Format: "format"}}} + params2 := CSVParameters{Separator: ",", Limit: 10, Columns: []Column{{Name: "name", Label: "label", Format: "format"}}} + expression.AssertEqual(t, params1.Equals(params2), true) +} + +func TestGetColumnsLabel_WithNoColumns(t *testing.T) { + params := CSVParameters{Separator: ",", Limit: 10, Columns: []Column{}} + labels := params.GetColumnsLabel() + expression.AssertEqual(t, len(labels), 0) +} + +func TestGetColumnsLabel_WithColumns(t *testing.T) { + params := CSVParameters{Separator: ",", Limit: 10, Columns: []Column{{Name: "name1", Label: "label1", Format: "format1"}, {Name: "name2", Label: "label2", Format: "format2"}}} + labels := params.GetColumnsLabel() + expression.AssertEqual(t, len(labels), 2) + expression.AssertEqual(t, labels[0], "label1") + expression.AssertEqual(t, labels[1], "label2") +} diff --git a/internals/export/worker.go b/internals/export/worker.go new file mode 100644 index 00000000..eb051d2d --- /dev/null +++ b/internals/export/worker.go @@ -0,0 +1,231 @@ +package export + +import ( + "compress/gzip" + "context" + "encoding/csv" + "fmt" + "github.com/myrteametrics/myrtea-engine-api/v5/internals/notifier" + "go.uber.org/zap" + "os" + "path/filepath" + "sync" +) + +type ExportWorker struct { + Mutex sync.Mutex + Id int + Success chan<- int + Cancel chan bool // channel to cancel the worker + BasePath string // base path where the file will be saved + // critical fields + Available bool + QueueItem WrapperItem +} + +func NewExportWorker(id int, basePath string, success chan<- int) *ExportWorker { + return &ExportWorker{ + Id: id, + Available: true, + BasePath: basePath, + Cancel: make(chan bool, 3), // buffered channel to avoid blocking + Success: success, + } +} + +// SetError sets the error and the status of the worker +func (e *ExportWorker) SetError(error error) { + e.Mutex.Lock() + defer e.Mutex.Unlock() + e.QueueItem.Status = StatusError + if error == nil { + e.QueueItem.Error = "" + } else { + e.QueueItem.Error = error.Error() + } +} + +// SetStatus sets the status of the worker +func (e *ExportWorker) SetStatus(status int) { + e.Mutex.Lock() + defer e.Mutex.Unlock() + e.QueueItem.Status = status +} + +// SwapAvailable swaps the availability of the worker +func (e *ExportWorker) SwapAvailable(available bool) (old bool) { + e.Mutex.Lock() + defer e.Mutex.Unlock() + old = e.Available + e.Available = available + return old +} + +// IsAvailable returns the availability of the worker +func (e *ExportWorker) IsAvailable() bool { + e.Mutex.Lock() + defer e.Mutex.Unlock() + return e.Available +} + +// DrainCancelChannel drains the cancel channel +func (e *ExportWorker) DrainCancelChannel() { + for { + select { + case <-e.Cancel: + default: + return + } + } +} + +// finalise sets the worker availability to true and clears the queueItem +func (e *ExportWorker) finalise() { + e.Mutex.Lock() + + // set status to error if error occurred + if e.QueueItem.Error != "" { + e.QueueItem.Status = StatusError + } + // set status to done if no error occurred + if e.QueueItem.Status != StatusError && e.QueueItem.Status != StatusCanceled { + e.QueueItem.Status = StatusDone + } + e.Mutex.Unlock() + + // clear Cancel channel, to avoid blocking + e.DrainCancelChannel() + + // notify to the dispatcher that this worker is now available + e.Success <- e.Id +} + +// Start starts the export task +// It handles one queueItem at a time and when finished it stops the goroutine +func (e *ExportWorker) Start(item WrapperItem, ctx context.Context) { + defer e.finalise() + item.Status = StatusRunning + + e.Mutex.Lock() + e.QueueItem = item + e.Mutex.Unlock() + + // send notification to user (non-blocking) + go func(wrapperItem WrapperItem) { + _ = notifier.C().SendToUserLogins( + createExportNotification(ExportNotificationStarted, &item), + wrapperItem.Users) + }(item) + + // create file + path := filepath.Join(e.BasePath, item.FileName) + // check if file not already exists + if _, err := os.Stat(path); err == nil { + e.SetError(fmt.Errorf("file with same name already exists")) + return + } + + file, err := os.Create(path) + if err != nil { + e.SetError(err) + return + } + defer file.Close() + + // opens a gzip writer + gzipWriter := gzip.NewWriter(file) + defer gzipWriter.Close() + + csvWriter := csv.NewWriter(gzipWriter) + streamedExport := NewStreamedExport() + var wg sync.WaitGroup + var writerErr error + + // local context handling + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Increment the WaitGroup counter + wg.Add(1) + + /** + * How streamed export works: + * - Export goroutine: each fact is processed one by one + * Each bulk of data is sent through a channel to the receiver + * - The receiver handles the incoming channel data and converts them to the CSV format + * After the conversion, the data is written and gzipped to a local file + */ + + go func() { + defer wg.Done() + defer close(streamedExport.Data) + + for _, f := range item.Facts { + writerErr = streamedExport.StreamedExportFactHitsFull(ctx, f, item.Params.Limit) + if writerErr != nil { + break // break here when error occurs? + } + } + }() + + // Chunk handler + first := true + +loop: + for { + select { + case hits, ok := <-streamedExport.Data: + if !ok { // channel closed + break loop + } + + err = WriteConvertHitsToCSV(csvWriter, hits, item.Params, first) + + if err != nil { + zap.L().Error("WriteConvertHitsToCSV error during export", zap.Error(err)) + cancel() + break loop + } + + // Flush data + csvWriter.Flush() + + if first { + first = false + } + case <-ctx.Done(): + break loop + case <-e.Cancel: + cancel() + break loop + } + } + + wg.Wait() + + // error occurred, close file and delete + if writerErr != nil || err != nil { + if ctx.Err() != nil { + e.SetStatus(StatusCanceled) + zap.L().Warn("Export worker: canceled, deleting file...", zap.String("filePath", path)) + } else { + if err != nil { // priority to err + e.SetError(err) + } else { + e.SetError(writerErr) + } + zap.L().Error("Export worker: error, deleting file...", zap.String("filePath", path), + zap.NamedError("err", err), zap.NamedError("writerErr", writerErr)) + } + + // close writer and file access before trying to delete file + _ = gzipWriter.Close() + _ = file.Close() + + err = os.Remove(path) + if err != nil { + zap.L().Error("Export worker: couldn't delete file", zap.String("filePath", path), zap.Error(err)) + } + } + +} diff --git a/internals/export/worker_test.go b/internals/export/worker_test.go new file mode 100644 index 00000000..fd1b50b4 --- /dev/null +++ b/internals/export/worker_test.go @@ -0,0 +1,48 @@ +package export + +import ( + "github.com/myrteametrics/myrtea-sdk/v4/expression" + "testing" +) + +func TestNewExportWorker(t *testing.T) { + worker := NewExportWorker(0, "/tmp", make(chan<- int)) + expression.AssertEqual(t, worker.BasePath, "/tmp") + expression.AssertEqual(t, worker.Available, true) + expression.AssertEqual(t, worker.Id, 0) +} + +func TestExportWorker_SetError(t *testing.T) { + worker := NewExportWorker(0, "/tmp", make(chan<- int)) + worker.SetError(nil) + expression.AssertEqual(t, worker.QueueItem.Status, StatusError) + expression.AssertEqual(t, worker.QueueItem.Error, "") +} + +func TestExportWorker_SetStatus(t *testing.T) { + worker := NewExportWorker(0, "/tmp", make(chan<- int)) + worker.SetStatus(StatusPending) + expression.AssertEqual(t, worker.QueueItem.Status, StatusPending) +} + +func TestExportWorker_SwapAvailable(t *testing.T) { + worker := NewExportWorker(0, "/tmp", make(chan<- int)) + expression.AssertEqual(t, worker.SwapAvailable(false), true) + expression.AssertEqual(t, worker.Available, false) + expression.AssertEqual(t, worker.SwapAvailable(true), false) + expression.AssertEqual(t, worker.Available, true) +} + +func TestExportWorker_IsAvailable(t *testing.T) { + worker := NewExportWorker(0, "/tmp", make(chan<- int)) + expression.AssertEqual(t, worker.IsAvailable(), true) + worker.SwapAvailable(false) + expression.AssertEqual(t, worker.IsAvailable(), false) +} + +func TestExportWorker_DrainCancelChannel(t *testing.T) { + worker := NewExportWorker(0, "/tmp", make(chan<- int)) + worker.Cancel <- true + worker.DrainCancelChannel() + expression.AssertEqual(t, len(worker.Cancel), 0) +} diff --git a/internals/export/wrapper.go b/internals/export/wrapper.go new file mode 100644 index 00000000..d5d26dcd --- /dev/null +++ b/internals/export/wrapper.go @@ -0,0 +1,550 @@ +package export + +import ( + "context" + "github.com/google/uuid" + "github.com/myrteametrics/myrtea-engine-api/v5/internals/notifier" + "github.com/myrteametrics/myrtea-engine-api/v5/internals/security" + "github.com/myrteametrics/myrtea-engine-api/v5/internals/security/users" + "github.com/myrteametrics/myrtea-sdk/v4/engine" + "go.uber.org/zap" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +const ( + CodeUserAdded = 1 + CodeAdded = 0 + CodeUserExists = -1 + CodeQueueFull = -2 + + // WrapperItem statuses + StatusPending = 0 + StatusRunning = 1 + StatusDone = 2 + StatusError = 3 + StatusCanceled = 4 + StatusCanceling = 5 + + // Delete return codes + DeleteExportNotFound = 0 + DeleteExportDeleted = 1 + DeleteExportUserDeleted = 2 + DeleteExportCanceled = 3 + + randCharSet = "abcdefghijklmnopqrstuvwxyz0123456789" +) + +// WrapperItem represents an export demand +type WrapperItem struct { + Id string `json:"id"` // unique id that represents an export demand + FactIDs []int64 `json:"factIds"` // list of fact ids that are part of the export (for archive and json) + Facts []engine.Fact `json:"-"` + Error string `json:"error"` + Status int `json:"status"` + FileName string `json:"fileName"` + Title string `json:"title"` + Date time.Time `json:"date"` + Users []string `json:"-"` + Params CSVParameters `json:"-"` +} + +type Wrapper struct { + // Queue handling + queueMutex sync.RWMutex + queue []*WrapperItem // stores queue to handle duplicates, state + + // contains also current handled items + // workers is final, its only instanced once and thus does not change size (ExportWorker have there indexes in this slice stored) + workers []*ExportWorker + + // success is passed to all workers, they write on this channel when they've finished with there export + success chan int + + // Archived WrapperItem's + archive sync.Map // map of all exports that have been done, key is the id of the export + + // Non-critical fields + // Read-only parameters + diskRetentionDays int + BasePath string // public for export_handlers + queueMaxSize int + workerCount int +} + +// NewWrapperItem creates a new export wrapper item +func NewWrapperItem(facts []engine.Fact, title string, params CSVParameters, user users.User) *WrapperItem { + var factIDs []int64 + for _, fact := range facts { + factIDs = append(factIDs, fact.ID) + } + + // file extension should be gz + // add random string to avoid multiple files with same name + fileName := security.RandStringWithCharset(5, randCharSet) + "_" + + strings.ReplaceAll(title, " ", "_") + ".csv.gz" + + return &WrapperItem{ + Users: append([]string{}, user.Login), + Id: uuid.New().String(), + Facts: facts, + FactIDs: factIDs, + Date: time.Now(), + Status: StatusPending, + Error: "", + FileName: fileName, + Title: title, + Params: params, + } +} + +// NewWrapper creates a new export wrapper +func NewWrapper(basePath string, workersCount, diskRetentionDays, queueMaxSize int) *Wrapper { + wrapper := &Wrapper{ + workers: make([]*ExportWorker, 0), + queue: make([]*WrapperItem, 0), + success: make(chan int), + archive: sync.Map{}, + queueMaxSize: queueMaxSize, + BasePath: basePath, + diskRetentionDays: diskRetentionDays, + workerCount: workersCount, + } + + return wrapper +} + +// ContainsFact checks if fact is part of the WrapperItem data +func (it *WrapperItem) ContainsFact(factID int64) bool { + for _, d := range it.FactIDs { + if d == factID { + return true + } + } + return false +} + +// Init initializes the export wrapper +func (ew *Wrapper) Init(ctx context.Context) { + // instantiate workers + for i := 0; i < ew.workerCount; i++ { + ew.workers = append(ew.workers, NewExportWorker(i, ew.BasePath, ew.success)) + } + + // check if destination folder exists + _, err := os.Stat(ew.BasePath) + if err != nil { + + if os.IsNotExist(err) { + zap.L().Info("The export directory not exists, trying to create...", zap.String("EXPORT_BASE_PATH", ew.BasePath)) + + if err := os.MkdirAll(ew.BasePath, os.ModePerm); err != nil { + zap.L().Fatal("Couldn't create export directory", zap.String("EXPORT_BASE_PATH", ew.BasePath), zap.Error(err)) + } else { + zap.L().Info("The export directory has been successfully created.") + } + + } else { + zap.L().Fatal("Couldn't access to export directory", zap.String("EXPORT_BASE_PATH", ew.BasePath), zap.Error(err)) + } + + } + + go ew.startDispatcher(ctx) +} + +// factsEquals checks if two slices of facts are equal +func factsEquals(a, b []engine.Fact) bool { + if len(a) != len(b) { + return false + } + for _, fact := range a { + found := false + for _, fact2 := range b { + if fact.ID == fact2.ID { + found = true + break + } + } + if !found { + return false + } + } + return true +} + +// AddToQueue Adds a new export to the export worker queue +func (ew *Wrapper) AddToQueue(facts []engine.Fact, title string, params CSVParameters, user users.User) (*WrapperItem, int) { + ew.queueMutex.Lock() + defer ew.queueMutex.Unlock() + + for _, queueItem := range ew.queue { + if !factsEquals(queueItem.Facts, facts) || !queueItem.Params.Equals(params) || queueItem.Title != title { + continue + } + + // check if user not already in queue.users + for _, u := range queueItem.Users { + if u == user.Login { + return nil, CodeUserExists + } + } + + queueItem.Users = append(queueItem.Users, user.Login) + return nil, CodeUserAdded + } + + if len(ew.queue) >= ew.queueMaxSize { + return nil, CodeQueueFull + } + + item := NewWrapperItem(facts, title, params, user) + ew.queue = append(ew.queue, item) + return item, CodeAdded +} + +// startDispatcher starts the export tasks dispatcher & the expired files checker +func (ew *Wrapper) startDispatcher(context context.Context) { + zap.L().Info("Starting export tasks dispatcher") + // every 5 seconds check if there is a new task to process in queue then check if there is an available worker + // if yes, start the worker with the task + // if no, continue to check + ticker := time.NewTicker(5 * time.Second) + expiredFileTicker := time.NewTicker(24 * time.Hour) + defer ticker.Stop() + defer expiredFileTicker.Stop() + + for { + select { + case w := <-ew.success: + worker := ew.workers[w] + + // archive item when finished + worker.Mutex.Lock() + ew.workers[w].Available = true + item := worker.QueueItem + worker.QueueItem = WrapperItem{} + worker.Mutex.Unlock() + + // archive item + item.Facts = []engine.Fact{} // empty facts to avoid storing them in the archive + ew.archive.Store(item.Id, item) + + // send notification to user (non-blocking) + go func(wrapperItem WrapperItem) { + _ = notifier.C().SendToUserLogins( + createExportNotification(ExportNotificationArchived, &wrapperItem), + wrapperItem.Users) + }(item) + case <-ticker.C: + ew.dispatchExportQueue(context) + case <-expiredFileTicker.C: + err := ew.checkForExpiredFiles() + + if err != nil { + zap.L().Error("Error during expired files check", zap.Error(err)) + } + case <-context.Done(): + return + } + } +} + +// checkForExpiredFiles checks for expired files in the export directory and deletes them +// it also deletes the done tasks that are older than diskRetentionDays +func (ew *Wrapper) checkForExpiredFiles() error { + // Get all files in directory and check the last edit date + // if last edit date is older than diskRetentionDays, delete the file + zap.L().Info("Checking for expired files") + files, err := os.ReadDir(ew.BasePath) + if err != nil { + return err + } + + // delete all done archives of ew.archive that are older than diskRetentionDays + ew.archive.Range(func(key, value any) bool { + data, ok := value.(WrapperItem) + if !ok { + return true + } + if time.Since(data.Date).Hours() > float64(ew.diskRetentionDays*24) { + ew.archive.Delete(key) + + // send notification to user (non-blocking) + go func(wrapperItem WrapperItem) { + _ = notifier.C().SendToUserLogins( + createExportNotification(ExportNotificationDeleted, &wrapperItem), + wrapperItem.Users) + }(data) + + } + return true + }) + + // count the number of deleted files + count := 0 + + for _, file := range files { + if file.IsDir() { + continue + } + + filePath := filepath.Join(ew.BasePath, file.Name()) + + fi, err := os.Stat(filePath) + if err != nil { + zap.L().Error("Cannot get file info", zap.String("file", filePath), zap.Error(err)) + continue + } + + // skip if file is not a zip + //if filepath.Ext(file.Name()) != ".zip" { + // continue + //} + + if time.Since(fi.ModTime()).Hours() > float64(ew.diskRetentionDays*24) { + err = os.Remove(filePath) + if err != nil { + zap.L().Error("Cannot delete file", zap.String("file", filePath), zap.Error(err)) + continue + } + count++ + } + } + + zap.L().Info("Deleted expired files", zap.Int("count", count)) + return nil +} + +func (ew *Wrapper) GetUserExports(user users.User) []WrapperItem { + result := make([]WrapperItem, 0) + + // first, gather all exports that are in the workers if there are any + for _, worker := range ew.workers { + worker.Mutex.Lock() + if worker.QueueItem.ContainsUser(user) { + result = append(result, worker.QueueItem) + } + worker.Mutex.Unlock() + } + + // then, gather all exports that are archived + ew.archive.Range(func(key, value any) bool { + data, ok := value.(WrapperItem) + if !ok { + return true + } + if data.ContainsUser(user) { + result = append(result, data) + } + return true + }) + + // finally, gather all exports that are in the queue + ew.queueMutex.Lock() + defer ew.queueMutex.Unlock() + + for _, item := range ew.queue { + if item.ContainsUser(user) { + result = append(result, *item) + } + } + + return result +} + +// dequeueWrapperItem Dequeues an item, returns size of queue and true if item was found and dequeued +func (ew *Wrapper) dequeueWrapperItem(item *WrapperItem) (int, bool) { + ew.queueMutex.Lock() + defer ew.queueMutex.Unlock() + + for i, queueItem := range ew.queue { + if queueItem.Id != item.Id { + continue + } + + ew.queue = append(ew.queue[:i], ew.queue[i+1:]...) + return len(ew.queue), true + } + + return len(ew.queue), false +} + +// dispatchExportQueue dispatches the export queue to the available workers +func (ew *Wrapper) dispatchExportQueue(ctx context.Context) { + for _, worker := range ew.workers { + worker.Mutex.Lock() + if !worker.Available { + worker.Mutex.Unlock() + continue + } + // check if there is an item in the queue + ew.queueMutex.Lock() + + if len(ew.queue) == 0 { + ew.queueMutex.Unlock() + worker.Mutex.Unlock() + return // Nothing in queue + } + + item := *ew.queue[0] + ew.queue = append(ew.queue[:0], ew.queue[1:]...) + ew.queueMutex.Unlock() + + worker.Available = false + worker.Mutex.Unlock() + + go worker.Start(item, ctx) + + } +} + +// FindArchive returns the archive item for the given id and user +func (ew *Wrapper) FindArchive(id string, user users.User) (WrapperItem, bool) { + item, found := ew.archive.Load(id) + if found { + if data, ok := item.(WrapperItem); ok && data.ContainsUser(user) { + return data, true + } + } + return WrapperItem{}, false +} + +// GetUserExport returns the export item for the given id and user +// this function is similar to GetUserExports, but it avoids iterating over all exports, thus it is faster +func (ew *Wrapper) GetUserExport(id string, user users.User) (item WrapperItem, ok bool) { + // start with archived items + if item, ok = ew.FindArchive(id, user); ok { + return item, ok + } + + // then check the workers + for _, worker := range ew.workers { + worker.Mutex.Lock() + if worker.QueueItem.Id == id && worker.QueueItem.ContainsUser(user) { + item = worker.QueueItem + ok = true + } + worker.Mutex.Unlock() + if ok { + return item, ok + } + } + + // finally check the queue + ew.queueMutex.Lock() + defer ew.queueMutex.Unlock() + + for _, it := range ew.queue { + ok = it.ContainsUser(user) + if ok { + item = *it + break + } + } + + return item, ok +} + +// DeleteExport removes an export from the queue / archive, or cancels it if it is running +// returns : +// DeleteExportNotFound (0): if the export was not found +// DeleteExportDeleted (1): if the export was found and deleted +// DeleteExportUserDeleted (2): if the export was found and the user was removed +// DeleteExportCanceled (3): if the export was found and the cancellation request was made +// this function is similar to GetUserExport, but it avoids iterating over all exports, thus it is faster +func (ew *Wrapper) DeleteExport(id string, user users.User) int { + // start with archived items + if item, ok := ew.FindArchive(id, user); ok { + if len(item.Users) == 1 { + ew.archive.Delete(id) + return DeleteExportDeleted + } + // remove user from item + for i, u := range item.Users { + if u == user.Login { + item.Users = append(item.Users[:i], item.Users[i+1:]...) + break + } + } + ew.archive.Store(id, item) + return DeleteExportUserDeleted + } + + // then check the queue + ew.queueMutex.Lock() + for i, item := range ew.queue { + if item.Id == id && item.ContainsUser(user) { + // remove user from item + for j, u := range item.Users { + if u == user.Login { + item.Users = append(item.Users[:j], item.Users[j+1:]...) + break + } + } + if len(item.Users) == 0 { + ew.queue = append(ew.queue[:i], ew.queue[i+1:]...) + ew.queueMutex.Unlock() + return DeleteExportDeleted + } + + ew.queueMutex.Unlock() + return DeleteExportUserDeleted + } + } + ew.queueMutex.Unlock() + + // finally check the workers + for _, worker := range ew.workers { + worker.Mutex.Lock() + if worker.Available || worker.QueueItem.Id != id || !worker.QueueItem.ContainsUser(user) { + worker.Mutex.Unlock() + continue + } + + // worker found but already canceling + if worker.QueueItem.Status == StatusCanceling { + worker.Mutex.Unlock() + return DeleteExportNotFound + } + + // remove user from item + if len(worker.QueueItem.Users) == 1 { + // cancel worker by sending a message on the cancel channel + // the worker will check this channel and stop if it receives a message + // it can happen that the worker is already stopped, in this case, the message will be ignored + select { // non-blocking send + case worker.Cancel <- true: + default: + } + worker.QueueItem.Status = StatusCanceling + worker.Mutex.Unlock() + return DeleteExportCanceled + } + + for i, u := range worker.QueueItem.Users { + if u == user.Login { + worker.QueueItem.Users = append(worker.QueueItem.Users[:i], worker.QueueItem.Users[i+1:]...) + worker.Mutex.Unlock() + return DeleteExportUserDeleted + } + } + worker.Mutex.Unlock() + return DeleteExportNotFound + } + + return DeleteExportNotFound +} + +// ContainsUser checks if user is in item +func (it *WrapperItem) ContainsUser(user users.User) bool { + for _, u := range it.Users { + if u == user.Login { + return true + } + } + return false +} diff --git a/internals/export/wrapper_test.go b/internals/export/wrapper_test.go new file mode 100644 index 00000000..2b0f595b --- /dev/null +++ b/internals/export/wrapper_test.go @@ -0,0 +1,456 @@ +package export + +import ( + "context" + "fmt" + "github.com/google/uuid" + "github.com/myrteametrics/myrtea-engine-api/v5/internals/notifier" + "github.com/myrteametrics/myrtea-engine-api/v5/internals/security/users" + "github.com/myrteametrics/myrtea-sdk/v4/engine" + "github.com/myrteametrics/myrtea-sdk/v4/expression" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestNewWrapper(t *testing.T) { + wrapper := NewWrapper("/tmp", 1, 1, 1) + expression.AssertEqual(t, wrapper.BasePath, "/tmp") + expression.AssertEqual(t, wrapper.queueMaxSize, 1) + expression.AssertEqual(t, wrapper.diskRetentionDays, 1) + expression.AssertEqual(t, wrapper.queueMaxSize, 1) +} + +func TestFactsEquals(t *testing.T) { + expression.AssertEqual(t, factsEquals([]engine.Fact{{ID: 1}}, []engine.Fact{{ID: 1}}), true) + expression.AssertEqual(t, factsEquals([]engine.Fact{{ID: 1}}, []engine.Fact{{ID: 2}}), false) + expression.AssertEqual(t, factsEquals([]engine.Fact{{ID: 1}, {ID: 2}}, []engine.Fact{{ID: 2}, {ID: 1}}), true) + expression.AssertEqual(t, factsEquals([]engine.Fact{{ID: 1}, {ID: 2}}, []engine.Fact{{ID: 1}, {ID: 3}}), false) + expression.AssertEqual(t, factsEquals([]engine.Fact{{ID: 1}, {ID: 2}}, []engine.Fact{{ID: 1}, {ID: 2}, {ID: 3}}), false) + expression.AssertEqual(t, factsEquals([]engine.Fact{{ID: 2}, {ID: 1}, {ID: 3}}, []engine.Fact{{ID: 1}, {ID: 2}}), false) +} + +func TestNewWrapperItem(t *testing.T) { + item := NewWrapperItem([]engine.Fact{{ID: 1}}, "test", CSVParameters{}, users.User{Login: "test"}) + expression.AssertNotEqual(t, item.Id, "") + expression.AssertEqual(t, factsEquals(item.Facts, []engine.Fact{{ID: 1}}), true) + expression.AssertEqual(t, item.Params.Equals(CSVParameters{}), true) + expression.AssertEqual(t, item.Status, StatusPending) + expression.AssertEqual(t, strings.HasSuffix(item.FileName, "test.csv.gz"), true, "test.txt.gz") + expression.AssertNotEqual(t, len(item.Users), 0) + expression.AssertEqual(t, item.Users[0], "test") +} + +func TestWrapperItem_ContainsFact(t *testing.T) { + item := NewWrapperItem([]engine.Fact{{ID: 1}, {ID: 22}, {ID: 33}}, "test.txt", CSVParameters{}, users.User{Login: "test"}) + expression.AssertEqual(t, item.ContainsFact(1), true) + expression.AssertEqual(t, item.ContainsFact(22), true) + expression.AssertEqual(t, item.ContainsFact(3), false) +} + +func TestWrapper_Init(t *testing.T) { + wrapper := NewWrapper("/tmp", 1, 1, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + wrapper.Init(ctx) + time.Sleep(500 * time.Millisecond) + expression.AssertEqual(t, len(wrapper.workers), 1) + worker := wrapper.workers[0] + expression.AssertEqual(t, worker.Id, 0) + worker.Mutex.Lock() + defer worker.Mutex.Unlock() + expression.AssertEqual(t, worker.Available, true) +} + +func TestAddToQueue(t *testing.T) { + wrapper := NewWrapper("/tmp", 1, 1, 1) + user1 := users.User{Login: "bla"} + user2 := users.User{Login: "blabla"} + csvParams := CSVParameters{} + _, result := wrapper.AddToQueue([]engine.Fact{{ID: 1}}, "test.txt", csvParams, user1) + expression.AssertEqual(t, result, CodeAdded, "AddToQueue should return CodeAdded") + _, result = wrapper.AddToQueue([]engine.Fact{{ID: 1}}, "test.txt", csvParams, user1) + expression.AssertEqual(t, result, CodeUserExists, "AddToQueue should return CodeUserExists") + _, result = wrapper.AddToQueue([]engine.Fact{{ID: 1}}, "test.txt", csvParams, user2) + expression.AssertEqual(t, result, CodeUserAdded, "AddToQueue should return CodeUserAdded") + _, result = wrapper.AddToQueue([]engine.Fact{{ID: 2}}, "test.txt", csvParams, user2) + expression.AssertEqual(t, result, CodeQueueFull, "AddToQueue should return CodeQueueFull") +} + +func TestStartDispatcher(t *testing.T) { + // we don't want that the worker try to export data, therefore we will create a temporary directory with a temp file + // so that the worker will not be able to create the file and will return an error + dname, err := os.MkdirTemp("", "exportdispatcher") + if err != nil { + t.Error(err) + t.FailNow() + } + defer os.RemoveAll(dname) + + // create a file that is 2 days old + file, err := os.CreateTemp(dname, "exportdispatcher") + if err != nil { + t.Error(err) + t.FailNow() + } + fileName := filepath.Base(file.Name()) + _ = file.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + notifier.ReplaceGlobals(notifier.NewNotifier()) // for notifications + wrapper := NewWrapper(dname, 1, 1, 1) + wrapper.Init(ctx) + expression.AssertEqual(t, len(wrapper.workers), 1) + // sleep one second to let the goroutine start + fmt.Println("Sleeping 1 second to let the goroutine start") + time.Sleep(1 * time.Second) + + worker := wrapper.workers[0] + + // check if the worker is available + worker.Mutex.Lock() + expression.AssertEqual(t, worker.Available, true) + worker.Mutex.Unlock() + + // add a task to the queue and check if the task was added to queue + user := users.User{Login: "test"} + _, result := wrapper.AddToQueue([]engine.Fact{{ID: 1}}, fileName, CSVParameters{}, user) + expression.AssertEqual(t, result, CodeAdded, "AddToQueue should return CodeAdded") + wrapper.queueMutex.Lock() + expression.AssertEqual(t, len(wrapper.queue), 1) + itemId := wrapper.queue[0].Id + wrapper.queueMutex.Unlock() + + // sleep another 5 seconds to let the goroutine handle the task + fmt.Println("Sleeping 5 seconds to let the goroutine handle the task") + time.Sleep(5 * time.Second) + + wrapper.queueMutex.Lock() + expression.AssertEqual(t, len(wrapper.queue), 0) + wrapper.queueMutex.Unlock() + + worker.Mutex.Lock() + expression.AssertEqual(t, worker.Available, true) + worker.Mutex.Unlock() + + time.Sleep(50 * time.Millisecond) + + item, ok := wrapper.FindArchive(itemId, user) + expression.AssertEqual(t, ok, true) + expression.AssertEqual(t, item.Status, StatusError) // could not create file +} + +func TestCheckForExpiredFiles(t *testing.T) { + // first test : check if files are deleted + dname, err := os.MkdirTemp("", "export") + if err != nil { + t.Error(err) + t.FailNow() + } + defer os.RemoveAll(dname) + + // create a file that is 2 days old + file, err := os.CreateTemp(dname, "export") + if err != nil { + t.Error(err) + t.FailNow() + } + file1Name := file.Name() + _ = file.Close() + err = os.Chtimes(file1Name, time.Now().AddDate(0, 0, -2), time.Now().AddDate(0, 0, -2)) + if err != nil { + t.Error(err) + t.FailNow() + } + + // create a freshly created file + file2, err := os.CreateTemp(dname, "export") + if err != nil { + t.Error(err) + t.FailNow() + } + file2Name := file2.Name() + _ = file2.Close() + + wrapper := NewWrapper(dname, 1, 1, 1) + err = wrapper.checkForExpiredFiles() + if err != nil { + t.Error(err) + t.FailNow() + } + + // check that the file has been deleted + _, err = os.Stat(file1Name) + if !os.IsNotExist(err) { + t.Error("File1 should have been deleted") + t.FailNow() + } + + _, err = os.Stat(file2Name) + if os.IsNotExist(err) { + t.Error("File2 should not have been deleted") + t.FailNow() + } + + // second test : check if expired exports are deleted + goodDate := time.Now() + id1 := uuid.New() + id2 := uuid.New() + wrapper.archive.Store(id1, WrapperItem{Date: time.Now().AddDate(0, 0, -2)}) + wrapper.archive.Store(id2, WrapperItem{Date: goodDate}) + + _, found := wrapper.archive.Load(id1) + expression.AssertEqual(t, found, true) + _, found = wrapper.archive.Load(id2) + expression.AssertEqual(t, found, true) + + err = wrapper.checkForExpiredFiles() + if err != nil { + t.Error(err) + t.FailNow() + } + + _, found = wrapper.archive.Load(id1) + expression.AssertEqual(t, found, false) + _, found = wrapper.archive.Load(id2) + expression.AssertEqual(t, found, true) +} + +func TestWrapper_GetUserExports(t *testing.T) { + wrapper := NewWrapper("/tmp", 1, 1, 2) + user1 := users.User{Login: "bla"} + user2 := users.User{Login: "blabla"} + item1 := NewWrapperItem([]engine.Fact{{ID: 1}}, "test.txt", CSVParameters{}, user1) + item2 := NewWrapperItem([]engine.Fact{{ID: 2}}, "test.txt", CSVParameters{}, user1) + item3 := NewWrapperItem([]engine.Fact{{ID: 3}}, "test.txt", CSVParameters{}, user1) + item4 := NewWrapperItem([]engine.Fact{{ID: 4}}, "test.txt", CSVParameters{}, user2) + wrapper.archive.Store(item1.Id, *item1) + wrapper.archive.Store(item2.Id, *item2) + wrapper.archive.Store(item3.Id, *item3) + wrapper.archive.Store(item4.Id, *item4) + wrapper.AddToQueue([]engine.Fact{{ID: 5}}, "test.txt", CSVParameters{}, user1) + wrapper.AddToQueue([]engine.Fact{{ID: 6}}, "test.txt", CSVParameters{}, user2) + exports := wrapper.GetUserExports(user1) + expression.AssertEqual(t, len(exports), 4) + exports = wrapper.GetUserExports(user2) + expression.AssertEqual(t, len(exports), 2) +} + +func TestWrapper_DequeueWrapperItem(t *testing.T) { + wrapper := NewWrapper("/tmp", 1, 1, 2) + i, ok := wrapper.dequeueWrapperItem(&WrapperItem{}) + expression.AssertEqual(t, ok, false) + expression.AssertEqual(t, i, 0) + wrapper.AddToQueue([]engine.Fact{{ID: 5}}, "test.txt", CSVParameters{}, users.User{Login: "bla"}) + wrapper.AddToQueue([]engine.Fact{{ID: 6}}, "test.txt", CSVParameters{}, users.User{Login: "blabla"}) + + expression.AssertEqual(t, len(wrapper.queue), 2) + item1 := wrapper.queue[0] + item2 := wrapper.queue[1] + + i, ok = wrapper.dequeueWrapperItem(item1) + expression.AssertEqual(t, ok, true) + expression.AssertEqual(t, i, 1) + + i, ok = wrapper.dequeueWrapperItem(item2) + expression.AssertEqual(t, ok, true) + expression.AssertEqual(t, i, 0) +} + +func TestWrapper_dispatchExportQueue(t *testing.T) { + // we don't want that the worker try to export data, therefore we will create a temporary directory with a temp file + // so that the worker will not be able to create the file and will return an error + dname, err := os.MkdirTemp("", "exportdispatcher") + if err != nil { + t.Error(err) + t.FailNow() + } + defer os.RemoveAll(dname) + + // create a file that is 2 days old + file, err := os.CreateTemp(dname, "exportdispatcher") + if err != nil { + t.Error(err) + t.FailNow() + } + fileName := filepath.Base(file.Name()) + _ = file.Close() + + notifier.ReplaceGlobals(notifier.NewNotifier()) // for notifications + wrapper := NewWrapper(dname, 1, 1, 2) + ctx, cancel := context.WithCancel(context.Background()) + wrapper.Init(ctx) + cancel() // stop dispatcher since we don't want him to interact with the workers or the queue + + // wait until dispatcher stops + time.Sleep(50 * time.Millisecond) + + expression.AssertEqual(t, len(wrapper.workers), 1) + worker := wrapper.workers[0] + + // no items in queue -> nothing should happen + expression.AssertEqual(t, worker.IsAvailable(), true) + wrapper.dispatchExportQueue(context.Background()) + expression.AssertEqual(t, worker.IsAvailable(), true, "worker should still be available, because no items in queue") + + // we add an item to the queue + wrapper.AddToQueue([]engine.Fact{{ID: 1}}, fileName, CSVParameters{}, users.User{Login: "test"}) + + // we test if dispatchExportQueue will not dispatch the item, no worker available + worker.SwapAvailable(false) + + wrapper.dispatchExportQueue(context.Background()) + + // the item should still be in the queue + wrapper.queueMutex.Lock() + expression.AssertEqual(t, len(wrapper.queue), 1, "item should still be in the queue, since no worker is available") + wrapper.queueMutex.Unlock() + + // we test if dispatchExportQueue will dispatch the item, worker is now set to available + expression.AssertEqual(t, worker.SwapAvailable(true), false) + + wrapper.dispatchExportQueue(context.Background()) + + expression.AssertEqual(t, worker.IsAvailable(), false, "worker should not be available, because it is working on an item") + expression.AssertEqual(t, len(wrapper.queue), 0) + + // wait until worker has finished + time.Sleep(1 * time.Second) + + worker.Mutex.Lock() + defer worker.Mutex.Unlock() + + expression.AssertEqual(t, worker.QueueItem.Status, StatusError, fmt.Sprintf("worker processed item should have StatusError(%d) because the file already exists", StatusError)) // could not create file +} + +func TestWrapper_FindArchive(t *testing.T) { + wrapper := NewWrapper("/tmp", 1, 1, 2) + item := NewWrapperItem([]engine.Fact{{ID: 1}}, "test.txt", CSVParameters{}, users.User{Login: "bla"}) + wrapper.archive.Store(item.Id, *item) + + // testing with non-existing item in archive + _, ok := wrapper.FindArchive("test", users.User{Login: "bla"}) + expression.AssertEqual(t, ok, false) + + // testing with existing item but not good user in archive + _, ok = wrapper.FindArchive("test", users.User{Login: "blabla"}) + expression.AssertEqual(t, ok, false) + + // testing with existing item in archive + _, ok = wrapper.FindArchive(item.Id, users.User{Login: "bla"}) + expression.AssertEqual(t, ok, true) +} + +func TestWrapper_ContainsUser(t *testing.T) { + item := NewWrapperItem([]engine.Fact{{ID: 1}}, "test.txt", CSVParameters{}, users.User{Login: "bla"}) + expression.AssertEqual(t, item.ContainsUser(users.User{Login: "bla"}), true) + expression.AssertEqual(t, item.ContainsUser(users.User{Login: "blabla"}), false) +} + +func TestWrapper_DeleteExport(t *testing.T) { + wrapper := NewWrapper("/tmp", 1, 1, 2) + item := NewWrapperItem([]engine.Fact{{ID: 1}}, "test.txt", CSVParameters{}, users.User{Login: "bla"}) + + // test archive + wrapper.archive.Store(item.Id, *item) + expression.AssertEqual(t, wrapper.DeleteExport(item.Id, users.User{Login: "bla"}), DeleteExportDeleted, "item should have been deleted") + _, ok := wrapper.archive.Load(item.Id) + expression.AssertEqual(t, ok, false, "item should not be in archive anymore") + + // test archive multi-user + item.Users = []string{"bla", "blabla"} + wrapper.archive.Store(item.Id, *item) + expression.AssertEqual(t, wrapper.DeleteExport(item.Id, users.User{Login: "bla"}), DeleteExportUserDeleted, "user should have been deleted from existing export") + _, ok = wrapper.archive.Load(item.Id) + expression.AssertEqual(t, ok, true, "item should be in archive") + item.Users = []string{"bla"} + + // test queue + queueItem, code := wrapper.AddToQueue([]engine.Fact{{ID: 1}}, "test.txt", CSVParameters{}, users.User{Login: "bla"}) + expression.AssertEqual(t, code, CodeAdded, "item should have been added to queue") + wrapper.queueMutex.Lock() + expression.AssertEqual(t, len(wrapper.queue), 1, "item should be in queue") + wrapper.queueMutex.Unlock() + expression.AssertEqual(t, wrapper.DeleteExport(queueItem.Id, users.User{Login: "bla"}), DeleteExportDeleted, "item should have been deleted") + wrapper.queueMutex.Lock() + expression.AssertEqual(t, len(wrapper.queue), 0, "item should not be in queue anymore") + wrapper.queueMutex.Unlock() + + // test queue multi-user + queueItem, code = wrapper.AddToQueue([]engine.Fact{{ID: 1}}, "test.txt", CSVParameters{}, users.User{Login: "bla"}) + expression.AssertEqual(t, code, CodeAdded, "item should have been added to queue") + _, code = wrapper.AddToQueue([]engine.Fact{{ID: 1}}, "test.txt", CSVParameters{}, users.User{Login: "blabla"}) + expression.AssertEqual(t, code, CodeUserAdded, "user should have been added to existing item in queue") + wrapper.queueMutex.Lock() + expression.AssertEqual(t, len(wrapper.queue), 1, "item should be in queue") + wrapper.queueMutex.Unlock() + expression.AssertEqual(t, wrapper.DeleteExport(queueItem.Id, users.User{Login: "bla"}), DeleteExportUserDeleted, "user should have been deleted from existing export") + wrapper.queueMutex.Lock() + expression.AssertEqual(t, len(wrapper.queue), 1, "item should be in queue") + wrapper.queueMutex.Unlock() + + // test workers + item.Users = []string{"bla", "blibli"} + worker := NewExportWorker(0, "/tmp", make(chan<- int)) + wrapper.workers = append(wrapper.workers, worker) + worker.Mutex.Lock() + worker.QueueItem = *item + worker.Available = true + worker.Mutex.Unlock() + expression.AssertEqual(t, wrapper.DeleteExport(item.Id, users.User{Login: "bla"}), DeleteExportNotFound, "item should have not been deleted") + worker.SwapAvailable(false) + expression.AssertEqual(t, wrapper.DeleteExport(item.Id, users.User{Login: "blibli"}), DeleteExportUserDeleted, "user should have been deleted from export") + expression.AssertEqual(t, len(worker.Cancel), 0, "worker cancel channel should not have been filled") + expression.AssertEqual(t, wrapper.DeleteExport(item.Id, users.User{Login: "bla"}), DeleteExportCanceled, "item should have been deleted") + expression.AssertEqual(t, len(worker.Cancel), 1, "worker cancel channel should have been filled") + + // clean cancel channel (non-blocking) + worker.DrainCancelChannel() + worker.Mutex.Lock() + worker.QueueItem.Users = []string{"bla", "blabla"} + worker.Mutex.Unlock() + expression.AssertEqual(t, wrapper.DeleteExport(item.Id, users.User{Login: "bla"}), DeleteExportNotFound, "user should have been deleted from existing export") + expression.AssertEqual(t, len(worker.Cancel), 0, "worker cancel channel should not have been filled") +} + +func TestWrapper_GetUserExport(t *testing.T) { + wrapper := NewWrapper("/tmp", 1, 1, 2) + item := NewWrapperItem([]engine.Fact{{ID: 1}}, "test.txt", CSVParameters{}, users.User{Login: "bla"}) + + // test item in archive + wrapper.archive.Store(item.Id, *item) + export, ok := wrapper.GetUserExport(item.Id, users.User{Login: "bla"}) + expression.AssertEqual(t, ok, true) + expression.AssertEqual(t, export.Id, item.Id) + export, ok = wrapper.GetUserExport(item.Id, users.User{Login: "blabla"}) + expression.AssertEqual(t, ok, false) + wrapper.archive.Delete(item.Id) + + // test item in queue queue + queueItem, code := wrapper.AddToQueue([]engine.Fact{{ID: 1}}, "test.txt", CSVParameters{}, users.User{Login: "bla"}) + expression.AssertEqual(t, code, CodeAdded, "item should have been added to queue") + export, ok = wrapper.GetUserExport(queueItem.Id, users.User{Login: "bla"}) + expression.AssertEqual(t, ok, true) + expression.AssertEqual(t, export.Id, queueItem.Id) + _, ok = wrapper.GetUserExport(queueItem.Id, users.User{Login: "blabla"}) + expression.AssertEqual(t, ok, false) + _, ok = wrapper.dequeueWrapperItem(&export) + expression.AssertEqual(t, ok, true) + + // test worker + worker := NewExportWorker(0, "/tmp", make(chan<- int)) + wrapper.workers = append(wrapper.workers, worker) + worker.Mutex.Lock() + worker.QueueItem = *item + worker.Available = false + worker.Mutex.Unlock() + + _, ok = wrapper.GetUserExport(item.Id, users.User{Login: "blabla"}) + expression.AssertEqual(t, ok, false) + _, ok = wrapper.GetUserExport(item.Id, users.User{Login: "bla"}) + expression.AssertEqual(t, ok, true) +} diff --git a/internals/handlers/export_handlers.go b/internals/handlers/export_handlers.go index 662eb00d..be9b44b0 100644 --- a/internals/handlers/export_handlers.go +++ b/internals/handlers/export_handlers.go @@ -2,106 +2,75 @@ package handlers import ( "context" + "encoding/json" "errors" - "github.com/myrteametrics/myrtea-sdk/v4/engine" - "net/http" - "strconv" - "strings" - "sync" - "time" - "unicode/utf8" - + "fmt" "github.com/go-chi/chi/v5" "github.com/myrteametrics/myrtea-engine-api/v5/internals/export" - "github.com/myrteametrics/myrtea-engine-api/v5/internals/fact" "github.com/myrteametrics/myrtea-engine-api/v5/internals/handlers/render" "github.com/myrteametrics/myrtea-engine-api/v5/internals/security/permissions" "go.uber.org/zap" + "net/http" + "net/url" + "path/filepath" + "strconv" + "sync" ) -type CSVParameters struct { - columns []string - columnsLabel []string - formatColumnsData map[string]string - separator rune - limit int64 - chunkSize int64 +type ExportHandler struct { + exportWrapper *export.Wrapper + directDownload bool + indirectDownloadUrl string } -// ExportFact godoc -// @Summary Export facts -// @Description Get all action definitions -// @Tags ExportFact +// NewExportHandler returns a new ExportHandler +func NewExportHandler(exportWrapper *export.Wrapper, directDownload bool, indirectDownloadUrl string) *ExportHandler { + return &ExportHandler{ + exportWrapper: exportWrapper, + directDownload: directDownload, + indirectDownloadUrl: indirectDownloadUrl, + } +} + +// ExportRequest represents a request for an export +type ExportRequest struct { + export.CSVParameters + FactIDs []int64 `json:"factIDs"` + Title string `json:"title"` +} + +// ExportFactStreamed godoc +// @Summary CSV streamed export facts in chunks +// @Description CSV Streamed export for facts in chunks +// @Tags ExportFactStreamed // @Produce octet-stream // @Security Bearer +// @Param request body handlers.ExportRequest true "request (json)" // @Success 200 {file} Returns data to be saved into a file // @Failure 500 "internal server error" -// @Router /engine/export/facts/{id} [get] -func ExportFact(w http.ResponseWriter, r *http.Request) { - id := chi.URLParam(r, "id") - - idFact, err := strconv.ParseInt(id, 10, 64) - - if err != nil { - zap.L().Warn("Error on parsing fact id", zap.String("idFact", id), zap.Error(err)) - render.Error(w, r, render.ErrAPIParsingInteger, err) - return - } - - userCtx, _ := GetUserFromContext(r) // TODO: set the right permission - if !userCtx.HasPermission(permissions.New(permissions.TypeFact, strconv.FormatInt(idFact, 10), permissions.ActionGet)) { +// @Router /engine/facts/streamedexport [post] +func ExportFactStreamed(w http.ResponseWriter, r *http.Request) { + userCtx, _ := GetUserFromContext(r) + if !userCtx.HasPermission(permissions.New(permissions.TypeExport, permissions.All, permissions.ActionGet)) { render.Error(w, r, render.ErrAPISecurityNoPermissions, errors.New("missing permission")) return } - f, found, err := fact.R().Get(idFact) + var request ExportRequest + err := json.NewDecoder(r.Body).Decode(&request) if err != nil { - zap.L().Error("Cannot retrieve fact", zap.Int64("factID", idFact), zap.Error(err)) - render.Error(w, r, render.ErrAPIDBSelectFailed, err) - return - } - if !found { - zap.L().Warn("fact does not exist", zap.Int64("factID", idFact)) - render.Error(w, r, render.ErrAPIDBResourceNotFound, err) + zap.L().Warn("Decode export request json", zap.Error(err)) + render.Error(w, r, render.ErrAPIDecodeJSONBody, err) return } - var filename = r.URL.Query().Get("fileName") - if filename == "" { - filename = f.Name + "_export_" + time.Now().Format("02_01_2006_15-04") + ".csv" - } - - // suppose that type is csv - params := GetCSVParameters(r) - - var combineFacts []engine.Fact - combineFacts = append(combineFacts, f) - - // export multiple facts into one file - combineFactIds, err := QueryParamToOptionalInt64Array(r, "combineFactIds", ",", false, []int64{}) - if err != nil { - zap.L().Warn("Could not parse parameter combineFactIds", zap.Error(err)) - } else { - for _, factId := range combineFactIds { - // no duplicates - if factId == idFact { - continue - } - - combineFact, found, err := fact.R().Get(factId) - if err != nil { - zap.L().Error("Export combineFact cannot retrieve fact", zap.Int64("factID", factId), zap.Error(err)) - continue - } - if !found { - zap.L().Warn("Export combineFact fact does not exist", zap.Int64("factID", factId)) - continue - } - combineFacts = append(combineFacts, combineFact) - } + if len(request.FactIDs) == 0 { + zap.L().Warn("Missing factIDs in export request") + render.Error(w, r, render.ErrAPIMissingParam, errors.New("missing factIDs")) + return } - err = HandleStreamedExport(r.Context(), w, combineFacts, filename, params) + err = handleStreamedExport(r.Context(), w, request) if err != nil { render.Error(w, r, render.ErrAPIProcessError, err) } @@ -109,49 +78,19 @@ func ExportFact(w http.ResponseWriter, r *http.Request) { } -func GetCSVParameters(r *http.Request) CSVParameters { - result := CSVParameters{separator: ','} - - limit, err := QueryParamToOptionalInt64(r, "limit", -1) - if err != nil { - result.limit = -1 - } else { - result.limit = limit - } - - result.columns = QueryParamToOptionalStringArray(r, "columns", ",", []string{}) - result.columnsLabel = QueryParamToOptionalStringArray(r, "columnsLabel", ",", []string{}) - - formatColumnsData := QueryParamToOptionalStringArray(r, "formateColumns", ",", []string{}) - result.formatColumnsData = make(map[string]string) - for _, formatData := range formatColumnsData { - parts := strings.Split(formatData, ";") - if len(parts) != 2 { - continue - } - key := strings.TrimSpace(parts[0]) - result.formatColumnsData[key] = parts[1] - } - separator := r.URL.Query().Get("separator") - if separator != "" { - sep, size := utf8.DecodeRuneInString(separator) - if size != 1 { - result.separator = ',' - } else { - result.separator = sep - } - } - - return result -} - -// HandleStreamedExport actually only handles CSV -func HandleStreamedExport(requestContext context.Context, w http.ResponseWriter, facts []engine.Fact, fileName string, params CSVParameters) error { +// handleStreamedExport actually only handles CSV +func handleStreamedExport(requestContext context.Context, w http.ResponseWriter, request ExportRequest) error { w.Header().Set("Connection", "Keep-Alive") w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("X-Content-Type-Options", "nosniff") - w.Header().Set("Content-Disposition", "attachment; filename="+strconv.Quote(fileName)) + w.Header().Set("Content-Disposition", "attachment; filename="+strconv.Quote(request.Title+".csv")) w.Header().Set("Content-Type", "application/octet-stream") + + facts := findCombineFacts(request.FactIDs) + if len(facts) == 0 { + return errors.New("no fact found") + } + streamedExport := export.NewStreamedExport() var wg sync.WaitGroup @@ -182,7 +121,7 @@ func HandleStreamedExport(requestContext context.Context, w http.ResponseWriter, defer close(streamedExport.Data) for _, f := range facts { - writerErr = streamedExport.StreamedExportFactHitsFull(ctx, f, params.limit) + writerErr = streamedExport.StreamedExportFactHitsFull(ctx, f, request.Limit) if writerErr != nil { zap.L().Error("Error during export (StreamedExportFactHitsFullV8)", zap.Error(err)) break // break here when error occurs? @@ -195,7 +134,6 @@ func HandleStreamedExport(requestContext context.Context, w http.ResponseWriter, go func() { defer wg.Done() first := true - labels := params.columnsLabel for { select { @@ -204,7 +142,7 @@ func HandleStreamedExport(requestContext context.Context, w http.ResponseWriter, return } - data, err := export.ConvertHitsToCSV(hits, params.columns, labels, params.formatColumnsData, params.separator) + data, err := export.ConvertHitsToCSV(hits, request.CSVParameters, first) if err != nil { zap.L().Error("ConvertHitsToCSV error during export (StreamedExportFactHitsFullV8)", zap.Error(err)) @@ -224,7 +162,6 @@ func HandleStreamedExport(requestContext context.Context, w http.ResponseWriter, if first { first = false - labels = []string{} } case <-requestContext.Done(): @@ -245,3 +182,212 @@ func HandleStreamedExport(requestContext context.Context, w http.ResponseWriter, return err } + +// GetExports godoc +// @Summary Get user exports +// @Description Get in memory user exports +// @Produce json +// @Security Bearer +// @Success 200 {array} export.WrapperItem "Returns a list of exports" +// @Failure 403 "Status Forbidden: missing permission" +// @Failure 500 "internal server error" +// @Router /engine/exports [get] +func (e *ExportHandler) GetExports(w http.ResponseWriter, r *http.Request) { + userCtx, _ := GetUserFromContext(r) + if !userCtx.HasPermission(permissions.New(permissions.TypeExport, permissions.All, permissions.ActionList)) { + render.Error(w, r, render.ErrAPISecurityNoPermissions, errors.New("missing permission")) + return + } + render.JSON(w, r, e.exportWrapper.GetUserExports(userCtx.User)) +} + +// GetExport godoc +// @Summary Get single export from user +// @Description Get single export from user +// @Tags Exports +// @Produce json +// @Security Bearer +// @Success 200 {object} export.WrapperItem "Status OK" +// @Failure 400 "Bad Request: missing export id" +// @Failure 403 "Status Forbidden: missing permission" +// @Failure 404 "Status Not Found: export not found" +// @Failure 500 "internal server error" +// @Router /engine/exports/{id} [get] +func (e *ExportHandler) GetExport(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + if id == "" { + render.Error(w, r, render.ErrAPIMissingParam, errors.New("missing id")) + return + } + + userCtx, _ := GetUserFromContext(r) + if !userCtx.HasPermission(permissions.New(permissions.TypeExport, permissions.All, permissions.ActionGet)) { + render.Error(w, r, render.ErrAPISecurityNoPermissions, errors.New("missing permission")) + return + } + + item, ok := e.exportWrapper.GetUserExport(id, userCtx.User) + if !ok { + render.Error(w, r, render.ErrAPIDBResourceNotFound, errors.New("export not found")) + return + } + + render.JSON(w, r, item) +} + +// DeleteExport godoc +// @Summary Deletes a single export +// @Description Deletes a single export, when running it is canceled +// @Tags Exports +// @Produce json +// @Security Bearer +// @Success 202 "Status Accepted: export found & cancellation request has been taken into account & will be processed" +// @Success 204 "Status OK: export was found and deleted" +// @Failure 400 "Bad Request: missing export id" +// @Failure 403 "Status Forbidden: missing permission" +// @Failure 404 "Status Not Found: export not found" +// @Failure 500 "internal server error" +// @Router /engine/exports/{id} [delete] +func (e *ExportHandler) DeleteExport(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + if id == "" { + render.Error(w, r, render.ErrAPIMissingParam, errors.New("missing id")) + return + } + + userCtx, _ := GetUserFromContext(r) + if !userCtx.HasPermission(permissions.New(permissions.TypeExport, permissions.All, permissions.ActionDelete)) { + render.Error(w, r, render.ErrAPISecurityNoPermissions, errors.New("missing permission")) + return + } + + status := e.exportWrapper.DeleteExport(id, userCtx.User) + + switch status { + case export.DeleteExportDeleted: + fallthrough + case export.DeleteExportUserDeleted: + w.WriteHeader(http.StatusNoContent) + case export.DeleteExportCanceled: + w.WriteHeader(http.StatusAccepted) + default: + render.Error(w, r, render.ErrAPIDBResourceNotFound, errors.New("export not found")) + } + +} + +// ExportFact godoc +// @Summary Creates a new export request for a fact (or multiple facts) +// @Description Creates a new export request for a fact (or multiple facts) +// @Tags Exports +// @Produce json +// @Security Bearer +// @Param request body handlers.ExportRequest true "request (json)" +// @Success 200 {object} export.WrapperItem "Status OK: user was added to existing export in queue" +// @Success 201 {object} export.WrapperItem "Status Created: new export was added in queue" +// @Failure 400 "Bad Request: missing fact id / fact id is not an integer" +// @Failure 403 "Status Forbidden: missing permission" +// @Failure 409 {object} export.WrapperItem "Status Conflict: user already exists in export queue" +// @Failure 429 "Status Too Many Requests: export queue is full" +// @Failure 500 "internal server error" +// @Router /engine/exports/fact [post] +func (e *ExportHandler) ExportFact(w http.ResponseWriter, r *http.Request) { + userCtx, _ := GetUserFromContext(r) + if !userCtx.HasPermission(permissions.New(permissions.TypeExport, permissions.All, permissions.ActionCreate)) { + render.Error(w, r, render.ErrAPISecurityNoPermissions, errors.New("missing permission")) + return + } + + var request ExportRequest + err := json.NewDecoder(r.Body).Decode(&request) + if err != nil { + zap.L().Warn("Decode export request json", zap.Error(err)) + render.Error(w, r, render.ErrAPIDecodeJSONBody, err) + return + } + + if len(request.FactIDs) == 0 { + zap.L().Warn("Missing factIDs in export request") + render.Error(w, r, render.ErrAPIMissingParam, errors.New("missing factIDs")) + return + } + + if len(request.Title) == 0 { + zap.L().Warn("Missing title (len is 0) in export request") + render.Error(w, r, render.ErrAPIMissingParam, errors.New("missing title (len is 0)")) + return + } + + facts := findCombineFacts(request.FactIDs) + if len(facts) == 0 { + zap.L().Warn("No fact was found in export request") + render.Error(w, r, render.ErrAPIDBResourceNotFound, errors.New("no fact was found in export request")) + return + } + + item, status := e.exportWrapper.AddToQueue(facts, request.Title, request.CSVParameters, userCtx.User) + + switch status { + case export.CodeAdded: + w.WriteHeader(http.StatusCreated) + case export.CodeUserAdded: + w.WriteHeader(http.StatusOK) + case export.CodeUserExists: + w.WriteHeader(http.StatusConflict) + case export.CodeQueueFull: + render.Error(w, r, render.ErrAPIQueueFull, fmt.Errorf("export queue is full")) + return + default: + render.Error(w, r, render.ErrAPIProcessError, fmt.Errorf("unknown status code (%d)", status)) + return + } + + render.JSON(w, r, item) +} + +// DownloadExport godoc +// @Summary Download export +// @Description Download export +// @Tags Exports +// @Produce json +// @Security Bearer +// @Success 200 {file} Returns data to be saved into a file +// @Success 308 Redirects to the export file location +// @Failure 400 "Bad Request: missing export id" +// @Failure 403 "Status Forbidden: missing permission" +// @Failure 404 "Status Not Found: export not found" +// @Failure 500 "internal server error" +// @Router /engine/exports/{id}/download [get] +func (e *ExportHandler) DownloadExport(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + if id == "" { + render.Error(w, r, render.ErrAPIMissingParam, errors.New("missing id")) + return + } + + userCtx, _ := GetUserFromContext(r) + if !userCtx.HasPermission(permissions.New(permissions.TypeExport, permissions.All, permissions.ActionGet)) { + render.Error(w, r, render.ErrAPISecurityNoPermissions, errors.New("missing permission")) + return + } + + item, ok := e.exportWrapper.GetUserExport(id, userCtx.User) + if !ok { + render.Error(w, r, render.ErrAPIDBResourceNotFound, errors.New("export not found")) + return + } + + if e.directDownload { + path := filepath.Join(e.exportWrapper.BasePath, item.FileName) + render.StreamFile(path, item.FileName, w, r) + return + } + + path, err := url.JoinPath(e.indirectDownloadUrl, item.FileName) + if err != nil { + render.Error(w, r, render.ErrAPIProcessError, err) + return + } + + http.Redirect(w, r, path, http.StatusPermanentRedirect) +} diff --git a/internals/handlers/notification_handlers.go b/internals/handlers/notification_handlers.go index 6e9ab55c..d7b4290e 100644 --- a/internals/handlers/notification_handlers.go +++ b/internals/handlers/notification_handlers.go @@ -1,17 +1,15 @@ package handlers import ( - "net/http" - "strconv" - "github.com/go-chi/chi/v5" - "github.com/google/uuid" "github.com/myrteametrics/myrtea-engine-api/v5/internals/dbutils" "github.com/myrteametrics/myrtea-engine-api/v5/internals/handlers/render" "github.com/myrteametrics/myrtea-engine-api/v5/internals/models" "github.com/myrteametrics/myrtea-engine-api/v5/internals/notifier/notification" "github.com/myrteametrics/myrtea-engine-api/v5/internals/security/users" "go.uber.org/zap" + "net/http" + "strconv" ) // GetNotifications godoc @@ -23,7 +21,6 @@ import ( // @Param nhit query int false "Hit per page" // @Param offset query int false "Offset number for pagination" // @Security Bearer -// @Success 200 {array} notification.FrontNotification "list of notifications" // @Failure 500 "internal server error" // @Router /engine/notifications [get] func GetNotifications(w http.ResponseWriter, r *http.Request) { @@ -57,17 +54,12 @@ func GetNotifications(w http.ResponseWriter, r *http.Request) { } user := _user.(users.UserWithPermissions) - roleIDs := make([]uuid.UUID, 0) - for _, role := range user.Roles { - roleIDs = append(roleIDs, role.ID) - } - - queryOptionnal := dbutils.DBQueryOptionnal{ + queryOptional := dbutils.DBQueryOptionnal{ Limit: nhit, Offset: offset, MaxAge: maxAge, } - notifications, err := notification.R().GetAll(queryOptionnal) + notifications, err := notification.R().GetAll(queryOptional, user.Login) if err != nil { zap.L().Error("Error getting notifications", zap.Error(err)) render.Error(w, r, render.ErrAPIDBSelectFailed, err) @@ -100,11 +92,18 @@ func UpdateRead(w http.ResponseWriter, r *http.Request) { _status := r.URL.Query().Get("status") status := false + _user := r.Context().Value(models.ContextKeyUser) + if _user == nil { + zap.L().Warn("No context user provided") + return + } + user := _user.(users.UserWithPermissions) + if _status == "true" { status = true } - err = notification.R().UpdateRead(idNotif, status) + err = notification.R().UpdateRead(idNotif, status, user.Login) if err != nil { zap.L().Error("Error while updating notifications", zap.Error(err)) render.Error(w, r, render.ErrAPIDBUpdateFailed, err) diff --git a/internals/handlers/notifier_handlers.go b/internals/handlers/notifier_handlers.go index 7936954e..bf24c300 100644 --- a/internals/handlers/notifier_handlers.go +++ b/internals/handlers/notifier_handlers.go @@ -1,13 +1,12 @@ package handlers import ( - "net/http" - "github.com/myrteametrics/myrtea-engine-api/v5/internals/handlers/render" "github.com/myrteametrics/myrtea-engine-api/v5/internals/models" "github.com/myrteametrics/myrtea-engine-api/v5/internals/notifier" "github.com/myrteametrics/myrtea-engine-api/v5/internals/security/users" "go.uber.org/zap" + "net/http" ) // NotificationsWSRegister godoc @@ -42,6 +41,20 @@ func NotificationsWSRegister(w http.ResponseWriter, r *http.Request) { zap.L().Error("Add new WS Client to manager", zap.Error(err)) return } + //go func(client *notifier.WebsocketClient) { // temporary for tests + // zap.L().Info("starting notifier") + // ticker := time.NewTicker(1 * time.Second) + // after := time.After(30 * time.Second) + // for { + // select { + // case <-ticker.C: + // notifier.C().SendToUsers(notification.ExportNotification{Status: export.StatusPending, Export: export.WrapperItem{Id: uuid.New().String(), Title: "test.bla"}}, []users.UserWithPermissions{user}) + // zap.L().Info("send notification") + // case <-after: + // return + // } + // } + //}(client) go client.Write() // go client.Read() // Disabled until proper usage diff --git a/internals/handlers/oidc_handlers.go b/internals/handlers/oidc_handlers.go index d0066d81..8ce0c6bd 100644 --- a/internals/handlers/oidc_handlers.go +++ b/internals/handlers/oidc_handlers.go @@ -24,7 +24,7 @@ func HandleOIDCRedirect(w http.ResponseWriter, r *http.Request) { handleError(w, r, "", err, render.ErrAPIProcessError) return } - render.Redirect(w, r, instanceOidc.OidcConfig.AuthCodeURL(expectedState), http.StatusFound) + http.Redirect(w, r, instanceOidc.OidcConfig.AuthCodeURL(expectedState), http.StatusFound) } func HandleOIDCCallback(w http.ResponseWriter, r *http.Request) { @@ -64,5 +64,5 @@ func HandleOIDCCallback(w http.ResponseWriter, r *http.Request) { baseURL := viper.GetString("AUTHENTICATION_OIDC_FRONT_END_URL") redirectURL := fmt.Sprintf("%s/auth/oidc/callback?token=%s", baseURL, url.QueryEscape(rawIDToken)) - render.Redirect(w, r, redirectURL, http.StatusFound) + http.Redirect(w, r, redirectURL, http.StatusFound) } diff --git a/internals/handlers/processor_handlers.go b/internals/handlers/processor_handlers.go index 81a60c47..72726de9 100644 --- a/internals/handlers/processor_handlers.go +++ b/internals/handlers/processor_handlers.go @@ -3,15 +3,13 @@ package handlers import ( "encoding/json" "errors" - "github.com/myrteametrics/myrtea-engine-api/v5/internals/ingester" - "net/http" - "time" - "github.com/myrteametrics/myrtea-engine-api/v5/internals/handlers/render" + "github.com/myrteametrics/myrtea-engine-api/v5/internals/ingester" "github.com/myrteametrics/myrtea-engine-api/v5/internals/processor" "github.com/myrteametrics/myrtea-engine-api/v5/internals/scheduler" "github.com/myrteametrics/myrtea-sdk/v4/models" "go.uber.org/zap" + "net/http" ) // ProcessorHandler is a basic struct allowing to set up a single aggregateIngester instance for all handlers @@ -21,12 +19,8 @@ type ProcessorHandler struct { // NewProcessorHandler returns a pointer to an ProcessorHandler instance func NewProcessorHandler() *ProcessorHandler { - var aggregateIngester = ingester.NewAggregateIngester() - go aggregateIngester.Run() // Start ingester - time.Sleep(10 * time.Millisecond) // goroutine warm-up - return &ProcessorHandler{ - aggregateIngester: aggregateIngester, + aggregateIngester: ingester.NewAggregateIngester(), } } @@ -79,7 +73,7 @@ func PostObjects(w http.ResponseWriter, r *http.Request) { // @Success 200 "Status OK" // @Failure 429 "Processing queue is full please retry later" // @Failure 500 "internal server error" -// @Router /service/ingester [post] +// @Router /service/aggregates [post] func (handler *ProcessorHandler) PostAggregates(w http.ResponseWriter, r *http.Request) { var aggregates []scheduler.ExternalAggregate err := json.NewDecoder(r.Body).Decode(&aggregates) diff --git a/internals/handlers/render/render.go b/internals/handlers/render/render.go index ea031fff..f2c6f670 100644 --- a/internals/handlers/render/render.go +++ b/internals/handlers/render/render.go @@ -2,7 +2,10 @@ package render import ( "encoding/json" + "fmt" + "io" "net/http" + "os" "strconv" "github.com/go-chi/chi/v5/middleware" @@ -10,7 +13,7 @@ import ( "go.uber.org/zap" ) -// APIError wraps all informations required to investiguate a backend error +// APIError wraps all information required to investigate a backend error // It is mainly used to returns information to the API caller when the status is not 2xx. type APIError struct { RequestID string `json:"requestID"` @@ -44,6 +47,9 @@ var ( // ErrAPIResourceDuplicate must be used in case a duplicate resource has been identified ErrAPIResourceDuplicate = APIError{Status: http.StatusBadRequest, ErrType: "RessourceError", Code: 2002, Message: `Provided resource definition can be parsed, but is already exists`} + // ErrAPIQueueFull must be used in case an internal processing queue is full + ErrAPIQueueFull = APIError{Status: http.StatusServiceUnavailable, ErrType: "RessourceError", Code: 2003, Message: `The queue is full, please retry later`} + // ErrAPIDBResourceNotFound must be used in case a resource is not found in the backend storage system ErrAPIDBResourceNotFound = APIError{Status: http.StatusNotFound, ErrType: "RessourceError", Code: 3000, Message: `Ressource not found`} // ErrAPIDBSelectFailed must be used when a select query returns an error from the backend storage system @@ -159,11 +165,43 @@ func File(w http.ResponseWriter, filename string, data []byte) { } } -// Redirect is a helper function to redirect the user to a specified location -// -// func Redirect(w http.ResponseWriter, r *http.Request, location string, code int) { -// http.Redirect(w, r, location, code) -// } -func Redirect(w http.ResponseWriter, r *http.Request, location string, code int) { - http.Redirect(w, r, location, code) +// StreamFile handle files streamed response with allows the download of a file in chunks +func StreamFile(filePath, fileName string, w http.ResponseWriter, r *http.Request) { + file, err := os.Open(filePath) + if err != nil { + Error(w, r, ErrAPIDBResourceNotFound, fmt.Errorf("error opening file: %s", err)) + return + } + defer file.Close() + + // Set all necessary headers + w.Header().Set("Connection", "Keep-Alive") + w.Header().Set("Transfer-Encoding", "chunked") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("Content-Disposition", "attachment; filename="+strconv.Quote(fileName)) + w.Header().Set("Content-Type", "application/octet-stream") + + const bufferSize = 4096 + buffer := make([]byte, bufferSize) + + for { + // Read a chunk of the file + bytesRead, err := file.Read(buffer) + if err == io.EOF { + break + } else if err != nil { + Error(w, r, ErrAPIProcessError, fmt.Errorf("error reading file: %s", err)) + return + } + + // Write the chunk to the response writer + _, err = w.Write(buffer[:bytesRead]) + if err != nil { + // If writing to the response writer fails, log the error and stop streaming + Error(w, r, ErrAPIProcessError, fmt.Errorf("error writing to response writer: %s", err)) + break + } + + w.(http.Flusher).Flush() + } } diff --git a/internals/handlers/utils.go b/internals/handlers/utils.go index a163b132..27c6e9d8 100644 --- a/internals/handlers/utils.go +++ b/internals/handlers/utils.go @@ -6,6 +6,9 @@ import ( "crypto/rand" "encoding/base64" "fmt" + "github.com/myrteametrics/myrtea-engine-api/v5/internals/fact" + "github.com/myrteametrics/myrtea-engine-api/v5/internals/utils" + "github.com/myrteametrics/myrtea-sdk/v4/engine" "io" "regexp" "strconv" @@ -31,6 +34,7 @@ const ( parseGlobalVariables = false ) +// QueryParamToOptionalInt parse a string from a string func QueryParamToOptionalInt(r *http.Request, name string, orDefault int) (int, error) { param := r.URL.Query().Get(name) if param != "" { @@ -39,6 +43,7 @@ func QueryParamToOptionalInt(r *http.Request, name string, orDefault int) (int, return orDefault, nil } +// QueryParamToOptionalInt64 parse an int64 from a string func QueryParamToOptionalInt64(r *http.Request, name string, orDefault int64) (int64, error) { param := r.URL.Query().Get(name) if param != "" { @@ -47,6 +52,7 @@ func QueryParamToOptionalInt64(r *http.Request, name string, orDefault int64) (i return orDefault, nil } +// QueryParamToOptionalInt64Array parse multiple int64 entries separated by a separator from a string func QueryParamToOptionalInt64Array(r *http.Request, name, separator string, allowDuplicates bool, orDefault []int64) ([]int64, error) { param := r.URL.Query().Get(name) if param == "" { @@ -64,7 +70,7 @@ func QueryParamToOptionalInt64Array(r *http.Request, name, separator string, all } if !allowDuplicates { - return removeDuplicate(result), nil + return utils.RemoveDuplicates(result), nil } return result, nil @@ -217,25 +223,13 @@ func GetUserFromContext(r *http.Request) (users.UserWithPermissions, bool) { return user, true } -func removeDuplicate[T string | int | int64](sliceList []T) []T { - allKeys := make(map[T]bool) - var list []T - for _, item := range sliceList { - if _, value := allKeys[item]; !value { - allKeys[item] = true - list = append(list, item) - } - } - return list -} - // handleError is a helper function that logs the error and sends a response. func handleError(w http.ResponseWriter, r *http.Request, message string, err error, apiError render.APIError) { zap.L().Error(message, zap.Error(err)) render.Error(w, r, apiError, err) } -// Generate a State use by OIDC authentification +// generateRandomState Generate a State used by OIDC authentication func generateRandomState() (string, error) { b := make([]byte, 32) _, err := rand.Read(b) @@ -244,6 +238,8 @@ func generateRandomState() (string, error) { } return base64.StdEncoding.EncodeToString(b), nil } + +// generateEncryptedState Generate a State used by OIDC authentication func generateEncryptedState(key []byte) (string, error) { // Generate random state plainState, err := generateRandomState() @@ -269,6 +265,8 @@ func generateEncryptedState(key []byte) (string, error) { b64State := base64.StdEncoding.EncodeToString(ciphertext) return b64State, nil } + +// verifyEncryptedState Verify the State used by OIDC authentication func verifyEncryptedState(state string, key []byte) (string, error) { // Decode from base64 decodedState, err := base64.StdEncoding.DecodeString(state) @@ -292,3 +290,18 @@ func verifyEncryptedState(state string, key []byte) (string, error) { return string(decodedState), nil } + +// findCombineFacts returns the combine facts +func findCombineFacts(combineFactIds []int64) (combineFacts []engine.Fact) { + for _, factId := range utils.RemoveDuplicates(combineFactIds) { + combineFact, found, err := fact.R().Get(factId) + if err != nil { + continue + } + if !found { + continue + } + combineFacts = append(combineFacts, combineFact) + } + return combineFacts +} diff --git a/internals/handlers/utils_test.go b/internals/handlers/utils_test.go index 401a5992..d7489f6d 100644 --- a/internals/handlers/utils_test.go +++ b/internals/handlers/utils_test.go @@ -139,23 +139,6 @@ func TestQueryParamToOptionalInt64Array(t *testing.T) { } -func TestRemoveDuplicate(t *testing.T) { - sample := []int64{1, 1, 1, 2, 2, 3, 4} - expectedResult := []int64{1, 2, 3, 4} - result := removeDuplicate(sample) - - if len(result) != len(expectedResult) { - t.FailNow() - } - - for i := 0; i < len(expectedResult); i++ { - if expectedResult[i] != result[i] { - t.FailNow() - } - } - -} - func TestHandleError(t *testing.T) { // response writer and request w := httptest.NewRecorder() diff --git a/internals/handlers/variablesconfig_handlers.go b/internals/handlers/variablesconfig_handlers.go index 16805882..b5d6ecc2 100644 --- a/internals/handlers/variablesconfig_handlers.go +++ b/internals/handlers/variablesconfig_handlers.go @@ -12,7 +12,7 @@ import ( "go.uber.org/zap" ) -// GetVariablesConfigs godoc +// GetVariablesConfig godoc // @Summary Get all Variables Config definitions // @Description Get all VariableConfig definitions // @Tags VariablesConfig diff --git a/internals/ingester/aggregate.go b/internals/ingester/aggregate.go index 94831449..b8c4924a 100644 --- a/internals/ingester/aggregate.go +++ b/internals/ingester/aggregate.go @@ -13,8 +13,9 @@ import ( // AggregateIngester is a component which process scheduler.ExternalAggregate type AggregateIngester struct { - Data chan []scheduler.ExternalAggregate + data chan []scheduler.ExternalAggregate metricQueueGauge *stdprometheus.Gauge + running bool } var ( @@ -39,16 +40,17 @@ func _newRegisteredGauge() *stdprometheus.Gauge { // NewAggregateIngester returns a pointer to a new AggregateIngester instance func NewAggregateIngester() *AggregateIngester { return &AggregateIngester{ - Data: make(chan []scheduler.ExternalAggregate, viper.GetInt("AGGREGATEINGESTER_QUEUE_BUFFER_SIZE")), + data: make(chan []scheduler.ExternalAggregate, viper.GetInt("AGGREGATEINGESTER_QUEUE_BUFFER_SIZE")), metricQueueGauge: _aggregateIngesterGauge, + running: false, } } // Run is the main routine of a TypeIngester instance -func (ingester *AggregateIngester) Run() { +func (ai *AggregateIngester) Run() { zap.L().Info("Starting AggregateIngester") - for ir := range ingester.Data { + for ir := range ai.data { zap.L().Debug("Received ExternalAggregate", zap.Int("ExternalAggregate items count", len(ir))) err := HandleAggregates(ir) @@ -57,25 +59,31 @@ func (ingester *AggregateIngester) Run() { } // Update queue gauge - (*ingester.metricQueueGauge).Set(float64(len(ingester.Data))) + (*ai.metricQueueGauge).Set(float64(len(ai.data))) } } // Ingest process an array of scheduler.ExternalAggregate -func (ingester *AggregateIngester) Ingest(aggregates []scheduler.ExternalAggregate) error { - dataLen := len(ingester.Data) +func (ai *AggregateIngester) Ingest(aggregates []scheduler.ExternalAggregate) error { + dataLen := len(ai.data) + + // Start ingester if not running + if !ai.running { + go ai.Run() + ai.running = true + } zap.L().Debug("Ingesting data", zap.Any("aggregates", aggregates)) // Check for channel overloading - if dataLen+1 >= cap(ingester.Data) { + if dataLen+1 >= cap(ai.data) { zap.L().Debug("Buffered channel would be overloaded with incoming bulkIngestRequest") - (*ingester.metricQueueGauge).Set(float64(dataLen)) + (*ai.metricQueueGauge).Set(float64(dataLen)) return errors.New("channel overload") } - ingester.Data <- aggregates + ai.data <- aggregates return nil } diff --git a/internals/modeler/postgres_repository.go b/internals/modeler/postgres_repository.go index ac8fe71c..7e7b6d65 100644 --- a/internals/modeler/postgres_repository.go +++ b/internals/modeler/postgres_repository.go @@ -16,7 +16,7 @@ type PostgresRepository struct { conn *sqlx.DB } -//NewPostgresRepository returns a new instance of PostgresRepository +// NewPostgresRepository returns a new instance of PostgresRepository func NewPostgresRepository(dbClient *sqlx.DB) Repository { r := PostgresRepository{ conn: dbClient, @@ -174,7 +174,7 @@ func (r *PostgresRepository) Delete(id int64) error { // GetAll returns all models in the repository func (r *PostgresRepository) GetAll() (map[int64]modeler.Model, error) { - models := make(map[int64]modeler.Model, 0) + models := make(map[int64]modeler.Model) query := `SELECT id, definition FROM model_v1` rows, err := r.conn.Query(query) diff --git a/internals/notifier/manager.go b/internals/notifier/manager.go index 498c163d..aee6f0f1 100644 --- a/internals/notifier/manager.go +++ b/internals/notifier/manager.go @@ -12,7 +12,7 @@ type ClientManager struct { Clients map[Client]bool } -// NewClientManager renders a new manager responsible of every connection +// NewClientManager renders a new manager responsible for every connection func NewClientManager() *ClientManager { return &ClientManager{ Clients: make(map[Client]bool), diff --git a/internals/notifier/notification/handler.go b/internals/notifier/notification/handler.go new file mode 100644 index 00000000..55758710 --- /dev/null +++ b/internals/notifier/notification/handler.go @@ -0,0 +1,89 @@ +package notification + +import ( + "context" + "go.uber.org/zap" + "sync" + "time" +) + +var ( + _globalHandlerMu sync.RWMutex + _globalHandler *Handler +) + +// H is used to access the global notification handler singleton +func H() *Handler { + _globalHandlerMu.RLock() + defer _globalHandlerMu.RUnlock() + return _globalHandler +} + +// ReplaceHandlerGlobals affects a new repository to the global notification handler singleton +func ReplaceHandlerGlobals(handler *Handler) func() { + _globalHandlerMu.Lock() + defer _globalHandlerMu.Unlock() + + prev := _globalHandler + _globalHandler = handler + return func() { ReplaceHandlerGlobals(prev) } +} + +type Handler struct { + notificationTypes map[string]Notification + notificationLifetime time.Duration +} + +// NewHandler returns a pointer to a new instance of Handler +func NewHandler(notificationLifetime time.Duration) *Handler { + handler := &Handler{ + notificationTypes: make(map[string]Notification), + notificationLifetime: notificationLifetime, + } + + // useless to start cleaner if lifetime is less than 0 + if notificationLifetime > 0 { + go handler.startCleaner(context.Background()) + } else { + zap.L().Info("Notification cleaner will not be started", zap.Duration("notificationLifetime", notificationLifetime)) + } + + return handler +} + +// RegisterNotificationType register a new notification type +func (h *Handler) RegisterNotificationType(notification Notification) { + h.notificationTypes[getType(notification)] = notification +} + +// UnregisterNotificationType unregister a notification type +func (h *Handler) UnregisterNotificationType(notification Notification) { + delete(h.notificationTypes, getType(notification)) +} + +// GetNotificationByType gets notification interface by its type +func (h *Handler) GetNotificationByType(notificationType string) (notif Notification, ok bool) { + notif, ok = h.notificationTypes[notificationType] + return notif, ok +} + +// startCleaner start a ticker to clean expired notifications in database every 24 hours +func (h *Handler) startCleaner(context context.Context) { + cleanRate := time.Hour * 24 + zap.L().Info("Starting notification cleaner", zap.Duration("cleanRate", cleanRate), zap.Duration("notificationLifetime", h.notificationLifetime)) + ticker := time.NewTicker(cleanRate) + defer ticker.Stop() + for { + select { + case <-context.Done(): + return + case <-ticker.C: + affectedRows, err := R().CleanExpired(h.notificationLifetime) + if err != nil { + zap.L().Error("Error while cleaning expired notifications", zap.Error(err)) + } else { + zap.L().Debug("Cleaned expired notifications", zap.Int64("affectedRows", affectedRows)) + } + } + } +} diff --git a/internals/notifier/notification/handler_test.go b/internals/notifier/notification/handler_test.go new file mode 100644 index 00000000..da2b1289 --- /dev/null +++ b/internals/notifier/notification/handler_test.go @@ -0,0 +1,54 @@ +package notification + +import ( + "github.com/myrteametrics/myrtea-sdk/v4/expression" + "testing" +) + +func TestNewHandler(t *testing.T) { + handler := NewHandler(0) + expression.AssertNotEqual(t, handler, nil, "NewHandler() should not return nil") +} + +func TestHandler_RegisterNotificationType_AddsNewType(t *testing.T) { + handler := NewHandler(0) + notification := BaseNotification{} + handler.RegisterNotificationType(notification) + _, exists := handler.notificationTypes[getType(notification)] + expression.AssertEqual(t, exists, true, "RegisterNotificationType() should add new type") +} + +func TestHandler_RegisterNotificationType_OverwritesExistingType(t *testing.T) { + handler := NewHandler(0) + notification := BaseNotification{} + handler.RegisterNotificationType(notification) + notification2 := BaseNotification{} // Assuming this has the same type as the first one + handler.RegisterNotificationType(notification2) + expression.AssertEqual(t, handler.notificationTypes[getType(notification)], notification2, "RegisterNotificationType() should overwrite existing type") +} + +func TestHandler_UnregisterNotificationType_RemovesExistingType(t *testing.T) { + handler := NewHandler(0) + notification := BaseNotification{} + handler.RegisterNotificationType(notification) + handler.UnregisterNotificationType(notification) + _, exists := handler.notificationTypes[getType(notification)] + expression.AssertEqual(t, exists, false, "UnregisterNotificationType() should remove existing type") +} + +func TestHandler_UnregisterNotificationType_DoesNothingForNonExistingType(t *testing.T) { + handler := NewHandler(0) + notification := BaseNotification{} + handler.UnregisterNotificationType(notification) + _, exists := handler.notificationTypes[getType(notification)] + expression.AssertEqual(t, exists, false, "UnregisterNotificationType() should do nothing for non-existing type") +} + +func TestReplaceHandlerGlobals_ReplacesGlobalHandler(t *testing.T) { + handler := NewHandler(0) + prevHandler := H() + undo := ReplaceHandlerGlobals(handler) + expression.AssertEqual(t, H(), handler, "ReplaceHandlerGlobals() should replace global handler") + undo() + expression.AssertEqual(t, H(), prevHandler, "Undo function should restore previous global handler") +} diff --git a/internals/notifier/notification/notification.go b/internals/notifier/notification/notification.go index 46544ab6..c6515c56 100644 --- a/internals/notifier/notification/notification.go +++ b/internals/notifier/notification/notification.go @@ -1,12 +1,91 @@ package notification -//FrontNotification data structure represente the notification and her current state -type FrontNotification struct { - Notification - IsRead bool -} +import ( + "encoding/json" +) // Notification is a general interface for all notifications types type Notification interface { ToBytes() ([]byte, error) + NewInstance(id int64, data []byte, isRead bool) (Notification, error) + Equals(notification Notification) bool + SetId(id int64) Notification + SetPersistent(persistent bool) Notification + IsPersistent() bool +} + +// BaseNotification data structure represents a basic notification and her current state +type BaseNotification struct { + Notification `json:"-"` + Id int64 `json:"id"` + IsRead bool `json:"isRead"` + Type string `json:"type"` + Persistent bool `json:"persistent"` // is notification saved in db or not ? +} + +// NewBaseNotification returns a new instance of a BaseNotification +func NewBaseNotification(id int64, isRead bool, persistent bool) BaseNotification { + return BaseNotification{ + Id: id, + IsRead: isRead, + Persistent: persistent, + Type: "BaseNotification", + } +} + +// NewInstance returns a new instance of a BaseNotification +func (n BaseNotification) NewInstance(id int64, data []byte, isRead bool) (Notification, error) { + var notification BaseNotification + err := json.Unmarshal(data, ¬ification) + if err != nil { + return nil, err + } + notification.Id = id + notification.IsRead = isRead + notification.Notification = notification + return notification, nil +} + +// ToBytes convert a notification in a json byte slice to be sent through any required channel +func (n BaseNotification) ToBytes() ([]byte, error) { + b, err := json.Marshal(n) + if err != nil { + return nil, err + } + return b, nil +} + +// Equals returns true if the two notifications are equals +func (n BaseNotification) Equals(notification Notification) bool { + notif, ok := notification.(BaseNotification) + if !ok { + return ok + } + if n.Id != notif.Id { + return false + } + if n.IsRead != notif.IsRead { + return false + } + if n.Type != notif.Type { + return false + } + return true +} + +// SetId set the notification ID +func (n BaseNotification) SetId(id int64) Notification { + n.Id = id + return n +} + +// SetPersistent sets whether the notification is persistent (saved to a database) +func (n BaseNotification) SetPersistent(persistent bool) Notification { + n.Persistent = persistent + return n +} + +// IsPersistent returns whether the notification is persistent (saved to a database) +func (n BaseNotification) IsPersistent() bool { + return n.Persistent } diff --git a/internals/notifier/notification/notification_mock.go b/internals/notifier/notification/notification_mock.go index 4a8ca546..4d68eca4 100644 --- a/internals/notifier/notification/notification_mock.go +++ b/internals/notifier/notification/notification_mock.go @@ -7,30 +7,34 @@ import ( // MockNotification is an implementation of a notification main type type MockNotification struct { - ID int64 `json:"id"` - Type string `json:"type"` + BaseNotification CreationDate time.Time `json:"creationDate"` Groups []int64 `json:"groups"` Level string `json:"level"` Title string `json:"title"` SubTitle string `json:"subtitle"` Description string `json:"description"` + Target string `json:"target"` Context map[string]interface{} `json:"context,omitempty"` } // NewMockNotification renders a new MockNotification instance -func NewMockNotification(level string, title string, subTitle string, description string, creationDate time.Time, +func NewMockNotification(id int64, level string, title string, subTitle string, description string, creationDate time.Time, groups []int64, context map[string]interface{}) *MockNotification { return &MockNotification{ - Type: "mock", + BaseNotification: BaseNotification{ + Id: id, + Type: "MockNotification", + Persistent: true, + }, CreationDate: creationDate, - // Groups: groups, - Level: level, - Title: title, - SubTitle: subTitle, - Description: description, - Context: context, + Groups: groups, + Level: level, + Title: title, + SubTitle: subTitle, + Description: description, + Context: context, } } @@ -42,3 +46,80 @@ func (n MockNotification) ToBytes() ([]byte, error) { } return b, nil } + +// NewInstance returns a new instance of a MockNotification +func (n MockNotification) NewInstance(id int64, data []byte, isRead bool) (Notification, error) { + var notification MockNotification + err := json.Unmarshal(data, ¬ification) + if err != nil { + return nil, err + } + notification.Id = id + notification.IsRead = isRead + notification.Notification = notification + return notification, nil +} + +// Equals returns true if the two notifications are equals +func (n MockNotification) Equals(notification Notification) bool { + notif, ok := notification.(MockNotification) + if !ok { + return ok + } + if !notif.BaseNotification.Equals(n.BaseNotification) { + return false + } + if notif.CreationDate != n.CreationDate { + return false + } + if notif.Level != n.Level { + return false + } + if notif.Title != n.Title { + return false + } + if notif.SubTitle != n.SubTitle { + return false + } + if notif.Description != n.Description { + return false + } + if notif.Context != nil && n.Context != nil { + if len(notif.Context) != len(n.Context) { + return false + } + for k, v := range notif.Context { + if n.Context[k] != v { + return false + } + } + } else if notif.Context != nil || n.Context != nil { + return false + } + if len(notif.Groups) != len(n.Groups) { + return false + } + for i, v := range notif.Groups { + if n.Groups[i] != v { + return false + } + } + return true +} + +// SetId set the notification ID +func (n MockNotification) SetId(id int64) Notification { + n.Id = id + return n +} + +// SetPersistent sets whether the notification is persistent (saved to a database) +func (n MockNotification) SetPersistent(persistent bool) Notification { + n.Persistent = persistent + return n +} + +// IsPersistent returns whether the notification is persistent (saved to a database) +func (n MockNotification) IsPersistent() bool { + return n.Persistent +} diff --git a/internals/notifier/notification/notification_test.go b/internals/notifier/notification/notification_test.go new file mode 100644 index 00000000..d8c3d3d4 --- /dev/null +++ b/internals/notifier/notification/notification_test.go @@ -0,0 +1,236 @@ +package notification + +import ( + "github.com/myrteametrics/myrtea-sdk/v4/expression" + "testing" + "time" +) + +func TestBaseNotificationToBytes(t *testing.T) { + notification := BaseNotification{ + Id: 1, + Type: "Test", + IsRead: true, + } + + bytes, err := notification.ToBytes() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if bytes == nil { + t.Fatalf("Expected bytes, got nil") + } +} + +func TestBaseNotificationNewInstance(t *testing.T) { + s := BaseNotification{ + Id: 1, + Type: "Test", + IsRead: true, + } + se, e := s.ToBytes() + if e == nil { + t.Log(string(se)) + } + + data := []byte(`{"id":1,"type":"Test","isRead":true}`) + notification, err := BaseNotification{}.NewInstance(1, data, true) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + expected := BaseNotification{ + Id: 1, + Type: "Test", + IsRead: true, + } + + expression.AssertEqual(t, expected.Equals(notification), true) +} + +func TestBaseNotificationNewInstanceWithInvalidData(t *testing.T) { + data := []byte(`{"id":1,"type":"Test","isRead":"invalid"}`) + _, err := BaseNotification{}.NewInstance(1, data, true) + if err == nil { + t.Fatalf("Expected error, got nil") + } +} + +func TestBaseNotification_Equals(t *testing.T) { + notif := BaseNotification{ + Id: 1, + Type: "Test", + IsRead: true, + } + + expression.AssertEqual(t, notif.Equals(BaseNotification{ + Id: 1, + Type: "Test", + IsRead: true, + }), true) + + expression.AssertEqual(t, notif.Equals(BaseNotification{ + Id: 2, + Type: "Test", + IsRead: true, + }), false) + + expression.AssertEqual(t, notif.Equals(BaseNotification{ + Id: 1, + Type: "Test2", + IsRead: true, + }), false) + + expression.AssertEqual(t, notif.Equals(BaseNotification{ + Id: 1, + Type: "Test", + IsRead: false, + }), false) +} + +func TestMockNotification_Equals(t *testing.T) { + baseNotification := BaseNotification{ + Id: 1, + Type: "Test", + IsRead: true, + } + now := time.Now() + notif := MockNotification{ + BaseNotification: baseNotification, + CreationDate: now, + Level: "info", + Title: "title", + SubTitle: "subTitle", + Description: "description", + Context: map[string]interface{}{"test": "test"}, + Groups: []int64{1, 2}, + } + + expression.AssertEqual(t, notif.Equals(MockNotification{ + BaseNotification: baseNotification, + CreationDate: now, + Level: "info", + Title: "title", + SubTitle: "subTitle", + Description: "description", + Context: map[string]interface{}{"test": "test"}, + Groups: []int64{1, 2}, + }), true) + + expression.AssertEqual(t, notif.Equals(MockNotification{ + BaseNotification: BaseNotification{ + Id: 2, + Type: "Test", + IsRead: true, + }, + CreationDate: now, + Level: "info", + Title: "title", + SubTitle: "subTitle", + Description: "description", + Context: map[string]interface{}{"test": "test"}, + Groups: []int64{1, 2}, + }), false) + + expression.AssertEqual(t, notif.Equals(MockNotification{ + BaseNotification: baseNotification, + CreationDate: time.Now().AddDate(1, 0, 0), + Level: "info", + Title: "title", + SubTitle: "subTitle", + Description: "description", + Context: map[string]interface{}{"test": "test"}, + Groups: []int64{1, 2}, + }), false) + + expression.AssertEqual(t, notif.Equals(MockNotification{ + BaseNotification: baseNotification, + CreationDate: now, + Level: "infos", + Title: "title", + SubTitle: "subTitle", + Description: "description", + Context: map[string]interface{}{"test": "test"}, + Groups: []int64{1, 2}, + }), false) + + expression.AssertEqual(t, notif.Equals(MockNotification{ + BaseNotification: baseNotification, + CreationDate: now, + Level: "info", + Title: "titles", + SubTitle: "subTitle", + Description: "description", + Context: map[string]interface{}{"test": "test"}, + Groups: []int64{1, 2}, + }), false) + + expression.AssertEqual(t, notif.Equals(MockNotification{ + BaseNotification: baseNotification, + CreationDate: now, + Level: "info", + Title: "title", + SubTitle: "subTitles", + Description: "description", + Context: map[string]interface{}{"test": "test"}, + Groups: []int64{1, 2}, + }), false) + expression.AssertEqual(t, notif.Equals(MockNotification{ + BaseNotification: baseNotification, + CreationDate: now, + Level: "info", + Title: "title", + SubTitle: "subTitle", + Description: "descriptions", + Context: map[string]interface{}{"test": "test"}, + Groups: []int64{1, 2}, + }), false) + + expression.AssertEqual(t, notif.Equals(MockNotification{ + BaseNotification: baseNotification, + CreationDate: now, + Level: "info", + Title: "title", + SubTitle: "subTitle", + Description: "description", + Context: map[string]interface{}{"tests": "test"}, + Groups: []int64{1, 2}, + }), false) + + expression.AssertEqual(t, notif.Equals(MockNotification{ + BaseNotification: baseNotification, + CreationDate: now, + Level: "info", + Title: "title", + SubTitle: "subTitle", + Description: "description", + Context: map[string]interface{}{"test": "test"}, + Groups: []int64{1, 2, 3}, + }), false) + +} + +func TestBaseNotification_SetId(t *testing.T) { + notif, err := BaseNotification{}.NewInstance(1, []byte(`{}`), true) + if err != nil { + t.Fatalf("Error: %v", err) + } + + notif = notif.SetId(2) + baseNotification, ok := notif.(BaseNotification) + expression.AssertEqual(t, ok, true) + expression.AssertEqual(t, baseNotification.Id, int64(2)) +} + +func TestMockNotification_SetId(t *testing.T) { + notif, err := MockNotification{}.NewInstance(1, []byte(`{}`), true) + if err != nil { + t.Fatalf("Error: %v", err) + } + + notif = notif.SetId(2) + mockNotification, ok := notif.(MockNotification) + expression.AssertEqual(t, ok, true) + expression.AssertEqual(t, mockNotification.Id, int64(2)) +} diff --git a/internals/notifier/notification/postgres_repository.go b/internals/notifier/notification/postgres_repository.go index 94496668..c4761d17 100644 --- a/internals/notifier/notification/postgres_repository.go +++ b/internals/notifier/notification/postgres_repository.go @@ -1,13 +1,12 @@ package notification import ( - "encoding/json" "errors" "time" + sq "github.com/Masterminds/squirrel" "github.com/jmoiron/sqlx" "github.com/myrteametrics/myrtea-engine-api/v5/internals/dbutils" - "go.uber.org/zap" ) // PostgresRepository is a repository containing the Fact definition based on a PSQL database and @@ -26,21 +25,21 @@ func NewPostgresRepository(dbClient *sqlx.DB) Repository { } // Create creates a new Notification definition in the repository -func (r *PostgresRepository) Create(notif Notification) (int64, error) { - - data, err := json.Marshal(notif) +func (r *PostgresRepository) Create(notif Notification, userLogin string) (int64, error) { + data, err := notif.ToBytes() if err != nil { return -1, err } ts := time.Now().Truncate(1 * time.Millisecond).UTC() - query := `INSERT INTO notifications_history_v1 (id, data, created_at) VALUES (DEFAULT, :data, :created_at) RETURNING id` - params := map[string]interface{}{ - "data": data, - "created_at": ts, - } - rows, err := r.conn.NamedQuery(query, params) + insertStatement := newStatement(). + Insert("notifications_history_v1"). + Columns("id", "data", "type", "user_login", "created_at"). + Values(sq.Expr("DEFAULT"), data, getType(notif), userLogin, ts). + Suffix("RETURNING id") + + rows, err := insertStatement.RunWith(r.conn.DB).Query() if err != nil { return -1, err } @@ -55,19 +54,16 @@ func (r *PostgresRepository) Create(notif Notification) (int64, error) { return id, nil } -// Get returns a notification by it's ID -func (r *PostgresRepository) Get(id int64) *FrontNotification { +// Get returns a notification by its ID +func (r *PostgresRepository) Get(id int64, userLogin string) (Notification, error) { + getStatement := newStatement(). + Select("id", "data", "isread", "type"). + Where(sq.And{sq.Eq{"id": id}, sq.Eq{"user_login": userLogin}}). + From("notifications_history_v1") - // TODO: "ORDER BY" should be an option in dbutils.DBQueryOptionnal - query := `SELECT id, data, isread FROM notifications_history_v1 WHERE id = :id` - params := map[string]interface{}{ - "id": id, - } - - rows, err := r.conn.NamedQuery(query, params) + rows, err := getStatement.RunWith(r.conn.DB).Query() if err != nil { - zap.L().Error("", zap.Error(err)) - return nil + return nil, errors.New("couldn't retrieve any notification with this id. The query is equal to: " + err.Error()) } defer rows.Close() @@ -75,82 +71,80 @@ func (r *PostgresRepository) Get(id int64) *FrontNotification { var id int64 var data string var isRead bool + var notifType string - err := rows.Scan(&id, &data, &isRead) + err = rows.Scan(&id, &data, &isRead, ¬ifType) if err != nil { - zap.L().Error("", zap.Error(err)) - return nil + return nil, errors.New("couldn't retrieve any notification. The query is equal to: " + err.Error()) } - var notif MockNotification - err = json.Unmarshal([]byte(data), ¬if) - if err != nil { - zap.L().Error("", zap.Error(err)) - return nil + t, ok := H().notificationTypes[notifType] + if !ok { + return nil, errors.New("notification type does not exist") } - notif.ID = id - - return &FrontNotification{ - Notification: notif, - IsRead: isRead, + instance, err := t.NewInstance(id, []byte(data), isRead) + if err != nil { + return nil, errors.New("notification couldn't be instanced") } + + return instance, nil } - return nil + return nil, errors.New("no notification found with this id") } -// GetByRoles returns all notifications related to a certain list of roles -func (r *PostgresRepository) GetAll(queryOptionnal dbutils.DBQueryOptionnal) ([]FrontNotification, error) { +// GetAll returns all notifications from the repository +func (r *PostgresRepository) GetAll(queryOptionnal dbutils.DBQueryOptionnal, userLogin string) ([]Notification, error) { + getStatement := newStatement(). + Select("id", "data", "isread", "type"). + Where(sq.Eq{"user_login": userLogin}). + From("notifications_history_v1") - // TODO: "ORDER BY" should be an option in dbutils.DBQueryOptionnal - query := `SELECT id, data, isread FROM notifications_history_v1` - params := map[string]interface{}{} if queryOptionnal.MaxAge > 0 { - query += ` WHERE created_at > :created_at` - params["created_at"] = time.Now().UTC().Add(-1 * queryOptionnal.MaxAge) + getStatement = getStatement.Where(sq.Gt{"created_at": time.Now().UTC().Add(-1 * queryOptionnal.MaxAge)}) } - query += ` ORDER BY created_at DESC` + if queryOptionnal.Limit > 0 { - query += ` LIMIT :limit` - params["limit"] = queryOptionnal.Limit + getStatement = getStatement.Limit(uint64(queryOptionnal.Limit)) } + if queryOptionnal.Offset > 0 { - query += ` OFFSET :offset` - params["offset"] = queryOptionnal.Offset + getStatement = getStatement.Offset(uint64(queryOptionnal.Offset)) } - rows, err := r.conn.NamedQuery(query, params) + // TODO: "ORDER BY" should be an option in dbutils.DBQueryOptionnal + getStatement = getStatement.OrderBy("created_at DESC") + + rows, err := getStatement.RunWith(r.conn.DB).Query() if err != nil { return nil, errors.New("couldn't retrieve any notification with these roles. The query is equal to: " + err.Error()) } defer rows.Close() - notifications := make([]FrontNotification, 0) + notifications := make([]Notification, 0) for rows.Next() { var id int64 var data string - var notif MockNotification - var isRead bool + var notifType string - err := rows.Scan(&id, &data, &isRead) + err = rows.Scan(&id, &data, &isRead, ¬ifType) if err != nil { return nil, errors.New("couldn't scan the notification data:" + err.Error()) } - // Retrieve data json data - err = json.Unmarshal([]byte(data), ¬if) - if err != nil { - return nil, errors.New("couldn't convert data content:" + err.Error()) + t, ok := H().notificationTypes[notifType] + if !ok { + return nil, errors.New("notification type does not exist") } - notif.ID = id + instance, err := t.NewInstance(id, []byte(data), isRead) + if err != nil { + return nil, errors.New("notification couldn't be instanced") + } - notifications = append(notifications, FrontNotification{ - Notification: notif, - IsRead: isRead, - }) + notifications = append(notifications, instance) } if err != nil { return nil, errors.New("deformed Data " + err.Error()) @@ -159,12 +153,12 @@ func (r *PostgresRepository) GetAll(queryOptionnal dbutils.DBQueryOptionnal) ([] } // Delete deletes a notification from the repository by its id -func (r *PostgresRepository) Delete(id int64) error { - query := `DELETE FROM notifications_history_v1 WHERE id = :id` +func (r *PostgresRepository) Delete(id int64, userLogin string) error { + deleteStatement := newStatement(). + Delete("notifications_history_v1"). + Where(sq.And{sq.Eq{"id": id}, sq.Eq{"user_login": userLogin}}) - res, err := r.conn.NamedExec(query, map[string]interface{}{ - "id": id, - }) + res, err := deleteStatement.RunWith(r.conn.DB).Exec() if err != nil { return err } @@ -178,14 +172,14 @@ func (r *PostgresRepository) Delete(id int64) error { return nil } -//UpdateRead updates a notification status by changing the isRead state to true once it has been read -func (r *PostgresRepository) UpdateRead(id int64, status bool) error { - query := `UPDATE notifications_history_v1 SET isread = :status WHERE id = :id` +// UpdateRead updates a notification status by changing the isRead state to true once it has been read +func (r *PostgresRepository) UpdateRead(id int64, status bool, userLogin string) error { + update := newStatement(). + Update("notifications_history_v1"). + Set("isread", status). + Where(sq.And{sq.Eq{"id": id}, sq.Eq{"user_login": userLogin}}) - res, err := r.conn.NamedExec(query, map[string]interface{}{ - "status": status, - "id": id, - }) + res, err := update.RunWith(r.conn.DB).Exec() if err != nil { return err } @@ -198,3 +192,20 @@ func (r *PostgresRepository) UpdateRead(id int64, status bool) error { } return nil } + +// CleanExpired deletes all notifications older than the given lifetime +func (r *PostgresRepository) CleanExpired(lifetime time.Duration) (int64, error) { + deleteStatement := newStatement(). + Delete("notifications_history_v1"). + Where(sq.Lt{"created_at": time.Now().UTC().Add(-1 * lifetime)}) + + res, err := deleteStatement.RunWith(r.conn.DB).Exec() + if err != nil { + return 0, err + } + i, err := res.RowsAffected() + if err != nil { + return 0, err + } + return i, nil +} diff --git a/internals/notifier/notification/repository.go b/internals/notifier/notification/repository.go index a789ffd2..89d95d44 100644 --- a/internals/notifier/notification/repository.go +++ b/internals/notifier/notification/repository.go @@ -1,7 +1,9 @@ package notification import ( + sq "github.com/Masterminds/squirrel" "sync" + "time" "github.com/myrteametrics/myrtea-engine-api/v5/internals/dbutils" ) @@ -10,11 +12,12 @@ import ( // (in-memory map, sql database, in-memory cache, file system, ...) // It allows standard CRUD operation on facts type Repository interface { - Create(notif Notification) (int64, error) - Get(id int64) *FrontNotification - GetAll(queryOptionnal dbutils.DBQueryOptionnal) ([]FrontNotification, error) - Delete(id int64) error - UpdateRead(id int64, state bool) error + Create(notif Notification, userLogin string) (int64, error) + Get(id int64, userLogin string) (Notification, error) + GetAll(queryOptionnal dbutils.DBQueryOptionnal, userLogin string) ([]Notification, error) + Delete(id int64, userLogin string) error + UpdateRead(id int64, state bool, userLogin string) error + CleanExpired(lifetime time.Duration) (int64, error) } var ( @@ -40,3 +43,7 @@ func ReplaceGlobals(repository Repository) func() { _globalRepository = repository return func() { ReplaceGlobals(prev) } } + +func newStatement() sq.StatementBuilderType { + return sq.StatementBuilder.PlaceholderFormat(sq.Dollar) +} diff --git a/internals/notifier/notification/utils.go b/internals/notifier/notification/utils.go new file mode 100644 index 00000000..e8cc152b --- /dev/null +++ b/internals/notifier/notification/utils.go @@ -0,0 +1,11 @@ +package notification + +import "reflect" + +func getType(myvar interface{}) string { + if t := reflect.TypeOf(myvar); t.Kind() == reflect.Ptr { + return "*" + t.Elem().Name() + } else { + return t.Name() + } +} diff --git a/internals/notifier/notifier.go b/internals/notifier/notifier.go index 2846f74c..eefd5d3c 100644 --- a/internals/notifier/notifier.go +++ b/internals/notifier/notifier.go @@ -1,6 +1,7 @@ package notifier import ( + "github.com/myrteametrics/myrtea-engine-api/v5/internals/security/users" "sync" "time" @@ -45,7 +46,7 @@ func NewNotifier() *Notifier { cm := NewClientManager() return &Notifier{ clientManager: cm, - cache: make(map[string]time.Time, 0), + cache: make(map[string]time.Time), } } @@ -61,6 +62,7 @@ func (notifier *Notifier) Unregister(client Client) error { return notifier.clientManager.Unregister(client) } +// verifyCache check if a notification has already been sent func (notifier *Notifier) verifyCache(key string, timeout time.Duration) bool { if val, ok := notifier.cache[key]; ok && time.Now().UTC().Before(val) { return false @@ -69,43 +71,52 @@ func (notifier *Notifier) verifyCache(key string, timeout time.Duration) bool { return true } -// SendToRoles send a notification to every user related to the input list of roles -func (notifier *Notifier) SendToRoles(cacheKey string, timeout time.Duration, notif notification.Notification, roles []uuid.UUID) { - - zap.L().Debug("notifier.SendToRoles", zap.Any("roles", roles), zap.Any("notification", notif)) - - if cacheKey != "" && !notifier.verifyCache(cacheKey, timeout) { - zap.L().Debug("Notification send skipped") - return - } - - id, err := notification.R().Create(notif) - if err != nil { - zap.L().Error("Add notification to history", zap.Error(err)) - return - } - - notifFull := notification.R().Get(id) - if notifFull == nil { - zap.L().Error("Notification not found after creation", zap.Int64("id", id)) +func (notifier *Notifier) CleanCache() { + for key, val := range notifier.cache { + if time.Now().UTC().After(val) { + delete(notifier.cache, key) + } } - - // FIXME: This should be fully reworking after security refactoring and removal of groups - - // if roles != nil && len(roles) > 0 { - // clients := make(map[Client]bool, 0) - // for _, roleID := range roles { - // roleClients := notifier.findClientsByRoleID(roleID) - // for _, client := range roleClients { - // clients[client] = true - // } - // } - // for client := range clients { - // notifier.sendToClient(notifFull, client) - // } - // } } +// TODO: renew this +//// SendToRoles send a notification to every user related to the input list of roles +//func (notifier *Notifier) SendToRoles(cacheKey string, timeout time.Duration, notif notification.Notification, roles []uuid.UUID) { +// +// zap.L().Debug("notifier.SendToRoles", zap.Any("roles", roles), zap.Any("notification", notif)) +// +// if cacheKey != "" && !notifier.verifyCache(cacheKey, timeout) { +// zap.L().Debug("Notification send skipped") +// return +// } +// +// id, err := notification.R().Create(notif, "") +// if err != nil { +// zap.L().Error("Add notification to history", zap.Error(err)) +// return +// } +// +// notifFull, err := notification.R().Get(id) +// if notifFull == nil { +// zap.L().Error("Notification not found after creation", zap.Int64("id", id)) +// } +// +// // FIXME: This should be fully reworking after security refactoring and removal of groups +// +// // if roles != nil && len(roles) > 0 { +// // clients := make(map[Client]bool, 0) +// // for _, roleID := range roles { +// // roleClients := notifier.findClientsByRoleID(roleID) +// // for _, client := range roleClients { +// // clients[client] = true +// // } +// // } +// // for client := range clients { +// // notifier.sendToClient(notifFull, client) +// // } +// // } +//} + // sendToClient convert and send a notification to a specific client // Every multiplexing function must call this function in the end to send message func (notifier *Notifier) sendToClient(notif notification.Notification, client Client) { @@ -126,34 +137,73 @@ func (notifier *Notifier) Broadcast(notif notification.Notification) { } // SendToUsers send a notification to users corresponding the input ids -func (notifier *Notifier) SendToUsers(notif notification.Notification, users []uuid.UUID) { +func (notifier *Notifier) SendToUsers(notif notification.Notification, users []users.UserWithPermissions) error { + if users != nil && len(users) > 0 { + for _, user := range users { + err := notifier.SendToUser(notif, user) + if err != nil { + return err + } + } + } + return nil +} + +// SendToUserLogins send a notification to user logins corresponding the input ids +func (notifier *Notifier) SendToUserLogins(notif notification.Notification, users []string) error { if users != nil && len(users) > 0 { - for _, userID := range users { - clients := notifier.findClientsByUserID(userID) - for _, client := range clients { - notifier.sendToClient(notif, client) + for _, user := range users { + err := notifier.SendToUserLogin(notif, user) + if err != nil { + return err } } } + return nil +} + +// SendToUser send a notification to a specific user +func (notifier *Notifier) SendToUser(notif notification.Notification, user users.UserWithPermissions) error { + return notifier.SendToUserLogin(notif, user.Login) +} + +// SendToUserLogin send a notification to a specific user using his login +func (notifier *Notifier) SendToUserLogin(notif notification.Notification, login string) error { + if notif.IsPersistent() { // Not all notifications needs notifications to be saved to database + id, err := notification.R().Create(notif, login) + if err != nil { + zap.L().Error("Add notification to history", zap.Error(err)) + return err + } + notif = notif.SetId(id) + } + + clients := notifier.findClientsByUserLogin(login) + for _, client := range clients { + notifier.sendToClient(notif, client) + } + return nil } -// Send send a byte slices to a specific websocket client +// Send a byte slices to a specific websocket client func (notifier *Notifier) Send(message []byte, client Client) { if client != nil { client.GetSendChannel() <- message } } -func (notifier *Notifier) findClientsByUserID(id uuid.UUID) []Client { +// findClientsByUserLogin returns a list of clients corresponding to the input login +func (notifier *Notifier) findClientsByUserLogin(login string) []Client { clients := make([]Client, 0) for _, client := range notifier.clientManager.GetClients() { - if client.GetUser() != nil && client.GetUser().ID == id { + if client.GetUser() != nil && client.GetUser().Login == login { clients = append(clients, client) } } return clients } +// findClientsByRoleID returns a list of clients corresponding to the input role id func (notifier *Notifier) findClientsByRoleID(id uuid.UUID) []Client { clients := make([]Client, 0) for _, client := range notifier.clientManager.GetClients() { diff --git a/internals/notifier/websocket_client_test.go b/internals/notifier/websocket_client_test.go index c417d7b2..f9c917bc 100644 --- a/internals/notifier/websocket_client_test.go +++ b/internals/notifier/websocket_client_test.go @@ -4,7 +4,9 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" + "time" "github.com/gorilla/websocket" ) @@ -14,7 +16,10 @@ func TestNewWSClient(t *testing.T) { // Server-side initialisation var client *WebsocketClient + wg := sync.WaitGroup{} + wg.Add(1) s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer wg.Done() var err error client, err = BuildWebsocketClient(w, r, nil) if err != nil { @@ -30,6 +35,20 @@ func TestNewWSClient(t *testing.T) { } defer ws.Close() + c := make(chan struct{}) + + // wait for the client to be ready + go func() { + wg.Wait() + c <- struct{}{} + }() + + select { + case <-c: + case <-time.After(time.Second): + t.Fatalf("Timed out waiting for wait group\n") + } + // Tests if client == nil { t.Fatal("Client not built") @@ -41,7 +60,10 @@ func TestWSClientRead(t *testing.T) { // Server-side initialisation var client *WebsocketClient + wg := sync.WaitGroup{} + wg.Add(1) s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer wg.Done() var err error client, err = BuildWebsocketClient(w, r, nil) if err != nil { @@ -58,6 +80,20 @@ func TestWSClientRead(t *testing.T) { } defer ws.Close() + c := make(chan struct{}) + + // wait for the client to be ready + go func() { + wg.Wait() + c <- struct{}{} + }() + + select { + case <-c: + case <-time.After(time.Second): + t.Fatalf("Timed out waiting for wait group\n") + } + // Tests for i := 0; i < 10; i++ { if err := ws.WriteMessage(websocket.TextMessage, []byte("hello")); err != nil { @@ -79,8 +115,11 @@ func TestWSClientWrite(t *testing.T) { ReplaceGlobals(NewNotifier()) // Server-side initialisation + wg := sync.WaitGroup{} + wg.Add(1) var client *WebsocketClient s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer wg.Done() var err error client, err = BuildWebsocketClient(w, r, nil) if err != nil { @@ -97,6 +136,20 @@ func TestWSClientWrite(t *testing.T) { } defer ws.Close() + c := make(chan struct{}) + + // wait for the client to be ready + go func() { + wg.Wait() + c <- struct{}{} + }() + + select { + case <-c: + case <-time.After(time.Second): + t.Fatalf("Timed out waiting for wait group\n") + } + // Tests for i := 0; i < 10; i++ { // Send message directly on the client Send channel diff --git a/internals/router/oidc/oidc_middleware.go b/internals/router/oidc/oidc_middleware.go index d54b193b..3fbc5627 100644 --- a/internals/router/oidc/oidc_middleware.go +++ b/internals/router/oidc/oidc_middleware.go @@ -3,7 +3,6 @@ package oidcAuth import ( "context" "errors" - "fmt" "net/http" "strings" "time" @@ -129,7 +128,7 @@ func ContextMiddleware(next http.Handler) http.Handler { loggerR := r.Context().Value(models.ContextKeyLoggerR) if loggerR != nil { - gorillacontext.Set(loggerR.(*http.Request), models.UserLogin, fmt.Sprintf("%s(%d)", up.User.Login, up.User.ID)) + gorillacontext.Set(loggerR.(*http.Request), models.UserLogin, up.User.Login) } ctx := context.WithValue(r.Context(), models.ContextKeyUser, up) diff --git a/internals/router/router.go b/internals/router/router.go index fdf24430..e17e5913 100644 --- a/internals/router/router.go +++ b/internals/router/router.go @@ -33,7 +33,13 @@ type Config struct { VerboseError bool AuthenticationMode string LogLevel zap.AtomicLevel - PluginCore *plugin.Core +} + +// Services is a wrapper for services instances, it is passed through router functions +type Services struct { + PluginCore *plugin.Core + ProcessorHandler *handlers.ProcessorHandler + ExportHandler *handlers.ExportHandler } // Check clean up the configuration and logs comments if required @@ -68,7 +74,7 @@ func (config *Config) Check() { // New returns a new fully configured instance of chi.Mux // It instanciates all middlewares including the security ones, all routes and route groups -func New(config Config) *chi.Mux { +func New(config Config, services Services) *chi.Mux { config.Check() @@ -76,7 +82,7 @@ func New(config Config) *chi.Mux { // Global middleware stack // TODO: Add CORS middleware if config.CORS { - cors := cors.New(cors.Options{ + corsHandler := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, @@ -84,7 +90,7 @@ func New(config Config) *chi.Mux { AllowCredentials: true, MaxAge: 300, // Maximum value not ignored by any of major browsers }) - r.Use(cors.Handler) + r.Use(corsHandler.Handler) } r.Use(chimiddleware.SetHeader("Strict-Transport-Security", "max-age=63072000; includeSubDomains")) @@ -105,11 +111,11 @@ func New(config Config) *chi.Mux { switch config.AuthenticationMode { case "BASIC": - routes, err = buildRoutesV3Basic(config) + routes, err = buildRoutesV3Basic(config, services) case "SAML": - routes, err = buildRoutesV3SAML(config) + routes, err = buildRoutesV3SAML(config, services) case "OIDC": - routes, err = buildRoutesV3OIDC(config) + routes, err = buildRoutesV3OIDC(config, services) default: zap.L().Panic("Authentication mode not supported", zap.String("AuthenticationMode", config.AuthenticationMode)) return nil @@ -124,7 +130,7 @@ func New(config Config) *chi.Mux { return r } -func buildRoutesV3Basic(config Config) (func(r chi.Router), error) { +func buildRoutesV3Basic(config Config, services Services) (func(r chi.Router), error) { signingKey := []byte(security.RandString(128)) securityMiddleware := security.NewMiddlewareJWT(signingKey, security.NewDatabaseAuth(postgres.DB())) @@ -164,9 +170,9 @@ func buildRoutesV3Basic(config Config) (func(r chi.Router), error) { rg.Use(chimiddleware.SetHeader("Content-Type", "application/json")) rg.HandleFunc("/log_level", config.LogLevel.ServeHTTP) - rg.Mount("/engine", engineRouter()) + rg.Mount("/engine", engineRouter(services)) - for _, plugin := range config.PluginCore.Plugins { + for _, plugin := range services.PluginCore.Plugins { rg.Mount(plugin.Plugin.HandlerPrefix(), plugin.Plugin.Handler()) rg.HandleFunc(fmt.Sprintf("/plugin%s", plugin.Plugin.HandlerPrefix()), func(w http.ResponseWriter, r *http.Request) { render.JSON(w, r, map[string]interface{}{"loaded": true}) @@ -205,12 +211,12 @@ func buildRoutesV3Basic(config Config) (func(r chi.Router), error) { // } // rg.Use(chimiddleware.SetHeader("Content-Type", "application/json")) - rg.Mount("/service", serviceRouter()) + rg.Mount("/service", serviceRouter(services)) }) }, nil } -func buildRoutesV3SAML(config Config) (func(r chi.Router), error) { +func buildRoutesV3SAML(config Config, services Services) (func(r chi.Router), error) { samlConfig := SamlSPMiddlewareConfig{ MetadataMode: viper.GetString("AUTHENTICATION_SAML_METADATA_MODE"), @@ -257,9 +263,9 @@ func buildRoutesV3SAML(config Config) (func(r chi.Router), error) { rg.Use(chimiddleware.SetHeader("Content-Type", "application/json")) rg.HandleFunc("/log_level", config.LogLevel.ServeHTTP) - rg.Mount("/engine", engineRouter()) + rg.Mount("/engine", engineRouter(services)) - for _, plugin := range config.PluginCore.Plugins { + for _, plugin := range services.PluginCore.Plugins { rg.Mount(plugin.Plugin.HandlerPrefix(), plugin.Plugin.Handler()) rg.HandleFunc(fmt.Sprintf("/plugin%s", plugin.Plugin.HandlerPrefix()), func(w http.ResponseWriter, r *http.Request) { render.JSON(w, r, map[string]interface{}{"loaded": true}) @@ -291,20 +297,12 @@ func buildRoutesV3SAML(config Config) (func(r chi.Router), error) { // } rg.Use(chimiddleware.SetHeader("Content-Type", "application/json")) - rg.Mount("/service", serviceRouter()) + rg.Mount("/service", serviceRouter(services)) }) }, nil } -// ReverseProxy act as a reverse proxy for any plugin http handlers -func ReverseProxy(plugin plugin.MyrteaPlugin) http.HandlerFunc { - url, _ := url.Parse(fmt.Sprintf("http://localhost:%d", plugin.ServicePort())) - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - httputil.NewSingleHostReverseProxy(url).ServeHTTP(w, r) - }) -} - -func buildRoutesV3OIDC(config Config) (func(r chi.Router), error) { +func buildRoutesV3OIDC(config Config, services Services) (func(r chi.Router), error) { return func(r chi.Router) { // Public routes @@ -331,9 +329,9 @@ func buildRoutesV3OIDC(config Config) (func(r chi.Router), error) { rg.Use(chimiddleware.SetHeader("Content-Type", "application/json")) rg.HandleFunc("/log_level", config.LogLevel.ServeHTTP) - rg.Mount("/engine", engineRouter()) + rg.Mount("/engine", engineRouter(services)) - for _, plugin := range config.PluginCore.Plugins { + for _, plugin := range services.PluginCore.Plugins { rg.Mount(plugin.Plugin.HandlerPrefix(), plugin.Plugin.Handler()) rg.HandleFunc(fmt.Sprintf("/plugin%s", plugin.Plugin.HandlerPrefix()), func(w http.ResponseWriter, r *http.Request) { render.JSON(w, r, map[string]interface{}{"loaded": true}) @@ -357,7 +355,15 @@ func buildRoutesV3OIDC(config Config) (func(r chi.Router), error) { r.Group(func(rg chi.Router) { rg.Use(chimiddleware.SetHeader("Content-Type", "application/json")) - rg.Mount("/service", serviceRouter()) + rg.Mount("/service", serviceRouter(services)) }) }, nil } + +// ReverseProxy act as a reverse proxy for any plugin http handlers +func ReverseProxy(plugin plugin.MyrteaPlugin) http.HandlerFunc { + pluginUrl, _ := url.Parse(fmt.Sprintf("http://localhost:%d", plugin.ServicePort())) + return func(w http.ResponseWriter, r *http.Request) { + httputil.NewSingleHostReverseProxy(pluginUrl).ServeHTTP(w, r) + } +} diff --git a/internals/router/routes.go b/internals/router/routes.go index 81e09981..8d0a7055 100644 --- a/internals/router/routes.go +++ b/internals/router/routes.go @@ -42,7 +42,7 @@ func adminRouter() http.Handler { return r } -func engineRouter() http.Handler { +func engineRouter(services Services) http.Handler { r := chi.NewRouter() r.Get("/security/myself", handlers.GetUserSelf) @@ -71,6 +71,7 @@ func engineRouter() http.Handler { r.Post("/facts/execute", handlers.ExecuteFactFromSource) // ?time=2019-05-10T12:00:00.000 debug= r.Get("/facts/{id}/hits", handlers.GetFactHits) // ?time=2019-05-10T12:00:00.000 debug= r.Get("/facts/{id}/es", handlers.FactToESQuery) + r.Post("/facts/streamedexport", handlers.ExportFactStreamed) r.Get("/situations", handlers.GetSituations) r.Get("/situations/{id}", handlers.GetSituation) @@ -174,7 +175,12 @@ func engineRouter() http.Handler { r.Get("/connector/{id}/executions/last", handlers.GetlastConnectorExecutionDateTime) - r.Get("/facts/{id}/export", handlers.ExportFact) + // exports + r.Get("/exports", services.ExportHandler.GetExports) + r.Get("/exports/{id}", services.ExportHandler.GetExport) + r.Get("/exports/{id}/download", services.ExportHandler.DownloadExport) + r.Delete("/exports/{id}", services.ExportHandler.DeleteExport) + r.Post("/exports/fact", services.ExportHandler.ExportFact) r.Get("/variablesconfig", handlers.GetVariablesConfig) r.Get("/variablesconfig/{id}", handlers.GetVariableConfig) @@ -186,12 +192,11 @@ func engineRouter() http.Handler { return r } -func serviceRouter() http.Handler { +func serviceRouter(services Services) http.Handler { r := chi.NewRouter() - processorHandler := handlers.NewProcessorHandler() r.Post("/objects", handlers.PostObjects) - r.Post("/aggregates", processorHandler.PostAggregates) + r.Post("/aggregates", services.ProcessorHandler.PostAggregates) r.Get("/externalconfigs", handlers.GetExternalConfigs) r.Get("/externalconfigs/{id}", handlers.GetExternalConfig) diff --git a/internals/router/saml_middleware.go b/internals/router/saml_middleware.go index 5b02a707..ec94d776 100644 --- a/internals/router/saml_middleware.go +++ b/internals/router/saml_middleware.go @@ -6,7 +6,6 @@ import ( "crypto/tls" "crypto/x509" "errors" - "fmt" "net/http" "net/url" @@ -217,7 +216,7 @@ func (m *SamlSPMiddleware) ContextMiddleware(next http.Handler) http.Handler { loggerR := r.Context().Value(models.ContextKeyLoggerR) if loggerR != nil { - gorillacontext.Set(loggerR.(*http.Request), models.UserLogin, fmt.Sprintf("%s(%d)", up.User.Login, up.User.ID)) + gorillacontext.Set(loggerR.(*http.Request), models.UserLogin, up.User.Login) } ctx := context.WithValue(r.Context(), models.ContextKeyUser, up) diff --git a/internals/security/permissions/permission.go b/internals/security/permissions/permission.go index c2b1b85f..06574794 100644 --- a/internals/security/permissions/permission.go +++ b/internals/security/permissions/permission.go @@ -26,6 +26,7 @@ const ( TypeCalendar = "calendar" TypeModel = "model" TypeFrontend = "frontend" + TypeExport = "export" ) type Permission struct { @@ -35,6 +36,7 @@ type Permission struct { Action string `json:"action"` } +// New returns a new Permission func New(resourceType string, resourceID string, action string) Permission { return Permission{ ResourceType: resourceType, @@ -43,6 +45,7 @@ func New(resourceType string, resourceID string, action string) Permission { } } +// ListMatchingPermissions returns a list of permissions matching the given permission func ListMatchingPermissions(permissions []Permission, match Permission) []Permission { lst := make([]Permission, 0) for _, permission := range permissions { @@ -60,7 +63,8 @@ func ListMatchingPermissions(permissions []Permission, match Permission) []Permi return lst } -func GetRessourceIDs(permissions []Permission) []string { +// GetResourceIDs returns a list of resource IDs from a list of permissions +func GetResourceIDs(permissions []Permission) []string { resourceIDs := make([]string, 0) for _, permission := range permissions { resourceIDs = append(resourceIDs, permission.ResourceID) @@ -68,6 +72,7 @@ func GetRessourceIDs(permissions []Permission) []string { return resourceIDs } +// HasPermission checks if the user has the required permission func matchPermission(permission string, required string) bool { if permission == All { return true @@ -81,6 +86,7 @@ func matchPermission(permission string, required string) bool { return false } +// HasPermission checks strictly if the user has the required permission func matchPermissionStrict(permission string, required string) bool { if permission == All { return true @@ -91,6 +97,7 @@ func matchPermissionStrict(permission string, required string) bool { return false } +// HasPermission checks if the user has the required permission func HasPermission(permissions []Permission, required Permission) bool { for _, permission := range permissions { if !matchPermissionStrict(permission.ResourceType, required.ResourceType) { @@ -107,6 +114,7 @@ func HasPermission(permissions []Permission, required Permission) bool { return false } +// HasPermissionAtLeastOne checks if the user has at least one of the required permissions func HasPermissionAtLeastOne(permissions []Permission, requiredAtLeastOne []Permission) bool { for _, required := range requiredAtLeastOne { if HasPermission(permissions, required) { @@ -116,6 +124,7 @@ func HasPermissionAtLeastOne(permissions []Permission, requiredAtLeastOne []Perm return false } +// HasPermissionAll checks if the user has all the required permissions func HasPermissionAll(permissions []Permission, requiredAll []Permission) bool { for _, required := range requiredAll { if !HasPermission(permissions, required) { diff --git a/internals/security/permissions/permission_test.go b/internals/security/permissions/permission_test.go index 0f51998a..62bb048a 100644 --- a/internals/security/permissions/permission_test.go +++ b/internals/security/permissions/permission_test.go @@ -117,7 +117,7 @@ func TestGetResourceIDs(t *testing.T) { New("fact", "5", "*"), } - resourceIDs := GetRessourceIDs(ListMatchingPermissions(permissions, New("situation", "*", "create"))) + resourceIDs := GetResourceIDs(ListMatchingPermissions(permissions, New("situation", "*", "create"))) if len(resourceIDs) != 3 { t.Error("invalid resourceIDs") } diff --git a/internals/security/permissions/postgres_repository.go b/internals/security/permissions/postgres_repository.go index bb49b965..58cc42fa 100644 --- a/internals/security/permissions/postgres_repository.go +++ b/internals/security/permissions/postgres_repository.go @@ -5,7 +5,7 @@ import ( "errors" sq "github.com/Masterminds/squirrel" - uuid "github.com/google/uuid" + "github.com/google/uuid" "github.com/jmoiron/sqlx" "go.uber.org/zap" ) @@ -29,7 +29,7 @@ func NewPostgresRepository(dbClient *sqlx.DB) Repository { return ifm } -//Get search and returns an User Permission from the repository by its id +// Get search and returns an User Permission from the repository by its id func (r *PostgresRepository) Get(permissionUUID uuid.UUID) (Permission, bool, error) { rows, err := r.newStatement(). Select(fields...). diff --git a/internals/security/users/user.go b/internals/security/users/user.go index 2197da7d..542dca8b 100644 --- a/internals/security/users/user.go +++ b/internals/security/users/user.go @@ -13,7 +13,7 @@ import ( // User is used as the main user struct type User struct { ID uuid.UUID `json:"id"` - Login string `json:"login"` + Login string `json:"login"` // is the unique identifier of the user, through the different connection modes Created time.Time `json:"created"` LastName string `json:"lastName"` FirstName string `json:"firstName"` @@ -102,12 +102,12 @@ func (u UserWithPermissions) ListMatchingPermissions(match permissions.Permissio } func (u UserWithPermissions) GetMatchingResourceIDs(match permissions.Permission) []string { - return permissions.GetRessourceIDs(permissions.ListMatchingPermissions(u.Permissions, match)) + return permissions.GetResourceIDs(permissions.ListMatchingPermissions(u.Permissions, match)) } func (u UserWithPermissions) GetMatchingResourceIDsInt64(match permissions.Permission) []int64 { ids := make([]int64, 0) - for _, resourceID := range permissions.GetRessourceIDs(permissions.ListMatchingPermissions(u.Permissions, match)) { + for _, resourceID := range permissions.GetResourceIDs(permissions.ListMatchingPermissions(u.Permissions, match)) { if resourceID == permissions.All { continue } diff --git a/internals/tasker/situation_reporting.go b/internals/tasker/situation_reporting.go index 113291a3..e0f7ebb0 100644 --- a/internals/tasker/situation_reporting.go +++ b/internals/tasker/situation_reporting.go @@ -30,18 +30,16 @@ func verifyCache(key string, timeout time.Duration) bool { // SituationReportingTask struct for close issues created in the current day from the BRMS type SituationReportingTask struct { - ID string `json:"id"` - IssueID string `json:"issueId"` - Subject string `json:"subject"` - BodyTemplate string `json:"bodyTemplate"` - To []string `json:"to"` - AttachmentFileNames []string `json:"attachmentFileNames"` - AttachmentFactIDs []int64 `json:"attachmentFactIds"` - Columns []string `json:"columns"` - FormatColumnsData map[string]string `json:"formateColumns"` - ColumnsLabel []string `json:"columnsLabel"` - Separator rune `json:"separator"` - Timeout string `json:"timeout"` + ID string `json:"id"` + IssueID string `json:"issueId"` + Subject string `json:"subject"` + BodyTemplate string `json:"bodyTemplate"` + To []string `json:"to"` + AttachmentFileNames []string `json:"attachmentFileNames"` + AttachmentFactIDs []int64 `json:"attachmentFactIds"` + Columns []export.Column `json:"columns"` + Separator rune `json:"separator"` + Timeout string `json:"timeout"` } func buildSituationReportingTask(parameters map[string]interface{}) (SituationReportingTask, error) { @@ -100,25 +98,43 @@ func buildSituationReportingTask(parameters map[string]interface{}) (SituationRe } if val, ok := parameters["columns"].(string); ok && val != "" { - task.Columns = strings.Split(val, ",") - } + columns := strings.Split(val, ",") + var columnsLabel []string + + if val, ok = parameters["columnsLabel"].(string); ok && val != "" { + columnsLabel = strings.Split(val, ",") + } - if val, ok := parameters["formateColumns"].(string); ok && val != "" { - formatColumnsData := strings.Split(val, ",") - task.FormatColumnsData = make(map[string]string) - for _, formatData := range formatColumnsData { - parts := strings.Split(formatData, ";") - if len(parts) != 2 { - continue + if len(columns) != len(columnsLabel) { + return task, errors.New("parameters 'columns' and 'columns label' have different length") + } + + formatColumnsDataMap := make(map[string]string) + + if val, ok = parameters["formateColumns"].(string); ok && val != "" { + formatColumnsData := strings.Split(val, ",") + for _, formatData := range formatColumnsData { + parts := strings.Split(formatData, ";") + if len(parts) != 2 { + continue + } + key := strings.TrimSpace(parts[0]) + formatColumnsDataMap[key] = parts[1] } - key := strings.TrimSpace(parts[0]) - task.FormatColumnsData[key] = parts[1] } - } + for i, column := range columns { + exportColumn := export.Column{ + Name: column, + Label: columnsLabel[i], + } + + if format, ok := formatColumnsDataMap[column]; ok { + exportColumn.Format = format + } - if val, ok := parameters["columnsLabel"].(string); ok && val != "" { - task.ColumnsLabel = strings.Split(val, ",") + task.Columns = append(task.Columns, exportColumn) + } } if val, ok := parameters["separator"].(string); ok && val != "" { @@ -127,10 +143,6 @@ func buildSituationReportingTask(parameters map[string]interface{}) (SituationRe task.Separator = ',' } - if len(task.Columns) != len(task.ColumnsLabel) { - return task, errors.New("parameters 'columns' and 'columns label' have different length") - } - if val, ok := parameters["timeout"].(string); ok && val != "" { task.Timeout = val } else { @@ -202,7 +214,7 @@ func (task SituationReportingTask) Perform(key string, context ContextData) erro return err } - csvAttachment, err := export.ConvertHitsToCSV(fullHits, task.Columns, task.ColumnsLabel, task.FormatColumnsData, task.Separator) + csvAttachment, err := export.ConvertHitsToCSV(fullHits, export.CSVParameters{Columns: task.Columns, Separator: string(task.Separator)}, true) if err != nil { return err } diff --git a/internals/utils/utils.go b/internals/utils/utils.go index dae701c3..4cbfd750 100644 --- a/internals/utils/utils.go +++ b/internals/utils/utils.go @@ -1,12 +1,13 @@ package utils -func RemoveDuplicates(stringSlice []string) []string { - keys := make(map[string]bool) - list := []string{} - for _, entry := range stringSlice { - if _, value := keys[entry]; !value { - keys[entry] = true - list = append(list, entry) +// RemoveDuplicates remove duplicate values from a slice +func RemoveDuplicates[T string | int | int64](sliceList []T) []T { + allKeys := make(map[T]bool) + var list []T + for _, item := range sliceList { + if _, value := allKeys[item]; !value { + allKeys[item] = true + list = append(list, item) } } return list diff --git a/internals/utils/utils_test.go b/internals/utils/utils_test.go new file mode 100644 index 00000000..ab230f88 --- /dev/null +++ b/internals/utils/utils_test.go @@ -0,0 +1,51 @@ +package utils + +import "testing" + +func TestRemoveDuplicates_Int64(t *testing.T) { + sample := []int64{1, 1, 1, 2, 2, 3, 4} + expectedResult := []int64{1, 2, 3, 4} + result := RemoveDuplicates(sample) + + if len(result) != len(expectedResult) { + t.FailNow() + } + + for i := 0; i < len(expectedResult); i++ { + if expectedResult[i] != result[i] { + t.FailNow() + } + } +} + +func TestRemoveDuplicates_Int(t *testing.T) { + sample := []int{1, 1, 1, 2, 2, 3, 4} + expectedResult := []int{1, 2, 3, 4} + result := RemoveDuplicates(sample) + + if len(result) != len(expectedResult) { + t.FailNow() + } + + for i := 0; i < len(expectedResult); i++ { + if expectedResult[i] != result[i] { + t.FailNow() + } + } +} + +func TestRemoveDuplicates_String(t *testing.T) { + sample := []string{"a", "a", "a", "b", "b", "c", "d"} + expectedResult := []string{"a", "b", "c", "d"} + result := RemoveDuplicates(sample) + + if len(result) != len(expectedResult) { + t.FailNow() + } + + for i := 0; i < len(expectedResult); i++ { + if expectedResult[i] != result[i] { + t.FailNow() + } + } +} diff --git a/internals/variablesconfig/postgres_repository.go b/internals/variablesconfig/postgres_repository.go index ec68e6f5..4e849c8f 100644 --- a/internals/variablesconfig/postgres_repository.go +++ b/internals/variablesconfig/postgres_repository.go @@ -147,8 +147,7 @@ func (r *PostgresRepository) Delete(id int64) error { // GetAll method used to get all Variables Config func (r *PostgresRepository) GetAll() ([]models.VariablesConfig, error) { - - var variablesConfig []models.VariablesConfig + variablesConfig := make([]models.VariablesConfig, 0) rows, err := r.newStatement(). Select("id", "key", "value"). diff --git a/main.go b/main.go index 12fcf053..240cf2c3 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,9 @@ package main import ( "context" + "errors" + "github.com/myrteametrics/myrtea-engine-api/v5/internals/export" + "github.com/myrteametrics/myrtea-engine-api/v5/internals/handlers" "github.com/myrteametrics/myrtea-engine-api/v5/internals/metrics" "net/http" "os" @@ -39,7 +42,6 @@ var ( // @name Authorization func main() { - hostname, _ := os.Hostname() metrics.InitMetricLabels(hostname) @@ -68,10 +70,27 @@ func main() { GatewayMode: viper.GetBool("HTTP_SERVER_API_ENABLE_GATEWAY_MODE"), AuthenticationMode: viper.GetString("AUTHENTICATION_MODE"), LogLevel: zapConfig.Level, - PluginCore: core, } - router := router.New(routerConfig) + // Exports + directDownload := viper.GetBool("EXPORT_DIRECT_DOWNLOAD") + indirectDownloadUrl := viper.GetString("EXPORT_INDIRECT_DOWNLOAD_URL") + + exportWrapper := export.NewWrapper( + viper.GetString("EXPORT_BASE_PATH"), // basePath + viper.GetInt("EXPORT_WORKERS_COUNT"), // workersCount + viper.GetInt("EXPORT_DISK_RETENTION_DAYS"), // diskRetentionDays + viper.GetInt("EXPORT_QUEUE_MAX_SIZE"), // queueMaxSize + ) + exportWrapper.Init(context.Background()) + + routerServices := router.Services{ + PluginCore: core, + ProcessorHandler: handlers.NewProcessorHandler(), + ExportHandler: handlers.NewExportHandler(exportWrapper, directDownload, indirectDownloadUrl), + } + + router := router.New(routerConfig, routerServices) var srv *http.Server if serverEnableTLS { srv = server.NewSecuredServer(serverPort, serverTLSCert, serverTLSKey, router) @@ -89,7 +108,7 @@ func main() { } else { err = srv.ListenAndServe() } - if err != nil && err != http.ErrServerClosed { + if err != nil && !errors.Is(err, http.ErrServerClosed) { zap.L().Fatal("Server listen", zap.Error(err)) } }() diff --git a/plugins/standalone/plugin.go b/plugins/standalone/plugin.go index b2cb4bb3..c0c91b45 100644 --- a/plugins/standalone/plugin.go +++ b/plugins/standalone/plugin.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "runtime" + "time" "github.com/go-chi/chi/v5" "github.com/hashicorp/go-plugin" @@ -60,7 +61,9 @@ func NewPlugin(config pluginutils.PluginConfig) *Plugin { HandshakeConfig: Handshake, Plugins: pluginMap, Cmd: cmd, + StartTimeout: 2 * time.Minute, AllowedProtocols: []plugin.Protocol{plugin.ProtocolNetRPC}, + SkipHostEnv: false, }, } }