Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

postgres.CopyToTable: Robustize resources cleanup #3072

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
168 changes: 83 additions & 85 deletions luigi/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Also provides a helper task to copy data into a Postgres table.
"""

import contextlib
import datetime
import logging
import re
Expand Down Expand Up @@ -167,20 +168,20 @@ def exists(self, connection=None):
if connection is None:
connection = self.connect()
connection.autocommit = True
cursor = connection.cursor()
try:
cursor.execute("""SELECT 1 FROM {marker_table}
WHERE update_id = %s
LIMIT 1""".format(marker_table=self.marker_table),
(self.update_id,)
)
row = cursor.fetchone()
except psycopg2.ProgrammingError as e:
if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE:
row = None
else:
raise
return row is not None
with connection.cursor() as cursor:
try:
cursor.execute(
"""SELECT 1 FROM {marker_table}
WHERE update_id = %s
LIMIT 1""".format(marker_table=self.marker_table),
(self.update_id,))
row = cursor.fetchone()
except psycopg2.ProgrammingError as e:
if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE:
row = None
else:
raise
return row is not None

def connect(self):
"""
Expand All @@ -201,30 +202,29 @@ def create_marker_table(self):

Using a separate connection since the transaction might have to be reset.
"""
connection = self.connect()
connection.autocommit = True
cursor = connection.cursor()
if self.use_db_timestamps:
sql = """ CREATE TABLE {marker_table} (
update_id TEXT PRIMARY KEY,
target_table TEXT,
inserted TIMESTAMP DEFAULT NOW())
""".format(marker_table=self.marker_table)
else:
sql = """ CREATE TABLE {marker_table} (
update_id TEXT PRIMARY KEY,
target_table TEXT,
inserted TIMESTAMP);
""".format(marker_table=self.marker_table)

try:
cursor.execute(sql)
except psycopg2.ProgrammingError as e:
if e.pgcode == psycopg2.errorcodes.DUPLICATE_TABLE:
pass
else:
raise
connection.close()
with contextlib.closing(self.connect()) as connection:
connection.autocommit = True
with connection.cursor() as cursor:
if self.use_db_timestamps:
sql = """ CREATE TABLE {marker_table} (
update_id TEXT PRIMARY KEY,
target_table TEXT,
inserted TIMESTAMP DEFAULT NOW())
""".format(marker_table=self.marker_table)
else:
sql = """ CREATE TABLE {marker_table} (
update_id TEXT PRIMARY KEY,
target_table TEXT,
inserted TIMESTAMP);
""".format(marker_table=self.marker_table)

try:
cursor.execute(sql)
except psycopg2.ProgrammingError as e:
if e.pgcode == psycopg2.errorcodes.DUPLICATE_TABLE:
pass
else:
raise

def open(self, mode):
raise NotImplementedError("Cannot open() PostgresTarget")
Expand Down Expand Up @@ -285,7 +285,7 @@ def copy(self, cursor, file):
elif len(self.columns[0]) == 2:
column_names = [c[0] for c in self.columns]
else:
raise Exception('columns must consist of column strings or (column string, type string) tuples (was %r ...)' % (self.columns[0],))
raise ValueError('columns must consist of column strings or (column string, type string) tuples (was %r ...)' % (self.columns[0],))
cursor.copy_from(file, self.table, null=r'\\N', sep=self.column_separator, columns=column_names)

def run(self):
Expand All @@ -299,52 +299,50 @@ def run(self):
if not (self.table and self.columns):
raise Exception("table and columns need to be specified")

connection = self.output().connect()
# transform all data generated by rows() using map_column and write data
# to a temporary file for import using postgres COPY
tmp_dir = luigi.configuration.get_config().get('postgres', 'local-tmp-dir', None)
tmp_file = tempfile.TemporaryFile(dir=tmp_dir)
n = 0
for row in self.rows():
n += 1
if n % 100000 == 0:
logger.info("Wrote %d lines", n)
rowstr = self.column_separator.join(self.map_column(val) for val in row)
rowstr += "\n"
tmp_file.write(rowstr.encode('utf-8'))

logger.info("Done writing, importing at %s", datetime.datetime.now())
tmp_file.seek(0)

# attempt to copy the data into postgres
# if it fails because the target table doesn't exist
# try to create it by running self.create_table
for attempt in range(2):
try:
cursor = connection.cursor()
self.init_copy(connection)
self.copy(cursor, tmp_file)
self.post_copy(connection)
if self.enable_metadata_columns:
self.post_copy_metacolumns(cursor)
except psycopg2.ProgrammingError as e:
if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE and attempt == 0:
# if first attempt fails with "relation not found", try creating table
logger.info("Creating table %s", self.table)
connection.reset()
self.create_table(connection)
else:
raise
else:
break

# mark as complete in same transaction
self.output().touch(connection)

# commit and clean up
connection.commit()
connection.close()
tmp_file.close()
with contextlib.closing(self.output().connect()) as connection:
# transform all data generated by rows() using map_column and
# write data to a temporary file for import using postgres COPY
tmp_dir = luigi.configuration.get_config().get('postgres', 'local-tmp-dir', None)
with tempfile.TemporaryFile(dir=tmp_dir) as tmp_file:
n = 0
for row in self.rows():
n += 1
if n % 100000 == 0:
logger.info("Wrote %d lines", n)
rowstr = self.column_separator.join(self.map_column(val) for val in row)
rowstr += "\n"
tmp_file.write(rowstr.encode('utf-8'))

logger.info(
"Done writing, importing at %s", datetime.datetime.now())
tmp_file.seek(0)

with connection:
# attempt to copy the data into postgres
# if it fails because the target table doesn't exist
# try to create it by running self.create_table
for attempt in range(2):
try:
with connection.cursor() as cursor:
self.init_copy(connection)
self.copy(cursor, tmp_file)
self.post_copy(connection)
if self.enable_metadata_columns:
self.post_copy_metacolumns(cursor)
except psycopg2.ProgrammingError as e:
if e.pgcode == psycopg2.errorcodes.\
UNDEFINED_TABLE and attempt == 0:
# if first attempt fails with "relation not found", try creating table
logger.info("Creating table %s", self.table)
connection.reset()
self.create_table(connection)
else:
raise
else:
break

# mark as complete in same transaction
self.output().touch(connection)


class PostgresQuery(rdbms.Query):
Expand Down
80 changes: 77 additions & 3 deletions test/contrib/postgres_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,48 @@ def datetime_to_epoch(dt):
return td.days * 86400 + td.seconds + td.microseconds / 1E6


class MockPostgresCursor(mock.Mock):
class MockContextManager(mock.Mock):

def __init__(self, *args, **kwargs):
super(MockContextManager, self).__init__(*args, **kwargs)
self.context_counter = 0
self.all_context_counter = 0

def __enter__(self):
self.context_counter += 1
self.all_context_counter += 1
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.context_counter -= 1

def _get_child_mock(self, **kwargs):
"""Child mocks will be instances of super."""
return mock.Mock(**kwargs)


class MockPostgresConnection(MockContextManager):
def __init__(self, existing_update_ids, *args, **kwargs):
super(MockPostgresConnection, self).__init__(*args, **kwargs)
self.existing = existing_update_ids
self.is_open = False
self.was_open = 0

def cursor(self):
self.is_open = True
self.was_open = True
return MockPostgresCursor(existing_update_ids=self.existing)

def close(self):
self.is_open = False


class MockPostgresCursor(MockContextManager):
"""
Keeps state to simulate executing SELECT queries and fetching results.
"""
def __init__(self, existing_update_ids):
super(MockPostgresCursor, self).__init__()
def __init__(self, existing_update_ids, *args, **kwargs):
super(MockPostgresCursor, self).__init__(*args, **kwargs)
self.existing = existing_update_ids

def execute(self, query, params):
Expand Down Expand Up @@ -82,6 +118,44 @@ def test_bulk_complete(self, mock_connect):
]))
self.assertFalse(task.complete())

@mock.patch('psycopg2.connect')
@mock.patch("luigi.contrib.postgres.CopyToTable.rows", return_value=['row1', 'row2'])
def test_cleanup_on_error(self, mock_rows, mock_connect):
"""
Test cleanup behavior of CopyToTable in case of an error.

When an error occured while the connection is open, it should be
closed again so that subsequent tasks do not fail due to the unclosed
connection.
"""
task = DummyPostgresImporter(date=datetime.datetime(2021, 4, 15))

mock_connection = MockPostgresConnection([task.task_id])
mock_connect.return_value = mock_connection
mock_cursor = MockPostgresCursor([task.task_id])

original_cursor = mock_connection.cursor

def get_mock_cursor():
original_cursor()
return mock_cursor

mock_connection.cursor = mock.MagicMock(side_effect=get_mock_cursor)

task = DummyPostgresImporter(date=datetime.datetime(2021, 4, 15))
task.columns = [(42,)] # inject defect

with self.assertRaisesRegex(ValueError, "columns"):
task.run()

self.assertEqual(mock_connection.context_counter, 0)
self.assertTrue(mock_connection.all_context_counter)
self.assertFalse(mock_connection.is_open)
self.assertTrue(mock_connection.was_open)

self.assertEqual(mock_cursor.context_counter, 0)
self.assertTrue(mock_cursor.all_context_counter)


class DummyPostgresQuery(luigi.contrib.postgres.PostgresQuery):
date = luigi.DateParameter()
Expand Down