Skip to content

Commit

Permalink
make sure migrations uses primary instead of replica, improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
John Tordoff committed Nov 4, 2024
1 parent 56c466c commit b8211d0
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 36 deletions.
69 changes: 43 additions & 26 deletions osf/management/commands/migrate_preprint_affiliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

from django.core.management.base import BaseCommand
from django.db import transaction
from osf.models import PreprintContributor
from django.db.models import F, Count
from django.db.models import F, Exists, OuterRef

from osf.models import PreprintContributor, InstitutionAffiliation

logger = logging.getLogger(__name__)

Expand All @@ -27,70 +28,86 @@ def add_arguments(self, parser):
dest='dry_run',
help='If true, performs a dry run without making changes'
)
parser.add_argument(
'--batch-size',
type=int,
default=1000,
dest='batch_size',
help='Number of contributors to process in each batch'
)

def handle(self, *args, **options):
start_time = datetime.datetime.now()
logger.info(f'Script started at: {start_time}')

exclude_guids = set(options.get('exclude_guids') or [])
dry_run = options.get('dry_run', False)
batch_size = options.get('batch_size', 1000)

if dry_run:
logger.info('Dry run mode activated.')

processed_count, updated_count, skipped_count = assign_affiliations_to_preprints(
processed_count, updated_count = assign_affiliations_to_preprints(
exclude_guids=exclude_guids,
dry_run=dry_run
dry_run=dry_run,
batch_size=batch_size
)

finish_time = datetime.datetime.now()
logger.info(f'Script finished at: {finish_time}')
logger.info(f'Total processed: {processed_count}, Updated: {updated_count}, Skipped: {skipped_count}')
logger.info(f'Total processed: {processed_count}, Updated: {updated_count}')
logger.info(f'Total run time: {finish_time - start_time}')


def assign_affiliations_to_preprints(exclude_guids=None, dry_run=True):
def assign_affiliations_to_preprints(exclude_guids=None, dry_run=True, batch_size=1000):
exclude_guids = exclude_guids or set()
contributors = PreprintContributor.objects.filter(
processed_count = updated_count = 0

# Subquery to check if the user has any affiliated institutions
user_has_affiliations = Exists(
InstitutionAffiliation.objects.filter(
user=OuterRef('user')
)
)

contributors_qs = PreprintContributor.objects.filter(
preprint__preprintgroupobjectpermission__permission__codename__in=['write_preprint', 'admin_preprint'],
preprint__preprintgroupobjectpermission__group__user=F('user'),
).annotate(
num_affiliations=Count('user__institutionaffiliation')
).filter(
num_affiliations__gt=0 # Exclude users with no affiliations
user_has_affiliations
).select_related(
'user',
'preprint'
).exclude(
user__guids___id__in=exclude_guids
).distinct()
).order_by('pk') # Ensure consistent ordering for batching

processed_count = updated_count = skipped_count = 0
total_contributors = contributors_qs.count()
logger.info(f'Total contributors to process: {total_contributors}')

# Process contributors in batches
with transaction.atomic():
for contributor in contributors:
user = contributor.user
preprint = contributor.preprint
for offset in range(0, total_contributors, batch_size):
# Use select_for_update() to ensure query hits the primary database
batch_contributors = contributors_qs[offset:offset + batch_size].select_for_update()

user_institutions = set(user.get_affiliated_institutions())
preprint_institutions = set(preprint.affiliated_institutions.all())
logger.info(f'Processing contributors {offset + 1} to {min(offset + batch_size, total_contributors)}')

new_institutions = user_institutions - preprint_institutions
for contributor in batch_contributors:
user = contributor.user
preprint = contributor.preprint

if new_institutions:
user_institutions = user.get_affiliated_institutions()
processed_count += 1
if not dry_run:
for institution in new_institutions:
preprint.affiliated_institutions.add(institution)
preprint.affiliated_institutions.add(*user_institutions)
updated_count += 1
logger.info(
f'Assigned {len(new_institutions)} affiliations from user <{user._id}> to preprint <{preprint._id}>.'
f'Assigned {len(user_institutions)} affiliations from user <{user._id}> to preprint <{preprint._id}>.'
)
else:
logger.info(
f'Dry run: Would assign {len(new_institutions)} affiliations from user <{user._id}> to preprint <{preprint._id}>.'
f'Dry run: Would assign {len(user_institutions)} affiliations from user <{user._id}> to preprint <{preprint._id}>.'
)
else:
skipped_count += 1

return processed_count, updated_count, skipped_count
return processed_count, updated_count
28 changes: 18 additions & 10 deletions osf_tests/management_commands/test_migrate_preprint_affiliations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ def test_no_affiliations_for_non_affiliated_contributor(self, preprint_with_non_
assert not preprint.affiliated_institutions.exists()

@pytest.mark.parametrize('dry_run', [True, False])
def test_exclude_contributor_by_guid(self, preprint_with_affiliated_contributor, institution, dry_run):
def test_exclude_contributor_by_guid(self, preprint_with_affiliated_contributor, user_with_affiliation, institution, dry_run):
preprint = preprint_with_affiliated_contributor
preprint.affiliated_institutions.clear()
preprint.save()

assert preprint.contributors.last().get_affiliated_institutions()
assert user_with_affiliation.get_affiliated_institutions()
assert user_with_affiliation in preprint.contributors.all()
exclude_guids = {user._id for user in preprint.contributors.all()}

assign_affiliations_to_preprints(exclude_guids=exclude_guids, dry_run=dry_run)
Expand All @@ -83,19 +84,25 @@ def test_exclude_contributor_by_guid(self, preprint_with_affiliated_contributor,

@pytest.mark.parametrize('dry_run', [True, False])
def test_affiliations_from_multiple_contributors(self, institution, dry_run):
user1 = AuthUserFactory()
user1.add_or_update_affiliated_institution(institution)
user1.save()
institution_not_include = InstitutionFactory()
read_contrib = AuthUserFactory()
read_contrib.add_or_update_affiliated_institution(institution_not_include)
read_contrib.save()

user2 = AuthUserFactory()
write_contrib = AuthUserFactory()
write_contrib.add_or_update_affiliated_institution(institution)
write_contrib.save()

admin_contrib = AuthUserFactory()
institution2 = InstitutionFactory()
user2.add_or_update_affiliated_institution(institution2)
user2.save()
admin_contrib.add_or_update_affiliated_institution(institution2)
admin_contrib.save()

preprint = PreprintFactory()
preprint.affiliated_institutions.clear()
preprint.add_contributor(user1, permissions='write', visible=True)
preprint.add_contributor(user2, permissions='admin', visible=True)
preprint.add_contributor(read_contrib, permissions='read', visible=True)
preprint.add_contributor(write_contrib, permissions='write', visible=True)
preprint.add_contributor(admin_contrib, permissions='admin', visible=True)
preprint.save()

assign_affiliations_to_preprints(dry_run=dry_run)
Expand All @@ -105,3 +112,4 @@ def test_affiliations_from_multiple_contributors(self, institution, dry_run):
else:
affiliations = set(preprint.affiliated_institutions.all())
assert affiliations == {institution, institution2}
assert institution_not_include not in affiliations

0 comments on commit b8211d0

Please sign in to comment.