Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

postgres.CopyToTable: Robustize resources cleanup #3072

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
96 changes: 49 additions & 47 deletions luigi/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def copy(self, cursor, file):
elif len(self.columns[0]) == 2:
column_names = [c[0] for c in self.columns]
else:
raise Exception('columns must consist of column strings or (column string, type string) tuples (was %r ...)' % (self.columns[0],))
raise ValueError('columns must consist of column strings or (column string, type string) tuples (was %r ...)' % (self.columns[0],))
cursor.copy_from(file, self.table, null=r'\\N', sep=self.column_separator, columns=column_names)

def run(self):
Expand All @@ -299,52 +299,54 @@ def run(self):
if not (self.table and self.columns):
raise Exception("table and columns need to be specified")

connection = self.output().connect()
# transform all data generated by rows() using map_column and write data
# to a temporary file for import using postgres COPY
tmp_dir = luigi.configuration.get_config().get('postgres', 'local-tmp-dir', None)
tmp_file = tempfile.TemporaryFile(dir=tmp_dir)
n = 0
for row in self.rows():
n += 1
if n % 100000 == 0:
logger.info("Wrote %d lines", n)
rowstr = self.column_separator.join(self.map_column(val) for val in row)
rowstr += "\n"
tmp_file.write(rowstr.encode('utf-8'))

logger.info("Done writing, importing at %s", datetime.datetime.now())
tmp_file.seek(0)

# attempt to copy the data into postgres
# if it fails because the target table doesn't exist
# try to create it by running self.create_table
for attempt in range(2):
try:
cursor = connection.cursor()
self.init_copy(connection)
self.copy(cursor, tmp_file)
self.post_copy(connection)
if self.enable_metadata_columns:
self.post_copy_metacolumns(cursor)
except psycopg2.ProgrammingError as e:
if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE and attempt == 0:
# if first attempt fails with "relation not found", try creating table
logger.info("Creating table %s", self.table)
connection.reset()
self.create_table(connection)
else:
raise
else:
break

# mark as complete in same transaction
self.output().touch(connection)

# commit and clean up
connection.commit()
connection.close()
tmp_file.close()
try:
connection = self.output().connect()
LinqLover marked this conversation as resolved.
Show resolved Hide resolved
# transform all data generated by rows() using map_column and write data
# to a temporary file for import using postgres COPY
tmp_dir = luigi.configuration.get_config().get('postgres', 'local-tmp-dir', None)
tmp_file = tempfile.TemporaryFile(dir=tmp_dir)
LinqLover marked this conversation as resolved.
Show resolved Hide resolved
n = 0
for row in self.rows():
n += 1
if n % 100000 == 0:
logger.info("Wrote %d lines", n)
rowstr = self.column_separator.join(self.map_column(val) for val in row)
rowstr += "\n"
tmp_file.write(rowstr.encode('utf-8'))

logger.info("Done writing, importing at %s", datetime.datetime.now())
tmp_file.seek(0)

with connection:
# attempt to copy the data into postgres
# if it fails because the target table doesn't exist
# try to create it by running self.create_table
for attempt in range(2):
try:
with connection.cursor() as cursor:
self.init_copy(connection)
self.copy(cursor, tmp_file)
self.post_copy(connection)
if self.enable_metadata_columns:
self.post_copy_metacolumns(cursor)
except psycopg2.ProgrammingError as e:
if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE and attempt == 0:
# if first attempt fails with "relation not found", try creating table
logger.info("Creating table %s", self.table)
connection.reset()
self.create_table(connection)
else:
raise
else:
break

# mark as complete in same transaction
self.output().touch(connection)
finally:
if connection:
connection.close()
if tmp_file:
tmp_file.close()


class PostgresQuery(rdbms.Query):
Expand Down
79 changes: 76 additions & 3 deletions test/contrib/postgres_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,47 @@ def datetime_to_epoch(dt):
return td.days * 86400 + td.seconds + td.microseconds / 1E6


class MockPostgresCursor(mock.Mock):
class MockContextManager(mock.Mock):

def __init__(self, *args, **kwargs):
super(MockContextManager, self).__init__(*args, **kwargs)
self.context_counter = 0
self.all_context_counter = 0

def __enter__(self):
self.context_counter += 1
self.all_context_counter += 1

def __exit__(self, exc_type, exc_value, exc_traceback):
self.context_counter -= 1

def _get_child_mock(self, **kwargs):
"""Child mocks will be instances of super."""
return mock.Mock(**kwargs)


class MockPostgresConnection(MockContextManager):
def __init__(self, existing_update_ids, *args, **kwargs):
super(MockPostgresConnection, self).__init__(*args, **kwargs)
self.existing = existing_update_ids
self.is_open = False
self.was_open = 0

def cursor(self):
self.is_open = True
self.was_open = True
return MockPostgresCursor(existing_update_ids=self.existing)

def close(self):
self.is_open = False


class MockPostgresCursor(MockContextManager):
"""
Keeps state to simulate executing SELECT queries and fetching results.
"""
def __init__(self, existing_update_ids):
super(MockPostgresCursor, self).__init__()
def __init__(self, existing_update_ids, *args, **kwargs):
super(MockPostgresCursor, self).__init__(*args, **kwargs)
self.existing = existing_update_ids

def execute(self, query, params):
Expand Down Expand Up @@ -82,6 +117,44 @@ def test_bulk_complete(self, mock_connect):
]))
self.assertFalse(task.complete())

@mock.patch('psycopg2.connect')
@mock.patch("luigi.contrib.postgres.CopyToTable.rows", return_value=['row1', 'row2'])
def test_cleanup_on_error(self, mock_rows, mock_connect):
"""
Test cleanup behavior of CopyToTable in case of an error.

When an error occured while the connection is open, it should be
closed again so that subsequent tasks do not fail due to the unclosed
connection.
"""
task = DummyPostgresImporter(date=datetime.datetime(2021, 4, 15))

mock_connection = MockPostgresConnection([task.task_id])
mock_connect.return_value = mock_connection
mock_cursor = MockPostgresCursor([task.task_id])

original_cursor = mock_connection.cursor

def get_mock_cursor():
original_cursor()
return mock_cursor

mock_connection.cursor = mock.MagicMock(side_effect=get_mock_cursor)

task = DummyPostgresImporter(date=datetime.datetime(2021, 4, 15))
task.columns = [(42,)] # inject defect

with self.assertRaisesRegex(ValueError, "columns"):
task.run()

self.assertEqual(mock_connection.context_counter, 0)
self.assertTrue(mock_connection.all_context_counter)
self.assertFalse(mock_connection.is_open)
self.assertTrue(mock_connection.was_open)

self.assertEqual(mock_cursor.context_counter, 0)
self.assertTrue(mock_cursor.all_context_counter)


class DummyPostgresQuery(luigi.contrib.postgres.PostgresQuery):
date = luigi.DateParameter()
Expand Down