diff --git a/django_bulk_update/helper.py b/django_bulk_update/helper.py index ef45d90..a2e1548 100644 --- a/django_bulk_update/helper.py +++ b/django_bulk_update/helper.py @@ -114,7 +114,8 @@ def get_fields(update_fields, exclude_fields, meta, obj=None): def bulk_update(objs, meta=None, update_fields=None, exclude_fields=None, - using='default', batch_size=None, pk_field='pk'): + using='default', batch_size=None, pk_field='pk', + ordered=False): assert batch_size is None or batch_size > 0 # force to retrieve objs from the DB at the beginning, @@ -154,11 +155,16 @@ def bulk_update(objs, meta=None, update_fields=None, exclude_fields=None, # Sqlite also gives some trouble with cast, at least for datetime, # but is also permissive for uncast values vendor = connection.vendor + use_cast = 'mysql' not in vendor and 'sqlite' not in vendor + if ordered: + template = 'CASE cte."{pk_column}" {cases}ELSE cte."{column}" END' + else: + template = 'CASE "{pk_column}" {cases}ELSE "{column}" END' if use_cast: - template = '"{column}" = CAST(CASE "{pk_column}" {cases}ELSE "{column}" END AS {type})' + template = '"{column}" = CAST(' + template + ' AS {type})' else: - template = '"{column}" = (CASE "{pk_column}" {cases}ELSE "{column}" END)' + template = '"{column}" = (' + template + ')' case_template = "WHEN %s THEN {} " @@ -191,11 +197,7 @@ def bulk_update(objs, meta=None, update_fields=None, exclude_fields=None, for field in parameters.keys() ) - parameters = flatten(parameters.values(), types=list) - parameters.extend(pks) - n_pks = len(pks) - del pks dbtable = '"{}"'.format(meta.db_table) @@ -204,11 +206,37 @@ def bulk_update(objs, meta=None, update_fields=None, exclude_fields=None, pks=', '.join(itertools.repeat('%s', n_pks)), ) - sql = 'UPDATE {dbtable} SET {values} WHERE {in_clause}'.format( - dbtable=dbtable, - values=values, - in_clause=in_clause, - ) + if ordered: + columns = ', '.join('"{column}"'.format(column=field.column) + for field in parameters.keys()) + parameters = list(pks) + flatten(parameters.values(), types=list) + + sql = '''with cte as + ( + select "{pk_column}", {columns} + from {dbtable} + where {in_clause} order by "{pk_column}" asc + ) + UPDATE {dbtable} + SET {values} + from cte + where cte."{pk_column}" = {dbtable}."{pk_column}" + ;'''.format( + dbtable=dbtable, + values=values, + in_clause=in_clause, + columns=columns, + pk_column=pk_field.column, + ) + else: + parameters = flatten(parameters.values(), types=list) + parameters.extend(pks) + sql = 'UPDATE {dbtable} SET {values} WHERE {in_clause}'.format( + dbtable=dbtable, + values=values, + in_clause=in_clause, + ) + del pks del values # String escaping in ANSI SQL is done by using double quotes (").