Skip to content

Commit

Permalink
fix(io): also check root tags when testing for file types
Browse files Browse the repository at this point in the history
Fixes #435
  • Loading branch information
sanjayankur31 committed Nov 6, 2024
1 parent 1a909ba commit 302d22c
Showing 1 changed file with 37 additions and 20 deletions.
57 changes: 37 additions & 20 deletions pyneuroml/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import lems.model.model as lems_model
import neuroml.loaders as loaders
import neuroml.writers as writers
from lxml import etree
from neuroml import NeuroMLDocument

from pyneuroml.errors import ARGUMENT_ERR, FILE_NOT_FOUND_ERR, NMLFileTypeError
Expand All @@ -27,15 +28,6 @@
logger.setLevel(logging.INFO)


# extension: standard
pynml_file_type_dict = {
"xml": "LEMS",
"nml": "NeuroML",
"sedml": "SED-ML",
"sbml": "SBML",
}


def read_neuroml2_file(
nml2_file_name: str,
include_includes: bool = False,
Expand Down Expand Up @@ -256,7 +248,7 @@ def confirm_neuroml_file(filename: str, sys_error: bool = False) -> None:
)

try:
confirm_file_type(filename, ["nml"])
confirm_file_type(filename, ["nml"], "neuroml")
except NMLFileTypeError as e:
if filename.startswith("LEMS_"):
logger.warning(error_string)
Expand All @@ -276,9 +268,6 @@ def confirm_lems_file(filename: str, sys_error: bool = False) -> None:
:param sys_error: toggle whether function should exit or raise exception
:type sys_error: bool
"""
# print('Checking file: %s'%filename)
# Some conditions to check if a LEMS file was entered
# TODO: Ideally we'd like to check the root node: checking file extensions is brittle
error_string = textwrap.dedent(
"""
*************************************************************************************
Expand All @@ -288,7 +277,7 @@ def confirm_lems_file(filename: str, sys_error: bool = False) -> None:
"""
)
try:
confirm_file_type(filename, ["xml"])
confirm_file_type(filename, ["xml"], "lems")
except NMLFileTypeError as e:
if filename.endswith("nml"):
logger.warning(error_string)
Expand All @@ -302,15 +291,22 @@ def confirm_lems_file(filename: str, sys_error: bool = False) -> None:
def confirm_file_type(
filename: str,
file_exts: typing.List[str],
root_tag: typing.Optional[str] = None,
error_str: typing.Optional[str] = None,
sys_error: bool = False,
) -> None:
"""Confirm that a file exists and has the necessary extension
"""Confirm that a file exists and is of the provided type.
First we rely on file extensions to test for type, since this is the
simplest way. If this test fails, we read the full file and test the root
tag if one has been provided.
:param filename: filename to confirm
:type filename: str
:param file_exts: list of valid file extensions, without the leading dot
:type file_exts: list of strings
:param root_tag: root tag for file, used if extensions do not match
:type root_tag: str
:param error_str: an optional error string to print along with the thrown
exception
:type error_str: string (optional)
Expand All @@ -320,11 +316,32 @@ def confirm_file_type(
"""
confirm_file_exists(filename)
filename_ext = filename.split(".")[-1]
file_types = [f"{x} ({pynml_file_type_dict[x]})" for x in file_exts]
if filename_ext not in file_exts:
error_string = (
f"Expected file extension(s): {', '.join(file_types)}; got {filename_ext}"
)

matched = False

if filename_ext in file_exts:
matched = True

got_root_tag = None

if matched is False:
if root_tag is not None:
with open(filename) as i_file:
xml_tree = etree.parse(i_file)
tree_root = xml_tree.getroot()
got_root_tag = tree_root.tag

if got_root_tag.lower() == root_tag.lower():
matched = True

if matched is False:
error_string = f"Expected file extension does not match: {', '.join(file_exts)}; got {filename_ext}."

if root_tag is not None:
error_string += (
f" Expected root tag does not match: {root_tag}; got {got_root_tag}"
)

if error_str is not None:
error_string += "\n" + error_str
if sys_error is True:
Expand Down

0 comments on commit 302d22c

Please sign in to comment.