diff --git a/src/nasdaq_protocols/fix/codegen.py b/src/nasdaq_protocols/fix/codegen.py index 7ae6223..489636b 100644 --- a/src/nasdaq_protocols/fix/codegen.py +++ b/src/nasdaq_protocols/fix/codegen.py @@ -13,11 +13,14 @@ @click.option('--prefix', type=click.STRING, default='') @click.option('--op-dir', type=click.Path(exists=True, writable=True)) @click.option('--init-file/--no-init-file', show_default=True, default=True) -def generate(spec_file, app_name, op_dir, prefix, init_file): +@click.option('--fix-version', + type=click.Choice(['4.2', '4.4', '5.0', '5.0SP2']), + default='5.0SP2') +def generate(spec_file, app_name, op_dir, prefix, init_file, fix_version): try: generator = Generator( - parse(spec_file), + parse(spec_file, fix_version), app_name, op_dir, prefix, diff --git a/src/nasdaq_protocols/fix/parser/definitions.py b/src/nasdaq_protocols/fix/parser/definitions.py index 017662a..34d0190 100644 --- a/src/nasdaq_protocols/fix/parser/definitions.py +++ b/src/nasdaq_protocols/fix/parser/definitions.py @@ -19,8 +19,6 @@ 'Definitions' ] -from nasdaq_protocols.fix.parser.version_types import Version - @attrs.define class FieldDef: @@ -127,7 +125,7 @@ def get_codegen_context(self, definitions): @attrs.define class Definitions: - version: Version + version: str fields: dict[str, FieldDef] = attrs.field(init=False, factory=dict) components: dict[str, Component] = attrs.field(kw_only=True, factory=dict) header: EntryContainer = attrs.field(kw_only=True, factory=EntryContainer) @@ -157,8 +155,8 @@ def get_codegen_context(self): } def _client_session(self): - if self.version == Version.FIX_4_4: + if self.version == '4.4': return 'Fix44Session' - if self.version in (Version.FIX_5_0, Version.FIX_5_0_2): + if self.version in ('5.0', '5.0SP2'): return 'Fix50Session' raise ValueError(f'Version {self.version} is not supported') diff --git a/src/nasdaq_protocols/fix/parser/parser.py b/src/nasdaq_protocols/fix/parser/parser.py index 3ab9f03..2d9d041 100644 --- a/src/nasdaq_protocols/fix/parser/parser.py +++ b/src/nasdaq_protocols/fix/parser/parser.py @@ -13,7 +13,6 @@ ) from .version_types import ( get_supported_types, - Version, SupportedTypes ) @@ -24,24 +23,13 @@ LOG = logging.getLogger(__name__) -def parse(file: str) -> Definitions: +def parse(file: str, version: str) -> Definitions: tree = e_tree.parse(file) root = tree.getroot() if root.tag != 'fix': raise ValueError('root tag is not fix') - version_str = f'{root.get("major")}{root.get("minor")}' - servicepack = int(root.get('servicepack', '0')) - if servicepack > 0: - version_str += f'{servicepack}' - version = int(version_str) - - try: - version = Version(version) - except ValueError as v_error: - raise ValueError(f'Version {version} is not supported') from v_error - handlers = { 'fields': partial(_handle_fields, get_supported_types(version)), 'components': _handle_components, diff --git a/src/nasdaq_protocols/fix/parser/version_types.py b/src/nasdaq_protocols/fix/parser/version_types.py index 1b35800..789f01d 100644 --- a/src/nasdaq_protocols/fix/parser/version_types.py +++ b/src/nasdaq_protocols/fix/parser/version_types.py @@ -1,30 +1,20 @@ -import enum - from .. import types from ...common import TypeDefinition __all__ = [ 'SupportedTypes', - 'Version', 'get_supported_types' ] SupportedTypes = dict[str, TypeDefinition] -class Version(enum.IntEnum): - FIX_4_2 = 42 - FIX_4_4 = 44 - FIX_5_0 = 50 - FIX_5_0_2 = 502 - - -def get_supported_types(version: Version) -> SupportedTypes: +def get_supported_types(version: str) -> SupportedTypes: version_map = { - Version.FIX_4_2: fix_42_version_types, - Version.FIX_4_4: fix_44_version_types, - Version.FIX_5_0: fix_50_version_types, - Version.FIX_5_0_2: fix_502_version_types + '4.2': _fix_42_version_types, + '4.4': _fix_44_version_types, + '5.0': _fix_50_version_types, + '5.0SP2': _fix_502_version_types } try: return version_map[version]() @@ -32,7 +22,7 @@ def get_supported_types(version: Version) -> SupportedTypes: raise ValueError(f'Version {version} not supported') from k_error -def fix_42_version_types(): +def _fix_42_version_types(): return { 'AMT': types.FixAmount, 'BOOLEAN': types.FixBool, @@ -61,8 +51,8 @@ def fix_42_version_types(): } -def fix_44_version_types(): - fix_44_types = fix_42_version_types() +def _fix_44_version_types(): + fix_44_types = _fix_42_version_types() fix_44_types.update({ 'SEQNUM': types.FixInt, 'NUMINGROUP': types.FixInt, @@ -70,8 +60,8 @@ def fix_44_version_types(): return fix_44_types -def fix_50_version_types(): - fix_50_types = fix_42_version_types() +def _fix_50_version_types(): + fix_50_types = _fix_42_version_types() fix_50_types.update({ 'FIXSTRING': types.FixString, 'MULTIPLECHARVALUE': types.FixString, @@ -81,8 +71,8 @@ def fix_50_version_types(): return fix_50_types -def fix_502_version_types(): - fix_52_types = fix_50_version_types() +def _fix_502_version_types(): + fix_52_types = _fix_50_version_types() fix_52_types.update({ 'LOCALMKTDATE': types.FixLocalMktDate, 'TZTIMEONLY': types.FixTzTimeonly, diff --git a/tests/conftest.py b/tests/conftest.py index e5bb94f..a9e04e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -40,13 +40,14 @@ async def mock_server_session(unused_tcp_port): @pytest.fixture(scope='function') def codegen_invoker(capsys, tmp_path): - def generator(codegen, xml_content, app_name, generate_init_file, prefix, output_dir=None): + def generator(codegen, xml_content, app_name, generate_init_file, prefix, output_dir=None, extra_args=None): runner = CliRunner() with capsys.disabled(), runner.isolated_filesystem(temp_dir=tmp_path): with open('spec.xml', 'w') as spec_file: spec_file.write(xml_content) output_dir = output_dir or 'output' Path(output_dir).mkdir(parents=True, exist_ok=True) + extra_args = extra_args or [] result = runner.invoke( codegen, [ @@ -54,7 +55,8 @@ def generator(codegen, xml_content, app_name, generate_init_file, prefix, output '--app-name', app_name, '--op-dir', output_dir, '--prefix', prefix, - '--init-file' if generate_init_file else '--no-init-file' + '--init-file' if generate_init_file else '--no-init-file', + *extra_args ] ) assert result.exit_code == 0 diff --git a/tests/test_fix_codegen.py b/tests/test_fix_codegen.py index 0b890e6..ad455e9 100644 --- a/tests/test_fix_codegen.py +++ b/tests/test_fix_codegen.py @@ -13,7 +13,7 @@ @pytest.fixture(scope='function') def fix_44_definitions(tmp_file_writer): file = tmp_file_writer(TEST_FIX_44_XML) - definitions = parse(file) + definitions = parse(file, '4.4') assert definitions is not None yield definitions @@ -26,7 +26,8 @@ def test__no_init_file__no_prefix__code_generated(codegen_invoker): TEST_FIX_44_XML, app_name, generate_init_file=False, - prefix=prefix + prefix=prefix, + extra_args=['--fix-version', '4.4'] ) assert len(generated_files) == 5 @@ -42,7 +43,8 @@ def test__init_file__no_prefix__code_generated(fix_44_definitions, codegen_invok app_name, generate_init_file=True, prefix=prefix, - output_dir=output_dir + output_dir=output_dir, + extra_args=['--fix-version', '4.4'] ) assert len(generated_files) == 6 diff --git a/tests/test_fix_parser.py b/tests/test_fix_parser.py index ad9d3a6..84ead86 100644 --- a/tests/test_fix_parser.py +++ b/tests/test_fix_parser.py @@ -8,7 +8,7 @@ @pytest.fixture(scope='function') def fix_44_definitions(tmp_file_writer): file = tmp_file_writer(TEST_FIX_44_XML) - definitions = parse(file) + definitions = parse(file, '4.4') assert definitions is not None yield definitions @@ -21,7 +21,7 @@ def test__fix_parser__parse__invalid_root_tag(tmp_file_writer): file = tmp_file_writer(invalid_xml) with pytest.raises(ValueError) as e: - parse(file) + parse(file, '4.4') assert str(e.value) == 'root tag is not fix' @@ -37,7 +37,7 @@ def test__fix_parser__parse__invalid_tag_in_fields(tmp_file_writer): file = tmp_file_writer(invalid_xml) with pytest.raises(ValueError) as e: - parse(file) + parse(file, '4.4') assert str(e.value) == 'expected field tag, got invalid' @@ -53,9 +53,9 @@ def test__fix_parser__parse__unsupported_version(tmp_file_writer): file = tmp_file_writer(invalid_xml) with pytest.raises(ValueError) as e: - parse(file) + parse(file, '2.1') - assert str(e.value) == 'Version 21 is not supported' + assert str(e.value) == 'Version 2.1 not supported' def test__fix_parser__parse__component_not_found(tmp_file_writer): @@ -73,7 +73,7 @@ def test__fix_parser__parse__component_not_found(tmp_file_writer): file = tmp_file_writer(invalid_xml) with pytest.raises(ValueError) as e: - parse(file) + parse(file, '4.4') assert str(e.value) == 'Component definition for NotFoundComponent not found' @@ -91,7 +91,7 @@ def test__fix_parser__parse__field_not_found(tmp_file_writer): file = tmp_file_writer(invalid_xml) with pytest.raises(ValueError) as e: - parse(file) + parse(file, '4.4') assert str(e.value) == 'Field definition for NotFound not found' @@ -103,8 +103,8 @@ def test__fix_parser__parse__xml_with_service_pack(tmp_file_writer): ''' file = tmp_file_writer(fix_502) - definitions = parse(file) - assert definitions.version == 502 + definitions = parse(file, '5.0SP2') + assert definitions.version == '5.0SP2' def test__fix_parser__parse__xml_with_keywords__keywords_are_transformed(tmp_file_writer): @@ -120,8 +120,8 @@ def test__fix_parser__parse__xml_with_keywords__keywords_are_transformed(tmp_fil ''' file = tmp_file_writer(fix_502) - definitions = parse(file) - assert definitions.version == 502 + definitions = parse(file, '5.0SP2') + assert definitions.version == '5.0SP2' context = definitions.fields['MsgType'].get_codegen_context(None) assert context['values'][0]['f_value'] == 'None_' assert context['values'][1]['f_value'] == 'if_'