Skip to content

Commit

Permalink
add post sanitycheck hook for cuDNN
Browse files Browse the repository at this point in the history
  • Loading branch information
truib committed May 5, 2024
1 parent 2050a89 commit 1f0206f
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions eb_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,62 @@ def post_sanitycheck_cuda(self, *args, **kwargs):
raise EasyBuildError("CUDA-specific hook triggered for non-CUDA easyconfig?!")


def post_sanitycheck_cuDNN(self, *args, **kwargs):
"""
Remove files from cuDNN installation that we are not allowed to ship,
and replace them with a symlink to a corresponding installation under host_injections.
"""
if self.name == 'cuDNN':
print_msg("Replacing files in cuDNN installation that we can not ship with symlinks to host_injections...")

allowlist = ['LICENSE']

# read cuDNN LICENSE, construct allowlist based on section 2.6 that specifies list of files that can be shipped
license_path = os.path.join(self.installdir, 'LICENSE')
search_string = "2. Distribution. The following portions of the SDK are distributable under the Agreement:"
with open(license_path) as infile:
for line in infile:
if line.strip().startswidth(search_string):
# remove search string, split into words, remove trailing
# dots '.' and only retain words starting with a dot '.'
distributable = line[len(search_string):]
for word in distributable.split():
if word[0] == '.':
allowlist.append(word.rstrip('.'))

allowlist = sorted(set(allowlist))
self.log.info("Allowlist for files in cuDNN installation that can be redistributed: " + ', '.join(allowlist))

# iterate over all files in the CUDA installation directory
for dir_path, _, files in os.walk(self.installdir):
for filename in files:
full_path = os.path.join(dir_path, filename)
# we only really care about real files, i.e. not symlinks
if not os.path.islink(full_path):
# check if the current file is part of the allowlist
basename = filename.split('.')[0]
if '.' in filename:
extension = '.' + filename.split('.')[1]
if basename in allowlist:
self.log.debug("%s is found in allowlist, so keeping it: %s", basename, full_path)
elif '.' in filename and extension in allowlist:
self.log.debug("%s is found in allowlist, so keeping it: %s", extension, full_path)
else:
self.log.debug("%s is not found in allowlist, so replacing it with symlink: %s",
filename, full_path)
# if it is not in the allowlist, delete the file and create a symlink to host_injections
host_inj_path = full_path.replace('versions', 'host_injections')
# make sure source and target of symlink are not the same
if full_path == host_inj_path:
raise EasyBuildError("Source (%s) and target (%s) are the same location, are you sure you "
"are using this hook for a NESSI installation?",
full_path, host_inj_path)
remove_file(full_path)
symlink(host_inj_path, full_path)
else:
raise EasyBuildError("cuDNN-specific hook triggered for non-cuDNN easyconfig?!")


def inject_gpu_property(ec):
"""
Add 'gpu' property, via modluafooter easyconfig parameter
Expand Down Expand Up @@ -768,4 +824,5 @@ def inject_gpu_property(ec):

POST_SANITYCHECK_HOOKS = {
'CUDA': post_sanitycheck_cuda,
'cuDNN': post_sanitycheck_cuDNN,
}

0 comments on commit 1f0206f

Please sign in to comment.