Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix codegen #34

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/nasdaq_protocols/fix/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
@click.option('--prefix', type=click.STRING, default='')
@click.option('--op-dir', type=click.Path(exists=True, writable=True))
@click.option('--init-file/--no-init-file', show_default=True, default=True)
def generate(spec_file, app_name, op_dir, prefix, init_file):
@click.option('--fix-version',
type=click.Choice(['4.2', '4.4', '5.0', '5.0SP2']),
default='5.0SP2')
def generate(spec_file, app_name, op_dir, prefix, init_file, fix_version):

try:
generator = Generator(
parse(spec_file),
parse(spec_file, fix_version),
app_name,
op_dir,
prefix,
Expand Down
8 changes: 3 additions & 5 deletions src/nasdaq_protocols/fix/parser/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
'Definitions'
]

from nasdaq_protocols.fix.parser.version_types import Version


@attrs.define
class FieldDef:
Expand Down Expand Up @@ -127,7 +125,7 @@ def get_codegen_context(self, definitions):

@attrs.define
class Definitions:
version: Version
version: str
fields: dict[str, FieldDef] = attrs.field(init=False, factory=dict)
components: dict[str, Component] = attrs.field(kw_only=True, factory=dict)
header: EntryContainer = attrs.field(kw_only=True, factory=EntryContainer)
Expand Down Expand Up @@ -157,8 +155,8 @@ def get_codegen_context(self):
}

def _client_session(self):
if self.version == Version.FIX_4_4:
if self.version == '4.4':
return 'Fix44Session'
if self.version in (Version.FIX_5_0, Version.FIX_5_0_2):
if self.version in ('5.0', '5.0SP2'):
return 'Fix50Session'
raise ValueError(f'Version {self.version} is not supported')
14 changes: 1 addition & 13 deletions src/nasdaq_protocols/fix/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
)
from .version_types import (
get_supported_types,
Version,
SupportedTypes
)

Expand All @@ -24,24 +23,13 @@
LOG = logging.getLogger(__name__)


def parse(file: str) -> Definitions:
def parse(file: str, version: str) -> Definitions:
tree = e_tree.parse(file)
root = tree.getroot()

if root.tag != 'fix':
raise ValueError('root tag is not fix')

version_str = f'{root.get("major")}{root.get("minor")}'
servicepack = int(root.get('servicepack', '0'))
if servicepack > 0:
version_str += f'{servicepack}'
version = int(version_str)

try:
version = Version(version)
except ValueError as v_error:
raise ValueError(f'Version {version} is not supported') from v_error

handlers = {
'fields': partial(_handle_fields, get_supported_types(version)),
'components': _handle_components,
Expand Down
34 changes: 12 additions & 22 deletions src/nasdaq_protocols/fix/parser/version_types.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,28 @@
import enum

from .. import types
from ...common import TypeDefinition


__all__ = [
'SupportedTypes',
'Version',
'get_supported_types'
]
SupportedTypes = dict[str, TypeDefinition]


class Version(enum.IntEnum):
FIX_4_2 = 42
FIX_4_4 = 44
FIX_5_0 = 50
FIX_5_0_2 = 502


def get_supported_types(version: Version) -> SupportedTypes:
def get_supported_types(version: str) -> SupportedTypes:
version_map = {
Version.FIX_4_2: fix_42_version_types,
Version.FIX_4_4: fix_44_version_types,
Version.FIX_5_0: fix_50_version_types,
Version.FIX_5_0_2: fix_502_version_types
'4.2': _fix_42_version_types,
'4.4': _fix_44_version_types,
'5.0': _fix_50_version_types,
'5.0SP2': _fix_502_version_types
}
try:
return version_map[version]()
except KeyError as k_error:
raise ValueError(f'Version {version} not supported') from k_error


def fix_42_version_types():
def _fix_42_version_types():
return {
'AMT': types.FixAmount,
'BOOLEAN': types.FixBool,
Expand Down Expand Up @@ -61,17 +51,17 @@ def fix_42_version_types():
}


def fix_44_version_types():
fix_44_types = fix_42_version_types()
def _fix_44_version_types():
fix_44_types = _fix_42_version_types()
fix_44_types.update({
'SEQNUM': types.FixInt,
'NUMINGROUP': types.FixInt,
})
return fix_44_types


def fix_50_version_types():
fix_50_types = fix_42_version_types()
def _fix_50_version_types():
fix_50_types = _fix_42_version_types()
fix_50_types.update({
'FIXSTRING': types.FixString,
'MULTIPLECHARVALUE': types.FixString,
Expand All @@ -81,8 +71,8 @@ def fix_50_version_types():
return fix_50_types


def fix_502_version_types():
fix_52_types = fix_50_version_types()
def _fix_502_version_types():
fix_52_types = _fix_50_version_types()
fix_52_types.update({
'LOCALMKTDATE': types.FixLocalMktDate,
'TZTIMEONLY': types.FixTzTimeonly,
Expand Down
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,23 @@ async def mock_server_session(unused_tcp_port):

@pytest.fixture(scope='function')
def codegen_invoker(capsys, tmp_path):
def generator(codegen, xml_content, app_name, generate_init_file, prefix, output_dir=None):
def generator(codegen, xml_content, app_name, generate_init_file, prefix, output_dir=None, extra_args=None):
runner = CliRunner()
with capsys.disabled(), runner.isolated_filesystem(temp_dir=tmp_path):
with open('spec.xml', 'w') as spec_file:
spec_file.write(xml_content)
output_dir = output_dir or 'output'
Path(output_dir).mkdir(parents=True, exist_ok=True)
extra_args = extra_args or []
result = runner.invoke(
codegen,
[
'--spec-file', 'spec.xml',
'--app-name', app_name,
'--op-dir', output_dir,
'--prefix', prefix,
'--init-file' if generate_init_file else '--no-init-file'
'--init-file' if generate_init_file else '--no-init-file',
*extra_args
]
)
assert result.exit_code == 0
Expand Down
8 changes: 5 additions & 3 deletions tests/test_fix_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@pytest.fixture(scope='function')
def fix_44_definitions(tmp_file_writer):
file = tmp_file_writer(TEST_FIX_44_XML)
definitions = parse(file)
definitions = parse(file, '4.4')
assert definitions is not None
yield definitions

Expand All @@ -26,7 +26,8 @@ def test__no_init_file__no_prefix__code_generated(codegen_invoker):
TEST_FIX_44_XML,
app_name,
generate_init_file=False,
prefix=prefix
prefix=prefix,
extra_args=['--fix-version', '4.4']
)

assert len(generated_files) == 5
Expand All @@ -42,7 +43,8 @@ def test__init_file__no_prefix__code_generated(fix_44_definitions, codegen_invok
app_name,
generate_init_file=True,
prefix=prefix,
output_dir=output_dir
output_dir=output_dir,
extra_args=['--fix-version', '4.4']
)

assert len(generated_files) == 6
Expand Down
22 changes: 11 additions & 11 deletions tests/test_fix_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
@pytest.fixture(scope='function')
def fix_44_definitions(tmp_file_writer):
file = tmp_file_writer(TEST_FIX_44_XML)
definitions = parse(file)
definitions = parse(file, '4.4')
assert definitions is not None
yield definitions

Expand All @@ -21,7 +21,7 @@ def test__fix_parser__parse__invalid_root_tag(tmp_file_writer):
file = tmp_file_writer(invalid_xml)

with pytest.raises(ValueError) as e:
parse(file)
parse(file, '4.4')

assert str(e.value) == 'root tag is not fix'

Expand All @@ -37,7 +37,7 @@ def test__fix_parser__parse__invalid_tag_in_fields(tmp_file_writer):
file = tmp_file_writer(invalid_xml)

with pytest.raises(ValueError) as e:
parse(file)
parse(file, '4.4')

assert str(e.value) == 'expected field tag, got invalid'

Expand All @@ -53,9 +53,9 @@ def test__fix_parser__parse__unsupported_version(tmp_file_writer):
file = tmp_file_writer(invalid_xml)

with pytest.raises(ValueError) as e:
parse(file)
parse(file, '2.1')

assert str(e.value) == 'Version 21 is not supported'
assert str(e.value) == 'Version 2.1 not supported'


def test__fix_parser__parse__component_not_found(tmp_file_writer):
Expand All @@ -73,7 +73,7 @@ def test__fix_parser__parse__component_not_found(tmp_file_writer):
file = tmp_file_writer(invalid_xml)

with pytest.raises(ValueError) as e:
parse(file)
parse(file, '4.4')

assert str(e.value) == 'Component definition for NotFoundComponent not found'

Expand All @@ -91,7 +91,7 @@ def test__fix_parser__parse__field_not_found(tmp_file_writer):
file = tmp_file_writer(invalid_xml)

with pytest.raises(ValueError) as e:
parse(file)
parse(file, '4.4')

assert str(e.value) == 'Field definition for NotFound not found'

Expand All @@ -103,8 +103,8 @@ def test__fix_parser__parse__xml_with_service_pack(tmp_file_writer):
'''
file = tmp_file_writer(fix_502)

definitions = parse(file)
assert definitions.version == 502
definitions = parse(file, '5.0SP2')
assert definitions.version == '5.0SP2'


def test__fix_parser__parse__xml_with_keywords__keywords_are_transformed(tmp_file_writer):
Expand All @@ -120,8 +120,8 @@ def test__fix_parser__parse__xml_with_keywords__keywords_are_transformed(tmp_fil
'''
file = tmp_file_writer(fix_502)

definitions = parse(file)
assert definitions.version == 502
definitions = parse(file, '5.0SP2')
assert definitions.version == '5.0SP2'
context = definitions.fields['MsgType'].get_codegen_context(None)
assert context['values'][0]['f_value'] == 'None_'
assert context['values'][1]['f_value'] == 'if_'
Expand Down
Loading