-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
anikaweinmann
committed
Jan 31, 2022
1 parent
5a38a04
commit 7075f4f
Showing
10 changed files
with
631 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
[flake8] | ||
|
||
# E501 line too long | ||
# F821 undefined name '_' | ||
|
||
exclude = .git | ||
max-line-length = 88 | ||
per-file-ignores = | ||
./r.learn.parallel.predict/r.learn.parallel.predict.py: F821, E501 | ||
./r.learn.predict.worker/r.learn.predict.worker.py: F821, E501 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
name: Python Flake8 Code Quality | ||
|
||
on: | ||
- push | ||
- pull_request | ||
|
||
jobs: | ||
flake8: | ||
runs-on: ubuntu-20.04 | ||
|
||
steps: | ||
- uses: actions/checkout@v2 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: 3.8 | ||
|
||
- name: Install | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install flake8==3.8.4 | ||
- name: Run Flake8 | ||
run: | | ||
flake8 --config=.flake8 --count --statistics --show-source --jobs=$(nproc) . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MODULE_TOPDIR = ../.. | ||
|
||
PGM = r.learn.predict.parallel | ||
|
||
# note: to deactivate a module, just place a file "DEPRECATED" in the subdir | ||
ALL_SUBDIRS := ${sort ${dir ${wildcard */.}}} | ||
DEPRECATED_SUBDIRS := ${sort ${dir ${wildcard */DEPRECATED}}} | ||
RM_SUBDIRS := bin/ docs/ etc/ scripts/ | ||
SUBDIRS_1 := $(filter-out $(DEPRECATED_SUBDIRS), $(ALL_SUBDIRS)) | ||
SUBDIRS := $(filter-out $(RM_SUBDIRS), $(SUBDIRS_1)) | ||
|
||
# $(warning ALL_SUBDIRS is $(ALL_SUBDIRS)) | ||
# $(warning DEPRECATED_SUBDIRS is $(DEPRECATED_SUBDIRS)) | ||
# $(warning SUBDIRS is $(SUBDIRS)) | ||
|
||
include $(MODULE_TOPDIR)/include/Make/Dir.make | ||
|
||
default: parsubdirs htmldir | ||
|
||
install: installsubdirs | ||
$(INSTALL_DATA) $(PGM).html $(INST_DIR)/docs/html/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
MODULE_TOPDIR = ../.. | ||
|
||
PGM = r.learn.parallel.predict | ||
|
||
include $(MODULE_TOPDIR)/include/Make/Script.make | ||
|
||
default: script |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
<h2>DESCRIPTION</h2> | ||
|
||
<em>r.learn.parallel.predict</em> applies the classification model in | ||
parallel using <em>r.learn.predict</em>. | ||
|
||
<h2>SEE ALSO</h2> | ||
|
||
<em> | ||
<a href="r.learn.predict.html">r.learn.predict</a> | ||
</em> | ||
|
||
<h2>AUTHOR</h2> | ||
|
||
Anika Weinmann, <a href="https://www.mundialis.de/">mundialis GmbH & Co. KG</a> | ||
|
||
<!-- | ||
<p> | ||
<i>Last changed: $Date$</i> | ||
--> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,273 @@ | ||
#!/usr/bin/env python3 | ||
############################################################################ | ||
# | ||
# MODULE: r.learn.parallel.predict | ||
# AUTHOR(S): Anika Weinmann | ||
# PURPOSE: Applies the classification model parallel using r.learn.predict | ||
# COPYRIGHT: (C) 2020-2022 by mundialis GmbH & Co. KG and the GRASS | ||
# Development Team | ||
# | ||
# This program is free software; you can redistribute it and/or modify | ||
# it under the terms of the GNU General Public License as published by | ||
# the Free Software Foundation; either version 2 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU General Public License for more details. | ||
# | ||
############################################################################ | ||
|
||
# %module | ||
# % description: Applies a classification model in parallel using r.learn.predict. | ||
# % keyword: raster | ||
# % keyword: classification | ||
# % keyword: regression | ||
# % keyword: machine learning | ||
# % keyword: scikit-learn | ||
# % keyword: prediction | ||
# % keyword: parallel | ||
# %end | ||
# %flag | ||
# % key: p | ||
# % label: Output class membership probabilities | ||
# % description: A raster layer is created for each class. For the case of a binary classification, only the positive (maximum) class is output | ||
# % guisection: Optional | ||
# %end | ||
# %flag | ||
# % key: z | ||
# % label: Only predict class probabilities | ||
# % guisection: Optional | ||
# %end | ||
# %flag | ||
# % key: v | ||
# % label: Create a VRT (Virtual Raster) as output | ||
# % guisection: Optional | ||
# %end | ||
|
||
# %option G_OPT_I_GROUP | ||
# % key: group | ||
# % label: Group of raster layers used for prediction | ||
# % description: GRASS imagery group of raster maps representing feature variables to be used in the machine learning model | ||
# % required: yes | ||
# % multiple: no | ||
# %end | ||
|
||
# %option G_OPT_F_INPUT | ||
# % key: load_model | ||
# % label: Load model from file | ||
# % description: File representing pickled scikit-learn estimator model | ||
# % required: yes | ||
# % guisection: Required | ||
# %end | ||
|
||
# %option G_OPT_R_OUTPUT | ||
# % key: output | ||
# % label: Output Map | ||
# % description: Raster layer name to store result from classification or regression model. The name will also used as a perfix if class probabilities or intermediate of cross-validation results are ordered as maps. | ||
# % guisection: Required | ||
# % required: yes | ||
# %end | ||
|
||
# %option | ||
# % key: chunksize | ||
# % type: integer | ||
# % label: Number of pixels to pass to the prediction method | ||
# % description: Number of pixels to pass to the prediction method. GRASS GIS reads raster by-row so chunksize is rounded down based on the number of columns | ||
# % answer: 100000 | ||
# % guisection: Optional | ||
# %end | ||
|
||
# %option G_OPT_M_NPROCS | ||
# % label: Number of parallel processes used for band importing in sen2cor | ||
# % description: Number of cores for multiprocessing, -2 is n_cores-1 | ||
# % answer: -2 | ||
# % guisection: Optional | ||
# %end | ||
|
||
# %option | ||
# % key: grid | ||
# % type: integer | ||
# % required: no | ||
# % multiple: no | ||
# % key_desc: rows,columns | ||
# % description: Number of rows and columns in grid | ||
# %end | ||
|
||
import atexit | ||
import sys | ||
import os | ||
import multiprocessing as mp | ||
|
||
import grass.script as grass | ||
from grass.pygrass.modules import Module, ParallelModuleQueue | ||
|
||
# initialize global vars | ||
rm_regions = [] | ||
rm_vectors = [] | ||
rm_rasters = [] | ||
|
||
|
||
def cleanup(): | ||
nuldev = open(os.devnull, "w") | ||
kwargs = {"flags": "f", "quiet": True, "stderr": nuldev} | ||
for rmr in rm_regions: | ||
if rmr in [x for x in grass.parse_command("g.list", type="region")]: | ||
grass.run_command("g.remove", type="region", name=rmr, **kwargs) | ||
for rmv in rm_vectors: | ||
if grass.find_file(name=rmv, element="vector")["file"]: | ||
grass.run_command("g.remove", type="vector", name=rmv, **kwargs) | ||
for rmrast in rm_rasters: | ||
if grass.find_file(name=rmrast, element="raster")["file"]: | ||
grass.run_command("g.remove", type="raster", name=rmrast, **kwargs) | ||
|
||
|
||
def set_test_nprocs(nprocs): | ||
# Test nprocs settings | ||
nprocs_real = mp.cpu_count() | ||
if nprocs == -2: | ||
procs = nprocs_real - 1 | ||
grass.info("Using %d parallel processes" % (procs)) | ||
return procs | ||
else: | ||
if nprocs > nprocs_real: | ||
grass.warning( | ||
"Using %d parallel processes but only %d CPUs available." | ||
% (nprocs, nprocs_real) | ||
) | ||
return nprocs | ||
|
||
|
||
def main(): | ||
|
||
global rm_regions, rm_rasters, rm_vectors | ||
|
||
# parallelization parameter | ||
n_jobs = set_test_nprocs(int(options["n_jobs"])) | ||
|
||
# parameter of r.learn.predict | ||
group = options["group"] | ||
output = options["output"] | ||
load_model = options["load_model"] | ||
chunksize = options["chunksize"] | ||
flags_str = "" | ||
for flag in flags: | ||
if flags[flag] and not flag == "v": | ||
flags_str += flag | ||
|
||
if options["grid"]: | ||
grid_rows_cols = options["grid"] | ||
else: | ||
grid_rows_cols = "%d,%d" % (n_jobs, n_jobs) | ||
|
||
# set some common environmental variables, like: | ||
os.environ.update( | ||
dict( | ||
GRASS_COMPRESS_NULLS="1", | ||
GRASS_COMPRESSOR="ZSTD", | ||
GRASS_MESSAGE_FORMAT="plain", | ||
) | ||
) | ||
|
||
# test if r.learn.predict is installed | ||
if not grass.find_program("r.learn.predict", "--help"): | ||
grass.fatal( | ||
_( | ||
"The 'r.learn.predict' module was not found, install it first:" | ||
+ "\n" | ||
+ "g.extension r.learn.ml2" | ||
) | ||
) | ||
|
||
if n_jobs > 1: | ||
grass.message(_("Generating grid to for parallelization ...")) | ||
grid = "tmp_grid_%s" % os.getpid() | ||
grass.run_command("v.mkgrid", map=grid, grid=grid_rows_cols) | ||
rm_vectors.append(grid) | ||
|
||
reg = grass.region() | ||
cats = list( | ||
grass.parse_command("v.category", input=grid, option="print").keys() | ||
) | ||
|
||
grass.message(_("Predict parallel on the grid cells ...")) | ||
# save current mapset | ||
env = grass.gisenv() | ||
# start_gisdbase = env['GISDBASE'] | ||
# start_location = env['LOCATION_NAME'] | ||
start_cur_mapset = env["MAPSET"] | ||
|
||
queue = ParallelModuleQueue(nprocs=n_jobs) | ||
classifications = [] | ||
for cat in cats: | ||
new_mapset = "tmp_mapset_rlearnpredict_%s" % cat | ||
tmp_output = "%s_%s" % (output, cat) | ||
# Module | ||
r_grid_predict = Module( | ||
"r.learn.predict.worker", | ||
area=grid, | ||
where="cat=%s" % cat, | ||
mapset=new_mapset, | ||
nsres=reg["nsres"], | ||
ewres=reg["ewres"], | ||
group=group, | ||
output=tmp_output, | ||
load_model=load_model, | ||
run_=False, | ||
chunksize=chunksize, | ||
) | ||
classifications.append("%s@%s" % (tmp_output, new_mapset)) | ||
queue.put(r_grid_predict) | ||
queue.wait() | ||
|
||
# verify that switchiing the mapset worked | ||
env = grass.gisenv() | ||
gisdbase = env["GISDBASE"] | ||
location = env["LOCATION_NAME"] | ||
cur_mapset = env["MAPSET"] | ||
if cur_mapset != start_cur_mapset: | ||
grass.fatal( | ||
"new mapset is %s, but should be %s" % (cur_mapset, start_cur_mapset) | ||
) | ||
|
||
for classification in classifications: | ||
name_mapset = classification.split("@") | ||
grass.run_command( | ||
"g.copy", raster="%s,%s" % (classification, name_mapset[0]) | ||
) | ||
grass.utils.try_rmdir(os.path.join(gisdbase, location, name_mapset[1])) | ||
|
||
# patching | ||
grass.message(_("Patching the tiles ...")) | ||
grass.message(_("Current region for patching:\n%s") % grass.region()) | ||
|
||
if len(classifications) > 1: | ||
module = "r.patch" | ||
if flags["v"]: | ||
module = "r.buildvrt" | ||
all_classified = [x.split("@")[0] for x in classifications] | ||
grass.run_command(module, input=all_classified, output=output) | ||
if not flags["v"]: | ||
rm_rasters.extend(all_classified) | ||
else: | ||
all_classified = [x.split("@")[0] for x in classifications][0] | ||
grass.run_command("g.copy", raster=all_classified + "," + output) | ||
|
||
grass.message(_("Patching the tiles done")) | ||
else: | ||
grass.run_command( | ||
"r.learn.predict", | ||
group=group, | ||
output=output, | ||
load_model=load_model, | ||
chunksize=chunksize, | ||
flags=flags_str, | ||
) | ||
return 0 | ||
|
||
|
||
if __name__ == "__main__": | ||
options, flags = grass.parser() | ||
atexit.register(cleanup) | ||
sys.exit(main()) |
Oops, something went wrong.