Skip to content

Commit

Permalink
MG-2287 - Improve search for Things (#2305)
Browse files Browse the repository at this point in the history
Signed-off-by: nyagamunene <[email protected]>
  • Loading branch information
nyagamunene authored Jul 4, 2024
1 parent e86eda5 commit 1302441
Show file tree
Hide file tree
Showing 21 changed files with 67 additions and 280 deletions.
4 changes: 2 additions & 2 deletions pkg/clients/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ type Repository interface {
// RetrieveAll retrieves all clients.
RetrieveAll(ctx context.Context, pm Page) (ClientsPage, error)

// SearchBasicInfo list all clients only with basic information.
SearchBasicInfo(ctx context.Context, pm Page) (ClientsPage, error)
// SearchClients retrieves clients based on search criteria.
SearchClients(ctx context.Context, pm Page) (ClientsPage, error)

// RetrieveAllByIDs retrieves for given client IDs .
RetrieveAllByIDs(ctx context.Context, pm Page) (ClientsPage, error)
Expand Down
86 changes: 31 additions & 55 deletions pkg/clients/postgres/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ func (repo *Repository) RetrieveAll(ctx context.Context, pm clients.Page) (clien
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
query = applyOrdering(query, pm)

q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.identity, c.metadata, COALESCE(c.domain_id, '') AS domain_id, c.status,
c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM clients c %s ORDER BY c.created_at LIMIT :limit OFFSET :offset;`, query)
Expand Down Expand Up @@ -190,10 +191,16 @@ func (repo *Repository) RetrieveAll(ctx context.Context, pm clients.Page) (clien
return page, nil
}

func (repo *Repository) SearchBasicInfo(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) {
sq, tq := ConstructSearchQuery(pm)
func (repo *Repository) SearchClients(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) {
query, err := PageQuery(pm)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}

q := fmt.Sprintf(`SELECT c.id, c.name, c.created_at, c.updated_at FROM clients c %s LIMIT :limit OFFSET :offset;`, sq)
tq := query
query = applyOrdering(query, pm)

q := fmt.Sprintf(`SELECT c.id, c.name, c.created_at, c.updated_at FROM clients c %s LIMIT :limit OFFSET :offset;`, query)

dbPage, err := ToDBClientsPage(pm)
if err != nil {
Expand Down Expand Up @@ -249,6 +256,7 @@ func (repo *Repository) RetrieveAllByIDs(ctx context.Context, pm clients.Page) (
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
query = applyOrdering(query, pm)

q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.identity, c.metadata, COALESCE(c.domain_id, '') AS domain_id, c.status,
c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM clients c %s ORDER BY c.created_at LIMIT :limit OFFSET :offset;`, query)
Expand Down Expand Up @@ -334,24 +342,6 @@ func (repo *Repository) Delete(ctx context.Context, id string) error {
return nil
}

func (repo *Repository) CheckSuperAdmin(ctx context.Context, adminID string) error {
q := "SELECT 1 FROM clients WHERE id = $1 AND role = $2"
rows, err := repo.DB.QueryContext(ctx, q, adminID, clients.AdminRole)
if err != nil {
return postgres.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()

if rows.Next() {
if err := rows.Err(); err != nil {
return postgres.HandleError(repoerr.ErrViewEntity, err)
}
return nil
}

return repoerr.ErrNotFound
}

type DBClient struct {
ID string `db:"id"`
Name string `db:"name,omitempty"`
Expand Down Expand Up @@ -487,14 +477,8 @@ func PageQuery(pm clients.Page) (string, error) {
if err != nil {
return "", errors.Wrap(errors.ErrMalformedEntity, err)
}

var query []string
var emq string
if mq != "" {
query = append(query, mq)
}
if len(pm.IDs) != 0 {
query = append(query, fmt.Sprintf("id IN ('%s')", strings.Join(pm.IDs, "','")))
}
if pm.Name != "" {
query = append(query, "name ILIKE '%' || :name || '%'")
}
Expand All @@ -505,52 +489,44 @@ func PageQuery(pm clients.Page) (string, error) {
query = append(query, "id ILIKE '%' || :id || '%'")
}
if pm.Tag != "" {
query = append(query, ":tag = ANY(c.tags)")
}
if pm.Status != clients.AllStatus {
query = append(query, "c.status = :status")
}
if pm.Domain != "" {
query = append(query, "c.domain_id = :domain_id")
query = append(query, "EXISTS (SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE '%' || :tag || '%')")
}

if pm.Role != clients.AllRole {
query = append(query, "c.role = :role")
}
if len(query) > 0 {
emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND "))
// If there are search params presents, use search and ignore other options.
// Always combine role with search params, so len(query) > 1.
if len(query) > 1 {
return fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")), nil
}
return emq, nil
}

func ConstructSearchQuery(pm clients.Page) (string, string) {
var query []string
var emq string
var tq string
if mq != "" {
query = append(query, mq)
}

if pm.Name != "" {
query = append(query, "name ILIKE '%' || :name || '%'")
if len(pm.IDs) != 0 {
query = append(query, fmt.Sprintf("id IN ('%s')", strings.Join(pm.IDs, "','")))
}
if pm.Identity != "" {
query = append(query, "identity ILIKE '%' || :identity || '%'")
if pm.Status != clients.AllStatus {
query = append(query, "c.status = :status")
}
if pm.Id != "" {
query = append(query, "id ILIKE '%' || :id || '%'")
if pm.Domain != "" {
query = append(query, "c.domain_id = :domain_id")
}

var emq string
if len(query) > 0 {
emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND "))
}
return emq, nil
}

tq = emq

func applyOrdering(emq string, pm clients.Page) string {
switch pm.Order {
case "name", "identity", "created_at", "updated_at":
emq = fmt.Sprintf("%s ORDER BY %s", emq, pm.Order)
if pm.Dir == api.AscDir || pm.Dir == api.DescDir {
emq = fmt.Sprintf("%s %s", emq, pm.Dir)
}
}

return emq, tq
return emq
}
4 changes: 2 additions & 2 deletions pkg/clients/postgres/clients_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ func TestRetrieveByIDs(t *testing.T) {
}
}

func TestSearchBasicInfo(t *testing.T) {
func TestSearchClients(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM clients")
require.Nil(t, err, fmt.Sprintf("clean clients unexpected error: %s", err))
Expand Down Expand Up @@ -1289,7 +1289,7 @@ func TestSearchBasicInfo(t *testing.T) {
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
switch response, err := repo.SearchBasicInfo(context.Background(), c.page); {
switch response, err := repo.SearchClients(context.Background(), c.page); {
case err == nil:
if c.page.Order != "" && c.page.Dir != "" {
c.response = response
Expand Down
6 changes: 3 additions & 3 deletions pkg/sdk/go/things_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ func TestListThings(t *testing.T) {
repoCall1 = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: false}, svcerr.ErrAuthorization)
repoCall2 = auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{}, svcerr.ErrAuthorization)
}
repoCall3 := cRepo.On("RetrieveAllByIDs", mock.Anything, mock.Anything).Return(mgclients.ClientsPage{Page: convertClientPage(pm), Clients: convertThings(tc.response...)}, tc.err)
repoCall3 := cRepo.On("SearchClients", mock.Anything, mock.Anything).Return(mgclients.ClientsPage{Page: convertClientPage(pm), Clients: convertThings(tc.response...)}, tc.err)
page, err := mgsdk.Things(pm, validToken)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page.Things, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
Expand Down Expand Up @@ -1042,7 +1042,7 @@ func TestEnableThing(t *testing.T) {
repoCall1 = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: false}, svcerr.ErrAuthorization)
}
repoCall2 := auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{Policies: toIDs(tc.response.Things)}, nil)
repoCall3 := cRepo.On("RetrieveAllByIDs", mock.Anything, mock.Anything).Return(convertThingsPage(tc.response), nil)
repoCall3 := cRepo.On("SearchClients", mock.Anything, mock.Anything).Return(convertThingsPage(tc.response), nil)
clientsPage, err := mgsdk.Things(pm, validToken)
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
size := uint64(len(clientsPage.Things))
Expand Down Expand Up @@ -1179,7 +1179,7 @@ func TestDisableThing(t *testing.T) {
repoCall1 = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: false}, svcerr.ErrAuthorization)
}
repoCall2 := auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{Policies: toIDs(tc.response.Things)}, nil)
repoCall3 := cRepo.On("RetrieveAllByIDs", mock.Anything, mock.Anything).Return(convertThingsPage(tc.response), nil)
repoCall3 := cRepo.On("SearchClients", mock.Anything, mock.Anything).Return(convertThingsPage(tc.response), nil)
page, err := mgsdk.Things(pm, validToken)
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
size := uint64(len(page.Things))
Expand Down
5 changes: 5 additions & 0 deletions things/api/http/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ func decodeListClients(_ context.Context, r *http.Request) (interface{}, error)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
id, err := apiutil.ReadStringQuery(r, api.IDOrder, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
p, err := apiutil.ReadStringQuery(r, api.PermissionKey, api.DefPermission)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
Expand All @@ -205,6 +209,7 @@ func decodeListClients(_ context.Context, r *http.Request) (interface{}, error)
permission: p,
listPerms: lp,
userID: chi.URLParam(r, "userID"),
id: id,
}
return req, nil
}
Expand Down
1 change: 1 addition & 0 deletions things/api/http/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ func listClientsEndpoint(svc things.Service) endpoint.Endpoint {
Metadata: req.metadata,
ListPerms: req.listPerms,
Role: mgclients.AllRole, // retrieve all things since things don't have roles
Id: req.id,
}
page, err := svc.ListClients(ctx, req.token, req.userID, pm)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions things/api/http/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ type listClientsReq struct {
userID string
listPerms bool
metadata mgclients.Metadata
id string
}

func (req listClientsReq) validate() error {
Expand Down
1 change: 1 addition & 0 deletions things/api/http/requests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
const (
valid = "valid"
invalid = "invalid"
name = "client"
)

var validID = testsutil.GenerateUUID(&testing.T{})
Expand Down
6 changes: 3 additions & 3 deletions things/mocks/repository.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion things/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (svc service) ListClients(ctx context.Context, token, reqUserID string, pm

pm.IDs = ids

tp, err := svc.clients.RetrieveAllByIDs(ctx, pm)
tp, err := svc.clients.SearchClients(ctx, pm)
if err != nil {
return mgclients.ClientsPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
Expand Down
18 changes: 11 additions & 7 deletions things/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
var (
secret = "strongsecret"
validCMetadata = mgclients.Metadata{"role": "client"}
ID = testsutil.GenerateUUID(&testing.T{})
ID = "6e5e10b3-d4df-4758-b426-4929d55ad740"
client = mgclients.Client{
ID: ID,
Name: "clientname",
Expand Down Expand Up @@ -428,6 +428,7 @@ func TestListClients(t *testing.T) {
identifyResponse *magistrala.IdentityRes
authorizeResponse *magistrala.AuthorizeRes
authorizeResponse1 *magistrala.AuthorizeRes
authorizeResponse2 *magistrala.AuthorizeRes
listObjectsResponse *magistrala.ListObjectsRes
listObjectsResponse1 *magistrala.ListObjectsRes
retrieveAllResponse mgclients.ClientsPage
Expand All @@ -438,6 +439,7 @@ func TestListClients(t *testing.T) {
identifyErr error
authorizeErr error
authorizeErr1 error
authorizeErr2 error
listObjectsErr error
retrieveAllErr error
listPermissionsErr error
Expand All @@ -455,6 +457,7 @@ func TestListClients(t *testing.T) {
},
identifyResponse: &magistrala.IdentityRes{Id: nonAdminID, UserId: nonAdminID, DomainId: domainID},
authorizeResponse: &magistrala.AuthorizeRes{Authorized: true},
authorizeResponse2: &magistrala.AuthorizeRes{Authorized: true},
listObjectsResponse: &magistrala.ListObjectsRes{},
retrieveAllResponse: mgclients.ClientsPage{
Page: mgclients.Page{
Expand Down Expand Up @@ -531,8 +534,9 @@ func TestListClients(t *testing.T) {
Limit: 100,
ListPerms: true,
},
identifyResponse: &magistrala.IdentityRes{Id: nonAdminID, UserId: nonAdminID, DomainId: domainID},
authorizeResponse: &magistrala.AuthorizeRes{Authorized: true},
identifyResponse: &magistrala.IdentityRes{Id: nonAdminID, UserId: nonAdminID, DomainId: domainID},
authorizeResponse: &magistrala.AuthorizeRes{Authorized: true},
authorizeResponse2: &magistrala.AuthorizeRes{Authorized: true},
retrieveAllResponse: mgclients.ClientsPage{
Page: mgclients.Page{
Total: 2,
Expand Down Expand Up @@ -619,7 +623,7 @@ func TestListClients(t *testing.T) {
Object: tc.identifyResponse.DomainId,
}).Return(tc.authorizeResponse1, tc.authorizeErr1)
listAllObjectsCall := auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(tc.listObjectsResponse, tc.listObjectsErr)
retrieveAllCall := cRepo.On("RetrieveAllByIDs", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
retrieveAllCall := cRepo.On("SearchClients", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
listPermissionsCall := auth.On("ListPermissions", mock.Anything, mock.Anything).Return(tc.listPermissionsResponse, tc.listPermissionsErr)

page, err := svc.ListClients(context.Background(), tc.token, tc.id, tc.page)
Expand Down Expand Up @@ -799,7 +803,7 @@ func TestListClients(t *testing.T) {
Permission: "",
ObjectType: authsvc.ThingType,
}).Return(tc.listObjectsResponse1, tc.listObjectsErr1)
retrieveAllCall := cRepo.On("RetrieveAllByIDs", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
retrieveAllCall := cRepo.On("SearchClients", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
listPermissionsCall := auth.On("ListPermissions", mock.Anything, mock.Anything).Return(tc.listPermissionsResponse, tc.listPermissionsErr)

page, err := svc.ListClients(context.Background(), tc.token, tc.id, tc.page)
Expand Down Expand Up @@ -1193,7 +1197,7 @@ func TestEnableClient(t *testing.T) {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{Policies: getIDs(tc.response.Clients)}, nil)
repoCall3 := cRepo.On("RetrieveAllByIDs", context.Background(), mock.Anything).Return(tc.response, nil)
repoCall3 := cRepo.On("SearchClients", context.Background(), mock.Anything).Return(tc.response, nil)
page, err := svc.ListClients(context.Background(), validToken, "", pm)
require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
size := uint64(len(page.Clients))
Expand Down Expand Up @@ -1363,7 +1367,7 @@ func TestDisableClient(t *testing.T) {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{Policies: getIDs(tc.response.Clients)}, nil)
repoCall3 := cRepo.On("RetrieveAllByIDs", context.Background(), mock.Anything).Return(tc.response, nil)
repoCall3 := cRepo.On("SearchClients", context.Background(), mock.Anything).Return(tc.response, nil)
page, err := svc.ListClients(context.Background(), validToken, "", pm)
require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
size := uint64(len(page.Clients))
Expand Down
21 changes: 0 additions & 21 deletions users/api/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,27 +381,6 @@ func (lm *loggingMiddleware) ListMembers(ctx context.Context, token, objectKind,
return lm.svc.ListMembers(ctx, token, objectKind, objectID, cp)
}

// SearchClients logs the search_clients request. It logs the page metadata and the time it took to complete the request.
func (lm *loggingMiddleware) SearchUsers(ctx context.Context, token string, cp mgclients.Page) (mp mgclients.ClientsPage, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.Group("page",
slog.Uint64("limit", cp.Limit),
slog.Uint64("offset", cp.Offset),
slog.Uint64("total", mp.Total),
),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Search clients failed to complete successfully", args...)
return
}
lm.logger.Info("Search clients completed successfully", args...)
}(time.Now())
return lm.svc.SearchUsers(ctx, token, cp)
}

// Identify logs the identify request. It logs the time it took to complete the request.
func (lm *loggingMiddleware) Identify(ctx context.Context, token string) (id string, err error) {
defer func(begin time.Time) {
Expand Down
9 changes: 0 additions & 9 deletions users/api/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,6 @@ func (ms *metricsMiddleware) ListMembers(ctx context.Context, token, objectKind,
return ms.svc.ListMembers(ctx, token, objectKind, objectID, pm)
}

// SearchClients instruments SearchClients method with metrics.
func (ms *metricsMiddleware) SearchUsers(ctx context.Context, token string, pm mgclients.Page) (mp mgclients.ClientsPage, err error) {
defer func(begin time.Time) {
ms.counter.With("method", "search_clients").Add(1)
ms.latency.With("method", "search_clients").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.SearchUsers(ctx, token, pm)
}

// Identify instruments Identify method with metrics.
func (ms *metricsMiddleware) Identify(ctx context.Context, token string) (string, error) {
defer func(begin time.Time) {
Expand Down
Loading

0 comments on commit 1302441

Please sign in to comment.