Skip to content

Commit

Permalink
collections: added task to compute num of records
Browse files Browse the repository at this point in the history
* added task to compute number of records for all the collections
* added "collection_id" parameter to record search
* added service methods to read collections (many, all)
* added tests
* collections: refactor 'resolve' to 'read'
* collections: rename 'search_records' method
* collections: update read method signature
  • Loading branch information
alejandromumo committed Oct 25, 2024
1 parent 970aa7b commit 9e56947
Show file tree
Hide file tree
Showing 10 changed files with 304 additions and 29 deletions.
30 changes: 22 additions & 8 deletions invenio_rdm_records/collections/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def create(cls, slug, title, query, ctree=None, parent=None, order=None, depth=2
)

@classmethod
def resolve(cls, *, id_=None, slug=None, ctree_id=None, depth=2):
"""Resolve a collection by ID or slug.
def read(cls, *, id_=None, slug=None, ctree_id=None, depth=2):
"""Read a collection by ID or slug.
To resolve by slug, the collection tree ID must be provided.
To read by slug, the collection tree ID must be provided.
"""
res = None
if id_:
Expand All @@ -89,10 +89,21 @@ def resolve(cls, *, id_=None, slug=None, ctree_id=None, depth=2):
return res

@classmethod
def resolve_many(cls, ids_=None, depth=2):
"""Resolve many collections by ID."""
_ids = ids_ or []
return [cls(c, depth) for c in cls.model_cls.read_many(_ids)]
def read_many(cls, ids_, depth=2):
"""Read many collections by ID."""
return [cls(c, depth) for c in cls.model_cls.read_many(ids_)]

@classmethod
def read_all(cls, depth=2):
"""Read all collections."""
return [cls(c, depth) for c in cls.model_cls.read_all()]

def update(self, **kwargs):
"""Update the collection."""
if "search_query" in kwargs:
Collection.validate_query(kwargs["search_query"])
self.model.update(**kwargs)
return self

def add(self, slug, title, query, order=None, depth=2):
"""Add a subcollection to the collection."""
Expand Down Expand Up @@ -128,7 +139,10 @@ def query(self):
@cached_property
def ancestors(self):
"""Get the collection ancestors."""
return Collection.resolve_many(self.split_path_to_ids())
ids_ = self.split_path_to_ids()
if not ids_:
return []
return Collection.read_many(ids_)

@cached_property
def subcollections(self):
Expand Down
20 changes: 18 additions & 2 deletions invenio_rdm_records/collections/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,25 @@ def get_by_slug(cls, slug, tree_id):
return cls.query.filter(cls.slug == slug, cls.tree_id == tree_id).one_or_none()

@classmethod
def read_many(cls, ids):
def read_many(cls, ids_):
"""Get many collections by ID."""
return cls.query.filter(cls.id.in_(ids)).order_by(cls.path, cls.order)
return cls.query.filter(cls.id.in_(ids_)).order_by(cls.path, cls.order)

@classmethod
def read_all(cls):
"""Get all collections.
The collections are ordered by ``path`` and ``order``, which means:
- By path: the collections are ordered in a breadth-first manner (first come the root collection, then the next level, and so on)
- By order: between the same level collections, they are ordered by the specified order field.
"""
return cls.query.order_by(cls.path, cls.order)

def update(self, **kwargs):
"""Update a collection."""
for key, value in kwargs.items():
setattr(self, key, value)

@classmethod
def get_children(cls, model):
Expand Down
22 changes: 16 additions & 6 deletions invenio_rdm_records/collections/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def to_dict(self):
res = {
"root": self._collection.id,
self._collection.id: {
**self._schema.dump(self._collection),
**self._schema.dump(
self._collection, context={"identity": self._identity}
),
"children": list(),
"links": self._links_tpl.expand(self._identity, self._collection),
},
Expand All @@ -76,7 +78,7 @@ def to_dict(self):
if _c.id not in res:
# Add the subcollection to the dictionary
res[_c.id] = {
**self._schema.dump(_c),
**self._schema.dump(_c, context={"identity": self._identity}),
"children": list(),
"links": self._links_tpl.expand(self._identity, _c),
}
Expand Down Expand Up @@ -121,22 +123,30 @@ def query(self):
class CollectionList(ServiceListResult):
"""Collection list item."""

def __init__(self, collections):
def __init__(self, identity, collections, schema, links_tpl, links_item_tpl):
"""Instantiate a Collection list item."""
self._identity = identity
self._collections = collections
self._schema = schema
self._links_tpl = links_tpl
self._links_item_tpl = links_item_tpl

def to_dict(self):
"""Serialize the collection list to a dictionary."""
res = []
for collection in self._collections:
_r = collection.to_dict()
_r["links"] = CollectionItem(collection).links
_r = CollectionItem(
self._identity, collection, self._schema, self._links_item_tpl
).to_dict()
res.append(_r)
return res

def __iter__(self):
"""Iterate over the collections."""
return iter(self._collections)
return (
CollectionItem(self._identity, x, self._schema, self._links_item_tpl)
for x in self._collections
)


class CollectionTreeItem:
Expand Down
5 changes: 3 additions & 2 deletions invenio_rdm_records/collections/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class CollectionSchema(Schema):

slug = fields.Str()
title = fields.Str()
depth = fields.Int()
depth = fields.Int(dump_only=True)
order = fields.Int()
id = fields.Int()
id = fields.Int(dump_only=True)
num_records = fields.Int()
search_query = fields.Str(load_only=True)
93 changes: 86 additions & 7 deletions invenio_rdm_records/collections/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@
import os

from flask import current_app, url_for
from invenio_records_resources.services import ServiceSchemaWrapper
from invenio_records_resources.services.base import Service
from invenio_records_resources.services.uow import ModelCommitOp, unit_of_work

from invenio_rdm_records.proxies import (
current_community_records_service,
current_rdm_records_service,
)

from .api import Collection, CollectionTree
from .errors import LogoNotFoundError
from .links import CollectionLinkstemplate
from .results import CollectionItem, CollectionTreeList
from .results import CollectionItem, CollectionList, CollectionTreeList


class CollectionsService(Service):
Expand All @@ -30,7 +36,7 @@ def __init__(self, config):
@property
def collection_schema(self):
"""Get the collection schema."""
return self.config.schema()
return ServiceSchemaWrapper(self, schema=self.config.schema)

@property
def links_item_tpl(self):
Expand All @@ -46,7 +52,11 @@ def links_item_tpl(self):
def create(
self, identity, community_id, tree_slug, slug, title, query, uow=None, **kwargs
):
"""Create a new collection."""
"""Create a new collection.
The created collection will be added to the collection tree as a root collection (no parent).
If a parent is needed, use the ``add`` method.
"""
self.require_permission(identity, "update", community_id=community_id)
ctree = CollectionTree.resolve(slug=tree_slug, community_id=community_id)
collection = self.collection_cls.create(
Expand All @@ -59,9 +69,8 @@ def create(

def read(
self,
/,
*,
identity=None,
*,
id_=None,
slug=None,
community_id=None,
Expand All @@ -74,10 +83,10 @@ def read(
To resolve by slug, the collection tree ID and community ID must be provided.
"""
if id_:
collection = self.collection_cls.resolve(id_=id_, depth=depth)
collection = self.collection_cls.read(id_=id_, depth=depth)
elif slug and tree_slug and community_id:
ctree = CollectionTree.resolve(slug=tree_slug, community_id=community_id)
collection = self.collection_cls.resolve(
collection = self.collection_cls.read(
slug=slug, ctree_id=ctree.id, depth=depth
)
else:
Expand Down Expand Up @@ -121,6 +130,31 @@ def add(self, identity, collection, slug, title, query, uow=None, **kwargs):
identity, new_collection, self.collection_schema, self.links_item_tpl
)

@unit_of_work()
def update(self, identity, collection_or_id, data=None, uow=None):
"""Update a collection."""
if isinstance(collection_or_id, int):
collection = self.collection_cls.read(id_=collection_or_id)
else:
collection = collection_or_id
self.require_permission(
identity, "update", community_id=collection.community.id
)

data = data or {}

valid_data, errors = self.collection_schema.load(
data, context={"identity": identity}, raise_errors=True
)

res = collection.update(**valid_data)

uow.register(ModelCommitOp(res.model))

return CollectionItem(
identity, collection, self.collection_schema, self.links_item_tpl
)

def read_logo(self, identity, slug):
"""Read a collection logo.
Expand All @@ -131,3 +165,48 @@ def read_logo(self, identity, slug):
if _exists:
return url_for("static", filename=logo_path)
raise LogoNotFoundError()

def read_many(self, identity, ids_, depth=2):
"""Get many collections."""
self.require_permission(identity, "read")

if ids_ is None:
raise ValueError("IDs must be provided.")

if ids_ == []:
raise ValueError("Use read_all to get all collections.")

res = self.collection_cls.read_many(ids_, depth=depth)
return CollectionList(
identity, res, self.collection_schema, None, self.links_item_tpl
)

def read_all(self, identity, depth=2):
"""Get all collections."""
self.require_permission(identity, "read")
res = self.collection_cls.read_all(depth=depth)
return CollectionList(
identity, res, self.collection_schema, None, self.links_item_tpl
)

def search_collection_records(self, identity, collection_or_id, params=None):
"""Search records in a collection."""
params = params or {}

if isinstance(collection_or_id, int):
collection = self.collection_cls.read(id_=collection_or_id)
else:
collection = collection_or_id

params.update({"collection_id": collection.id})
if collection.community:
res = current_community_records_service.search(
identity,
community_id=collection.community.id,
params=params,
)
else:
raise NotImplementedError(
"Search for collections without community not supported."
)
return res
31 changes: 31 additions & 0 deletions invenio_rdm_records/collections/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# Invenio-RDM-Records is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.
"""Collections celery tasks."""

from celery import shared_task
from flask import current_app
from invenio_access.permissions import system_identity

from invenio_rdm_records.proxies import current_rdm_records


@shared_task(ignore_result=True)
def update_collections_size():
"""Calculate and update the size of all the collections."""
collections_service = current_rdm_records.collections_service
res = collections_service.read_all(system_identity, depth=0)
for citem in res:
try:
collection = citem._collection
res = collections_service.search_collection_records(
system_identity, collection
)
collections_service.update(
system_identity, collection, data={"num_records": res.total}
)
except Exception as e:
current_app.logger.exception(str(e))
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ invenio_celery.tasks =
invenio_rdm_records_access_requests = invenio_rdm_records.requests.access.tasks
invenio_rdm_records_iiif = invenio_rdm_records.services.iiif.tasks
invenio_rdm_records_user_moderation = invenio_rdm_records.requests.user_moderation.tasks
invenio_rdm_records_collections = invenio_rdm_records.collections.tasks
invenio_db.models =
invenio_rdm_records = invenio_rdm_records.records.models
invenio_rdm_records_collections = invenio_rdm_records.collections.models
Expand Down
8 changes: 4 additions & 4 deletions tests/collections/test_collections_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_create(running_app, db, community, community_owner):
ctree=tree,
)

read_c = Collection.resolve(id_=collection.id)
read_c = Collection.read(id_=collection.id)
assert read_c.id == collection.id
assert read_c.title == "My Collection"
assert read_c.collection_tree.id == tree.id
Expand All @@ -40,7 +40,7 @@ def test_create(running_app, db, community, community_owner):
ctree=tree.id,
)

read_c = Collection.resolve(id_=collection.id)
read_c = Collection.read(id_=collection.id)
assert read_c.id == collection.id
assert collection.title == "My Collection 2"
assert collection.collection_tree.id == tree.id
Expand All @@ -63,11 +63,11 @@ def test_resolve(running_app, db, community):
)

# Read by ID
read_by_id = Collection.resolve(id_=collection.id)
read_by_id = Collection.read(id_=collection.id)
assert read_by_id.id == collection.id

# Read by slug
read_by_slug = Collection.resolve(slug="my-collection", ctree_id=tree.id)
read_by_slug = Collection.read(slug="my-collection", ctree_id=tree.id)
assert read_by_slug.id == read_by_id.id == collection.id


Expand Down
Loading

0 comments on commit 9e56947

Please sign in to comment.