Skip to content

Commit

Permalink
make sure migrations uses primary instead of replica
Browse files Browse the repository at this point in the history
  • Loading branch information
John Tordoff committed Nov 1, 2024
1 parent 56c466c commit e763563
Showing 1 changed file with 58 additions and 34 deletions.
92 changes: 58 additions & 34 deletions osf/management/commands/migrate_preprint_affiliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import logging

from django.core.management.base import BaseCommand
from django.db import transaction
from osf.models import PreprintContributor
from django.db.models import F, Count
from django.db import transaction, router
from django.db.models import F, Exists, OuterRef

from osf.models import PreprintContributor, InstitutionAffiliation

logger = logging.getLogger(__name__)

Expand All @@ -27,20 +28,29 @@ def add_arguments(self, parser):
dest='dry_run',
help='If true, performs a dry run without making changes'
)
parser.add_argument(
'--batch-size',
type=int,
default=1000,
dest='batch_size',
help='Number of contributors to process in each batch'
)

def handle(self, *args, **options):
start_time = datetime.datetime.now()
logger.info(f'Script started at: {start_time}')

exclude_guids = set(options.get('exclude_guids') or [])
dry_run = options.get('dry_run', False)
batch_size = options.get('batch_size', 1000)

if dry_run:
logger.info('Dry run mode activated.')

processed_count, updated_count, skipped_count = assign_affiliations_to_preprints(
exclude_guids=exclude_guids,
dry_run=dry_run
dry_run=dry_run,
batch_size=batch_size
)

finish_time = datetime.datetime.now()
Expand All @@ -49,48 +59,62 @@ def handle(self, *args, **options):
logger.info(f'Total run time: {finish_time - start_time}')


def assign_affiliations_to_preprints(exclude_guids=None, dry_run=True):
def assign_affiliations_to_preprints(exclude_guids=None, dry_run=True, batch_size=1000):
exclude_guids = exclude_guids or set()
contributors = PreprintContributor.objects.filter(
processed_count = updated_count = skipped_count = 0

# Subquery to check if the user has any affiliated institutions
user_has_affiliations = Exists(
InstitutionAffiliation.objects.filter(
user=OuterRef('user')
)
)

contributors_qs = PreprintContributor.objects.filter(
preprint__preprintgroupobjectpermission__permission__codename__in=['write_preprint', 'admin_preprint'],
preprint__preprintgroupobjectpermission__group__user=F('user'),
).annotate(
num_affiliations=Count('user__institutionaffiliation')
).filter(
num_affiliations__gt=0 # Exclude users with no affiliations
user_has_affiliations
).select_related(
'user',
'preprint'
).exclude(
user__guids___id__in=exclude_guids
).distinct()
).order_by('pk') # Ensure consistent ordering for batching

processed_count = updated_count = skipped_count = 0
total_contributors = contributors_qs.count()
logger.info(f'Total contributors to process: {total_contributors}')

# Process contributors in batches
with transaction.atomic():
for contributor in contributors:
user = contributor.user
preprint = contributor.preprint

user_institutions = set(user.get_affiliated_institutions())
preprint_institutions = set(preprint.affiliated_institutions.all())

new_institutions = user_institutions - preprint_institutions

if new_institutions:
processed_count += 1
if not dry_run:
for institution in new_institutions:
preprint.affiliated_institutions.add(institution)
updated_count += 1
logger.info(
f'Assigned {len(new_institutions)} affiliations from user <{user._id}> to preprint <{preprint._id}>.'
)
for offset in range(0, total_contributors, batch_size):
# Use select_for_update() to ensure query hits the primary database
batch_contributors = contributors_qs[offset:offset + batch_size].select_for_update()

logger.info(f'Processing contributors {offset + 1} to {min(offset + batch_size, total_contributors)}')

for contributor in batch_contributors:
user = contributor.user
preprint = contributor.preprint

user_institutions = set(user.get_affiliated_institutions())
preprint_institutions = set(preprint.affiliated_institutions.all())

new_institutions = user_institutions - preprint_institutions

if new_institutions:
processed_count += 1
if not dry_run:
preprint.affiliated_institutions.add(*new_institutions)
updated_count += 1
logger.info(
f'Assigned {len(new_institutions)} affiliations from user <{user._id}> to preprint <{preprint._id}>.'
)
else:
logger.info(
f'Dry run: Would assign {len(new_institutions)} affiliations from user <{user._id}> to preprint <{preprint._id}>.'
)
else:
logger.info(
f'Dry run: Would assign {len(new_institutions)} affiliations from user <{user._id}> to preprint <{preprint._id}>.'
)
else:
skipped_count += 1
skipped_count += 1

return processed_count, updated_count, skipped_count

0 comments on commit e763563

Please sign in to comment.