From 879a002bb0b7652f9588e2849cb417840e23e374 Mon Sep 17 00:00:00 2001 From: Allison Suarez Miranda <22477579+allisonsuarez@users.noreply.github.com> Date: Wed, 27 Jul 2022 19:59:24 -0700 Subject: [PATCH] fix: driver object pickle error (#1944) * fix: driver object pickle error Signed-off-by: Allison Suarez Miranda * always used conf w fallback on neo4j extractor Signed-off-by: Allison Suarez Miranda --- .../databuilder/extractor/neo4j_extractor.py | 63 +++++++++---------- .../publisher/neo4j_csv_publisher.py | 54 +++++++--------- .../task/neo4j_staleness_removal_task.py | 53 +++++++--------- databuilder/setup.py | 2 +- 4 files changed, 79 insertions(+), 93 deletions(-) diff --git a/databuilder/databuilder/extractor/neo4j_extractor.py b/databuilder/databuilder/extractor/neo4j_extractor.py index 16b8551e88..4dc5553234 100644 --- a/databuilder/databuilder/extractor/neo4j_extractor.py +++ b/databuilder/databuilder/extractor/neo4j_extractor.py @@ -35,7 +35,6 @@ class Neo4jExtractor(Extractor): """NEO4J_ENCRYPTED is a boolean indicating whether to use SSL/TLS when connecting.""" NEO4J_VALIDATE_SSL = 'neo4j_validate_ssl' """NEO4J_VALIDATE_SSL is a boolean indicating whether to validate the server's SSL/TLS cert against system CAs.""" - NEO4J_DRIVER = 'neo4j_driver' DEFAULT_CONFIG = ConfigFactory.from_dict({ NEO4J_MAX_CONN_LIFE_TIME_SEC: 50, @@ -48,41 +47,39 @@ def init(self, conf: ConfigTree) -> None: :param conf: """ self.conf = conf.with_fallback(Neo4jExtractor.DEFAULT_CONFIG) - self.graph_url = conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY) - self.cypher_query = conf.get_string(Neo4jExtractor.CYPHER_QUERY_CONFIG_KEY) + self.graph_url = self.conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY) + self.cypher_query = self.conf.get_string(Neo4jExtractor.CYPHER_QUERY_CONFIG_KEY) self.db_name = self.conf.get_string(Neo4jExtractor.NEO4J_DATABASE_NAME) - driver = conf.get(Neo4jExtractor.NEO4J_DRIVER, None) - if driver: - self.driver = driver - else: - uri = conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY) - driver_args = { - 'uri': uri, - 'max_connection_lifetime': self.conf.get_int(Neo4jExtractor.NEO4J_MAX_CONN_LIFE_TIME_SEC), - 'auth': (conf.get_string(Neo4jExtractor.NEO4J_AUTH_USER), - conf.get_string(Neo4jExtractor.NEO4J_AUTH_PW)), - } - - # if URI scheme not secure set `trust`` and `encrypted` to default values - # https://neo4j.com/docs/api/python-driver/current/api.html#uri - _, security_type, _ = parse_neo4j_uri(uri=uri) - if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]: - default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True} - driver_args.update(default_security_conf) - - # if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver - validate_ssl_conf = conf.get(Neo4jExtractor.NEO4J_VALIDATE_SSL, None) - encrypted_conf = conf.get(Neo4jExtractor.NEO4J_ENCRYPTED, None) - if validate_ssl_conf is not None: - driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \ - else neo4j.TRUST_ALL_CERTIFICATES - if encrypted_conf is not None: - driver_args['encrypted'] = encrypted_conf - - self.driver = GraphDatabase.driver(**driver_args) + + uri = self.conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY) + driver_args = { + 'uri': uri, + 'max_connection_lifetime': self.conf.get_int(Neo4jExtractor.NEO4J_MAX_CONN_LIFE_TIME_SEC), + 'auth': (self.conf.get_string(Neo4jExtractor.NEO4J_AUTH_USER), + self.conf.get_string(Neo4jExtractor.NEO4J_AUTH_PW)), + } + + # if URI scheme not secure set `trust`` and `encrypted` to default values + # https://neo4j.com/docs/api/python-driver/current/api.html#uri + _, security_type, _ = parse_neo4j_uri(uri=uri) + if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]: + default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True} + driver_args.update(default_security_conf) + + # if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver + validate_ssl_conf = self.conf.get(Neo4jExtractor.NEO4J_VALIDATE_SSL, None) + encrypted_conf = self.conf.get(Neo4jExtractor.NEO4J_ENCRYPTED, None) + if validate_ssl_conf is not None: + driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \ + else neo4j.TRUST_ALL_CERTIFICATES + if encrypted_conf is not None: + driver_args['encrypted'] = encrypted_conf + + self.driver = GraphDatabase.driver(**driver_args) + self._extract_iter: Union[None, Iterator] = None - model_class = conf.get(Neo4jExtractor.MODEL_CLASS_CONFIG_KEY, None) + model_class = self.conf.get(Neo4jExtractor.MODEL_CLASS_CONFIG_KEY, None) if model_class: module_name, class_name = model_class.rsplit(".", 1) mod = importlib.import_module(module_name) diff --git a/databuilder/databuilder/publisher/neo4j_csv_publisher.py b/databuilder/databuilder/publisher/neo4j_csv_publisher.py index 8fb8e84155..b1d07449ef 100644 --- a/databuilder/databuilder/publisher/neo4j_csv_publisher.py +++ b/databuilder/databuilder/publisher/neo4j_csv_publisher.py @@ -57,8 +57,6 @@ # in Neo4j (v4.0+), we can create and use more than one active database at the same time NEO4J_DATABASE_NAME = 'neo4j_database' -NEO4J_DRIVER = 'neo4j_driver' - # NEO4J_ENCRYPTED is a boolean indicating whether to use SSL/TLS when connecting NEO4J_ENCRYPTED = 'neo4j_encrypted' # NEO4J_VALIDATE_SSL is a boolean indicating whether to validate the server's SSL/TLS @@ -154,34 +152,30 @@ def init(self, conf: ConfigTree) -> None: self._relation_files = self._list_files(conf, RELATION_FILES_DIR) self._relation_files_iter = iter(self._relation_files) - driver = conf.get(NEO4J_DRIVER, None) - if driver: - self._driver = driver - else: - uri = conf.get_string(NEO4J_END_POINT_KEY) - driver_args = { - 'uri': uri, - 'max_connection_lifetime': conf.get_int(NEO4J_MAX_CONN_LIFE_TIME_SEC), - 'auth': (conf.get_string(NEO4J_USER), conf.get_string(NEO4J_PASSWORD)), - } - - # if URI scheme not secure set `trust`` and `encrypted` to default values - # https://neo4j.com/docs/api/python-driver/current/api.html#uri - _, security_type, _ = parse_neo4j_uri(uri=uri) - if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]: - default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True} - driver_args.update(default_security_conf) - - # if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver - validate_ssl_conf = conf.get(NEO4J_VALIDATE_SSL, None) - encrypted_conf = conf.get(NEO4J_ENCRYPTED, None) - if validate_ssl_conf is not None: - driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \ - else neo4j.TRUST_ALL_CERTIFICATES - if encrypted_conf is not None: - driver_args['encrypted'] = encrypted_conf - - self._driver = GraphDatabase.driver(**driver_args) + uri = conf.get_string(NEO4J_END_POINT_KEY) + driver_args = { + 'uri': uri, + 'max_connection_lifetime': conf.get_int(NEO4J_MAX_CONN_LIFE_TIME_SEC), + 'auth': (conf.get_string(NEO4J_USER), conf.get_string(NEO4J_PASSWORD)), + } + + # if URI scheme not secure set `trust`` and `encrypted` to default values + # https://neo4j.com/docs/api/python-driver/current/api.html#uri + _, security_type, _ = parse_neo4j_uri(uri=uri) + if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]: + default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True} + driver_args.update(default_security_conf) + + # if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver + validate_ssl_conf = conf.get(NEO4J_VALIDATE_SSL, None) + encrypted_conf = conf.get(NEO4J_ENCRYPTED, None) + if validate_ssl_conf is not None: + driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \ + else neo4j.TRUST_ALL_CERTIFICATES + if encrypted_conf is not None: + driver_args['encrypted'] = encrypted_conf + + self._driver = GraphDatabase.driver(**driver_args) self._db_name = conf.get_string(NEO4J_DATABASE_NAME) self._session = self._driver.session(database=self._db_name) diff --git a/databuilder/databuilder/task/neo4j_staleness_removal_task.py b/databuilder/databuilder/task/neo4j_staleness_removal_task.py index cbaae87333..4f54f2d70a 100644 --- a/databuilder/databuilder/task/neo4j_staleness_removal_task.py +++ b/databuilder/databuilder/task/neo4j_staleness_removal_task.py @@ -26,7 +26,6 @@ NEO4J_PASSWORD = 'neo4j_password' # in Neo4j (v4.0+), we can create and use more than one active database at the same time NEO4J_DATABASE_NAME = 'neo4j_database' -NEO4J_DRIVER = 'neo4j_driver' NEO4J_ENCRYPTED = 'neo4j_encrypted' """NEO4J_ENCRYPTED is a boolean indicating whether to use SSL/TLS when connecting.""" NEO4J_VALIDATE_SSL = 'neo4j_validate_ssl' @@ -131,34 +130,30 @@ def init(self, conf: ConfigTree) -> None: else: self.marker = conf.get_string(JOB_PUBLISH_TAG) - driver = conf.get(NEO4J_DRIVER, None) - if driver: - self._driver = driver - else: - uri = conf.get_string(NEO4J_END_POINT_KEY) - driver_args = { - 'uri': uri, - 'max_connection_lifetime': conf.get_int(NEO4J_MAX_CONN_LIFE_TIME_SEC), - 'auth': (conf.get_string(NEO4J_USER), conf.get_string(NEO4J_PASSWORD)), - } - - # if URI scheme not secure set `trust`` and `encrypted` to default values - # https://neo4j.com/docs/api/python-driver/current/api.html#uri - _, security_type, _ = parse_neo4j_uri(uri=uri) - if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]: - default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True} - driver_args.update(default_security_conf) - - # if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver - validate_ssl_conf = conf.get(NEO4J_VALIDATE_SSL, None) - encrypted_conf = conf.get(NEO4J_ENCRYPTED, None) - if validate_ssl_conf is not None: - driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \ - else neo4j.TRUST_ALL_CERTIFICATES - if encrypted_conf is not None: - driver_args['encrypted'] = encrypted_conf - - self._driver = GraphDatabase.driver(**driver_args) + uri = conf.get_string(NEO4J_END_POINT_KEY) + driver_args = { + 'uri': uri, + 'max_connection_lifetime': conf.get_int(NEO4J_MAX_CONN_LIFE_TIME_SEC), + 'auth': (conf.get_string(NEO4J_USER), conf.get_string(NEO4J_PASSWORD)), + } + + # if URI scheme not secure set `trust`` and `encrypted` to default values + # https://neo4j.com/docs/api/python-driver/current/api.html#uri + _, security_type, _ = parse_neo4j_uri(uri=uri) + if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]: + default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True} + driver_args.update(default_security_conf) + + # if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver + validate_ssl_conf = conf.get(NEO4J_VALIDATE_SSL, None) + encrypted_conf = conf.get(NEO4J_ENCRYPTED, None) + if validate_ssl_conf is not None: + driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \ + else neo4j.TRUST_ALL_CERTIFICATES + if encrypted_conf is not None: + driver_args['encrypted'] = encrypted_conf + + self._driver = GraphDatabase.driver(**driver_args) self.db_name = conf.get(NEO4J_DATABASE_NAME) diff --git a/databuilder/setup.py b/databuilder/setup.py index 3d86c2f7bd..c7df053ef8 100644 --- a/databuilder/setup.py +++ b/databuilder/setup.py @@ -5,7 +5,7 @@ from setuptools import find_packages, setup -__version__ = '7.1.0' +__version__ = '7.1.1' requirements_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'requirements.txt')