Skip to content

Commit

Permalink
Process unique directories for concurent work
Browse files Browse the repository at this point in the history
  • Loading branch information
calum-chamberlain committed Oct 24, 2023
1 parent 181d2ed commit a1e22dc
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 19 deletions.
9 changes: 8 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import shutil
import glob
from os.path import join, dirname

import pytest
Expand Down Expand Up @@ -58,7 +59,8 @@ def clean_up_test_files():
'dt.cc2',
'dt.ct',
'dt.ct2',
'phase.dat'
'phase.dat',
'eqcorrscan_temporary_party.pkl'
]

yield
Expand All @@ -85,14 +87,19 @@ def clean_up_test_directories():
'test_tar_write',
'tmp1',
'cc_exported',
'.streams',
'.parties'
]
directories_to_kill.extend(glob.glob(".template_db_*"))
directories_to_kill.extend(glob.glob(".streams_*"))

yield

# remove files
for directory in directories_to_kill:
if os.path.isdir(directory):
try:
print(f"Removing directory {directory}")
shutil.rmtree(directory)
except Exception as e:
print("Could not find directory, already cleaned?")
Expand Down
41 changes: 31 additions & 10 deletions eqcorrscan/core/match_filter/tribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ def __unique_ids(self, other=None):
f"{', '.join(non_unique_names)}")
return

@property
def _stream_dir(self):
""" Location for temporary streams """
return f".streams_{os.getpid()}"

def sort(self):
"""
Sort the tribe, sorts by template name.
Expand Down Expand Up @@ -877,6 +882,8 @@ def _detect_serial(
Logger.info(f"Added party from {_chunk_file}, party now "
f"contains {len(party)} detections")

if os.path.isdir(self._stream_dir):
shutil.rmtree(self._stream_dir)
return party

def _detect_concurrent(
Expand Down Expand Up @@ -928,6 +935,7 @@ def _detect_concurrent(
target=_pre_processor,
kwargs=dict(
stream_queue=stream,
temp_stream_dir=self._stream_dir,
template_ids=template_ids,
pre_processed=pre_processed,
filt_order=self.templates[0].filt_order,
Expand Down Expand Up @@ -1087,6 +1095,8 @@ def _detect_concurrent(
# Clean the template db
if os.path.isdir(template_dir):
shutil.rmtree(template_dir)
if os.path.isdir(self._stream_dir):
shutil.rmtree(self._stream_dir)
self._on_error(internal_error)

# Shut down the processes and close the queues
Expand All @@ -1098,6 +1108,8 @@ def _detect_concurrent(
self._close_processes()
if os.path.isdir(template_dir):
shutil.rmtree(template_dir)
if os.path.isdir(self._stream_dir):
shutil.rmtree(self._stream_dir)
return party

def client_detect(self, client, starttime, endtime, threshold,
Expand Down Expand Up @@ -1343,6 +1355,7 @@ def client_detect(self, client, starttime, endtime, threshold,
buff=buff,
out_queue=stream_queue,
poison_queue=poison_queue,
temp_stream_dir=self._stream_dir,
full_stream_dir=full_stream_dir,
pre_process=True, parallel_process=parallel_process,
process_cores=process_cores, daylong=daylong,
Expand Down Expand Up @@ -1994,7 +2007,9 @@ def _make_party(
chunk_file_str = os.path.join(
chunk_dir,
"chunk_party_{chunk_start_str}"
"_{chunk_id}.pkl")
"_{chunk_id}_{pid}.pkl")
# Process ID included in chunk file to avoid multiple processes writing
# and reading and removing the same files.

# Get the results out of the end!
Logger.info(f"Made {len(detections)} detections")
Expand Down Expand Up @@ -2045,7 +2060,7 @@ def _make_party(
chunk_file = chunk_file_str.format(
chunk_start_str=chunk_start.strftime("%Y-%m-%dT%H-%M-%S"),
chunk_start=chunk_start,
chunk_id=chunk_id)
chunk_id=chunk_id, pid=os.getpid())
with open(chunk_file, "wb") as _f:
pickle.dump(chunk_party, _f)
Logger.info("Completed party processing")
Expand Down Expand Up @@ -2116,6 +2131,7 @@ def _get_detection_stream(
buff: float,
out_queue: Queue,
poison_queue: Queue,
temp_stream_dir: str,
full_stream_dir: str = None,
pre_process: bool = False,
parallel_process: bool = True,
Expand Down Expand Up @@ -2193,12 +2209,14 @@ def _get_detection_stream(
Logger.info(f"After processing stream has {len(chunk)} traces:")
for tr in chunk:
Logger.info(tr)
if not os.path.isdir(".streams"):
os.makedirs(".streams")
if not os.path.isdir(temp_stream_dir):
os.makedirs(temp_stream_dir)
chunk_file = os.path.join(
".streams",
temp_stream_dir,
f"chunk_{len(chunk)}_"
f"{chunk[0].stats.starttime.strftime('%Y-%m-%dT%H-%M-%S')}.pkl")
f"{chunk[0].stats.starttime.strftime('%Y-%m-%dT%H-%M-%S')}"
f"_{os.getpid()}.pkl")
# Add PID to cope with multiple instances operating at once
_pickle_stream(chunk, chunk_file)
out_queue.put(chunk_file)
del chunk
Expand All @@ -2212,6 +2230,7 @@ def _get_detection_stream(

def _pre_processor(
stream_queue: Queue,
temp_stream_dir: str,
template_ids: set,
pre_processed: bool,
filt_order: int,
Expand Down Expand Up @@ -2251,13 +2270,15 @@ def _pre_processor(
samp_rate, process_length, parallel, cores, daylong,
ignore_length, ignore_bad_data, overlap)
for chunk in st_chunks:
if not os.path.isdir(".streams"):
os.makedirs(".streams")
if not os.path.isdir(temp_stream_dir):
os.makedirs(temp_stream_dir)
chunk_files = []
chunk_file = os.path.join(
".streams",
temp_stream_dir,
f"chunk_{len(chunk)}_"
f"{chunk[0].stats.starttime.strftime('%Y-%m-%dT%H-%M-%S')}.pkl")
f"{chunk[0].stats.starttime.strftime('%Y-%m-%dT%H-%M-%S')}"
f"_{os.getpid()}.pkl")
# Add PID to cope with multiple instances operating at once
_pickle_stream(chunk, chunk_file)
output_queue.put(chunk_file)
del chunk
Expand Down
8 changes: 0 additions & 8 deletions eqcorrscan/tests/match_filter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,12 +726,6 @@ def setUpClass(cls):
cls.unproc_st, cls.tribe, cls.onehztribe, cls.st, cls.party = (
unproc_st, tribe, onehztribe, st, party)

@classmethod
def tearDownClass(cls):
for f in ['eqcorrscan_temporary_party.pkl']:
if os.path.isfile(f):
os.remove(f)

def test_tribe_detect(self):
"""Test the detect method on Tribe objects"""
for conc_proc in [True, False]:
Expand Down Expand Up @@ -821,7 +815,6 @@ def test_tribe_detect_save_progress(self):
for pf in party_files:
saved_party += Party().read(pf)
self.assertEqual(party, saved_party)
shutil.rmtree(".parties")

@pytest.mark.serial
def test_tribe_detect_masked_data(self):
Expand Down Expand Up @@ -890,7 +883,6 @@ def test_client_detect_save_progress(self):
for pf in party_files:
saved_party += Party().read(pf)
self.assertEqual(party, saved_party)
shutil.rmtree(".parties")
compare_families(
party=party, party_in=self.party, float_tol=0.05,
check_event=False)
Expand Down

0 comments on commit a1e22dc

Please sign in to comment.