Skip to content

Commit

Permalink
Use run_graph_locally.py for test_run_job_to_s3. Open-EO/openeo-geotr…
Browse files Browse the repository at this point in the history
  • Loading branch information
EmileSonneveld committed Nov 6, 2024
1 parent 9bb8a6b commit 89c40a0
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 107 deletions.
103 changes: 78 additions & 25 deletions openeogeotrellis/deploy/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
_log = logging.getLogger(__name__)


def setup_local_spark(additional_jar_dirs=[]):
def is_port_free(port: int) -> bool:
import socket

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) != 0


def setup_local_spark(verbosity=0):
# TODO: make this more reusable (e.g. also see `_setup_local_spark` in tests/conftest.py)
from pyspark import SparkContext, find_spark_home

Expand All @@ -29,21 +36,44 @@ def setup_local_spark(additional_jar_dirs=[]):
_log.debug("sys.path: {p!r}".format(p=sys.path))
master_str = "local[2]"

OPENEO_LOCAL_DEBUGGING = smart_bool(os.environ.get("OPENEO_LOCAL_DEBUGGING", "false"))
if "PYSPARK_PYTHON" not in os.environ:
os.environ["PYSPARK_PYTHON"] = sys.executable

from geopyspark import geopyspark_conf
from pyspark import SparkContext

# Make sure geopyspark can find the custom jars (e.g. geotrellis-extension)
# even if test suite is not run from project root (e.g. "run this test" functionality in an IDE like PyCharm)
additional_jar_dirs = [
Path(__file__).parent.parent.parent / "jars",
]

conf = geopyspark_conf(
master=master_str, appName="openeo-geotrellis-local", additional_jar_dirs=additional_jar_dirs
master=master_str,
appName="openeo-geopyspark-driver",
additional_jar_dirs=additional_jar_dirs,
)

spark_jars = conf.get("spark.jars").split(",")
# geotrellis-extensions needs to be loaded first to avoid "java.lang.NoClassDefFoundError: shapeless/lazily$"
spark_jars.sort(key=lambda x: "geotrellis-extensions" not in x)
conf.set(key="spark.jars", value=",".join(spark_jars))

# Use UTC timezone by default when formatting/parsing dates (e.g. CSV export of timeseries)
conf.set("spark.sql.session.timeZone", "UTC")

conf.set("spark.kryoserializer.buffer.max", value="1G")
conf.set(key="spark.kryo.registrator", value="geotrellis.spark.store.kryo.KryoRegistrator")
conf.set("spark.kryo.registrator", "geotrellis.spark.store.kryo.KryoRegistrator")
conf.set(
key="spark.kryo.classesToRegister",
value="ar.com.hjg.pngj.ImageInfo,ar.com.hjg.pngj.ImageLineInt,geotrellis.raster.RasterRegion$GridBoundsRasterRegion",
)
# Only show spark progress bars for high verbosity levels
conf.set("spark.ui.showConsoleProgress", verbosity >= 3)

conf.set(key="spark.driver.memory", value="2G")
conf.set(key="spark.executor.memory", value="2G")
OPENEO_LOCAL_DEBUGGING = smart_bool(os.environ.get("OPENEO_LOCAL_DEBUGGING", "false"))
conf.set("spark.ui.enabled", OPENEO_LOCAL_DEBUGGING)

jars = []
Expand Down Expand Up @@ -72,32 +102,55 @@ def setup_local_spark(additional_jar_dirs=[]):
"${sys:openeo.logging.threshold}", "DEBUG"
)
)

# 'agentlib' to allow attaching a Java debugger to running Spark driver
extra_options = f"-Dlog4j2.configurationFile=file:{sparkSubmitLog4jConfigurationFile}"
extra_options += " -Dgeotrellis.jts.precision.type=fixed -Dgeotrellis.jts.simplification.scale=1e10"
# Some options to allow attaching a Java debugger to running Spark driver
# got some options from 'sparkDriverJavaOptions'
sparkDriverJavaOptions = f"-Dlog4j2.configurationFile=file:{sparkSubmitLog4jConfigurationFile}\
-Dscala.concurrent.context.numThreads=6 \
-Dsoftware.amazon.awssdk.http.service.impl=software.amazon.awssdk.http.urlconnection.UrlConnectionSdkHttpService\
-Dtsservice.layersConfigClass=ProdLayersConfiguration -Dtsservice.sparktasktimeout=600"
sparkDriverJavaOptions += " -Dgeotrellis.jts.precision.type=fixed -Dgeotrellis.jts.simplification.scale=1e10"
if OPENEO_LOCAL_DEBUGGING:
extra_options += f" -agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=5009"
conf.set("spark.driver.extraJavaOptions", extra_options)
# conf.set('spark.executor.extraJavaOptions', extra_options) # Seems not needed

conf.set(key="spark.driver.memory", value="2G")
conf.set(key="spark.executor.memory", value="2G")
for port in [5005, 5009]:
if is_port_free(port):
# 'agentlib' to allow attaching a Java debugger to running Spark driver
# IntelliJ IDEA: Run -> Edit Configurations -> Remote JVM Debug uses 5005 by default
sparkDriverJavaOptions += f" -agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=*:{port}"
break
conf.set("spark.driver.extraJavaOptions", sparkDriverJavaOptions)

sparkExecutorJavaOptions = f"-Dlog4j2.configurationFile=file:{sparkSubmitLog4jConfigurationFile}\
-Dsoftware.amazon.awssdk.http.service.impl=software.amazon.awssdk.http.urlconnection.UrlConnectionSdkHttpService\
-Dscala.concurrent.context.numThreads=8"
conf.set("spark.executor.extraJavaOptions", sparkExecutorJavaOptions)

_log.info("[conftest.py] SparkContext.getOrCreate with {c!r}".format(c=conf.getAll()))
context = SparkContext.getOrCreate(conf)
context.setLogLevel("INFO")
_log.info(
"[conftest.py] JVM info: {d!r}".format(
d={
f: context._jvm.System.getProperty(f)
for f in [
"java.version",
"java.vendor",
"java.home",
"java.class.version",
# "java.class.path",
]
}
)
)

if "PYSPARK_PYTHON" not in os.environ:
os.environ["PYSPARK_PYTHON"] = sys.executable
if OPENEO_LOCAL_DEBUGGING:
# TODO: Activate default logging for this message
print("Spark web UI: " + str(context.uiWebUrl))

_log.info("Creating Spark context with config:")
for k, v in conf.getAll():
_log.info("Spark config: {k!r}: {v!r}".format(k=k, v=v))
pysc = SparkContext.getOrCreate(conf)
pysc.setLogLevel("INFO")
_log.info("Created Spark Context {s}".format(s=pysc))
if OPENEO_LOCAL_DEBUGGING:
_log.info("Spark web UI: http://localhost:{p}/".format(p=pysc.getConf().get("spark.ui.port") or 4040))
_log.info("[conftest.py] Validating the Spark context")
dummy = context._jvm.org.openeo.geotrellis.OpenEOProcesses()
answer = context.parallelize([9, 10, 11, 12]).sum()
_log.info("[conftest.py] " + repr((answer, dummy)))

return pysc
return context


def on_started() -> None:
Expand Down
42 changes: 0 additions & 42 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,48 +212,6 @@ def _setup_local_spark(out: TerminalReporter, verbosity=0):
return context


# noinspection PyProtectedMember
def restart_spark_context():
from pyspark import SparkContext

with SparkContext._lock:
# Need to shut down before creating a new SparkConf (Before SparkContext is not enough)
# Like this, the new environment variables are available inside the JVM
if SparkContext._active_spark_context:
SparkContext._active_spark_context.stop()
SparkContext._gateway.shutdown()
SparkContext._gateway = None
SparkContext._jvm = None

class TerminalReporterMock:
@staticmethod
def write_line(message):
print(message)

# noinspection PyTypeChecker
_setup_local_spark(TerminalReporterMock(), 0)


@pytest.fixture
def custom_spark_context_restart_instant():
"""
Add this fixture at the end of your argument list.
The restarted JVM will pick up your environment variables
https://docs.pytest.org/en/6.2.x/fixture.html#yield-fixtures-recommended
"""
restart_spark_context()


@pytest.fixture
def custom_spark_context_restart_delayed():
"""
Add this fixture at the beginning of your argument list.
The JVM will be restarted when all mocking is cleaned up.
"""
yield "Spark context is globally accesible now"
restart_spark_context()


@pytest.fixture(params=["1.0.0"])
def api_version(request):
return request.param
Expand Down
65 changes: 25 additions & 40 deletions tests/deploy/test_batch_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import re
import shutil
import subprocess
import tempfile
import textwrap
import zipfile
Expand All @@ -13,7 +14,7 @@
from openeo_driver.delayed_vector import DelayedVector
from openeo_driver.dry_run import DryRunDataTracer
from openeo_driver.save_result import ImageCollectionResult
from openeo_driver.testing import DictSubSet
from openeo_driver.testing import DictSubSet, ListSubSet
from openeo_driver.utils import read_json
from osgeo import gdal
from pytest import approx
Expand Down Expand Up @@ -1301,69 +1302,53 @@ def test_run_job_get_projection_extension_metadata_assets_in_s3_multiple_assets(
)


@mock.patch(
"openeogeotrellis.configparams.ConfigParams.is_kube_deploy",
new_callable=mock.PropertyMock,
)
def test_run_job_to_s3(
mock_config_is_kube_deploy,
custom_spark_context_restart_delayed,
tmp_path,
mock_s3_bucket,
moto_server,
custom_spark_context_restart_instant,
monkeypatch,
):
mock_config_is_kube_deploy.return_value = True
monkeypatch.setenv("KUBE", "TRUE")
spatial_extent_tap = {
"east": 5.08,
"north": 51.22,
"south": 51.215,
"west": 5.07,
}
process_graph = {
"lc": {
"process_id": "load_collection",
"arguments": {
"id": "TestCollection-LonLat4x4",
"id": "TestCollection-LonLat16x16",
"temporal_extent": ["2021-01-01", "2021-01-10"],
"spatial_extent": {
"east": 5.08,
"north": 51.22,
"south": 51.215,
"west": 5.07,
},
"spatial_extent": spatial_extent_tap,
"bands": ["Longitude", "Latitude", "Day"],
},
},
"resamplespatial1": {
"process_id": "resample_spatial",
"arguments": {
"align": "upper-left",
"data": {"from_node": "lc"},
"method": "bilinear",
"projection": 4326,
"resolution": 0.000297619047619,
},
},
"save": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "lc"}, "format": "GTiff"},
"result": True,
},
}

run_job(
job_specification={
"process_graph": process_graph,
},
output_file=tmp_path / "out",
metadata_file=tmp_path / "metadata.json",
api_version="2.0.0",
job_dir=tmp_path,
dependencies=[],
user_id="jenkins",
)
json_path = tmp_path / "process_graph.json"
json.dump(process_graph, json_path.open("wt"))

cmd = ["run_graph_locally.py", json_path]
try:
# Run in separate subprocess so that all environment variables are
# set correctly at the moment the SparkContext is created:
output = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True)
except subprocess.CalledProcessError as e:
error, output, returncode = e, e.output, e.returncode
print(output)

s3_instance = s3_client()
from openeogeotrellis.config import get_backend_config

files = {o["Key"] for o in s3_instance.list_objects(Bucket=get_backend_config().s3_bucket_name)["Contents"]}
files = {f[len(str(tmp_path)) :] for f in files}
assert files == {"collection.json", "metadata.json", "openEO_2021-01-05Z.tif", "openEO_2021-01-05Z.tif.json"}
files = [f[len(str(tmp_path)) :] for f in files]
assert files == ListSubSet(["collection.json", "openEO_2021-01-05Z.tif", "openEO_2021-01-05Z.tif.json"])


# TODO: Update this test to include statistics or not? Would need to update the json file.
Expand Down

0 comments on commit 89c40a0

Please sign in to comment.