Skip to content

Commit

Permalink
Be robust against invalid utf-8 byte sequences and surrogateescape th…
Browse files Browse the repository at this point in the history
…em when en- or decoding (#144)

This commit also takes the opportunity to remove Python 2 string
compatibility code.

It will also remove the final left-over Python 2 compatibility
in the test cases.

Co-authored-by: Yun Zheng Hu <[email protected]>
  • Loading branch information
pyrco and yunzheng authored Oct 3, 2024
1 parent 8d6fe37 commit 4e1a285
Show file tree
Hide file tree
Showing 16 changed files with 133 additions and 140 deletions.
2 changes: 1 addition & 1 deletion flow/record/adapter/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def record_to_document(self, record: Record, index: str) -> dict:
}

if self.hash_record:
document["_id"] = hashlib.md5(document["_source"].encode()).hexdigest()
document["_id"] = hashlib.md5(document["_source"].encode(errors="surrogateescape")).hexdigest()

return document

Expand Down
2 changes: 1 addition & 1 deletion flow/record/adapter/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def write(self, rec: Record) -> None:
for key, value in rdict.items():
if rdict_types:
key = f"{key} ({rdict_types[key]})"
self.fp.write(fmt.format(key, value).encode())
self.fp.write(fmt.format(key, value).encode(errors="surrogateescape"))

def flush(self) -> None:
if self.fp:
Expand Down
2 changes: 1 addition & 1 deletion flow/record/adapter/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def read_table(self, table_name: str) -> Iterator[Record]:
if value == 0:
row[idx] = None
elif isinstance(value, str):
row[idx] = value.encode("utf-8")
row[idx] = value.encode(errors="surrogateescape")
yield descriptor_cls.init_from_dict(dict(zip(fnames, row)))

def __iter__(self) -> Iterator[Record]:
Expand Down
2 changes: 1 addition & 1 deletion flow/record/adapter/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def write(self, rec):
buf = self.format_spec.format_map(DefaultMissing(rec._asdict()))
else:
buf = repr(rec)
self.fp.write(buf.encode() + b"\n")
self.fp.write(buf.encode(errors="surrogateescape") + b"\n")

# because stdout is usually line buffered we force flush here if wanted
if self.auto_flush:
Expand Down
4 changes: 2 additions & 2 deletions flow/record/adapter/xlsx.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def sanitize_fieldvalues(values: Iterator[Any]) -> Iterator[Any]:
elif isinstance(value, bytes):
base64_encode = False
try:
new_value = 'b"' + value.decode() + '"'
new_value = 'b"' + value.decode(errors="surrogateescape") + '"'
if ILLEGAL_CHARACTERS_RE.search(new_value):
base64_encode = True
else:
Expand Down Expand Up @@ -142,7 +142,7 @@ def __iter__(self):
if field_types[idx] == "bytes":
if value[1] == '"': # If so, we know this is b""
# Cut of the b" at the start and the trailing "
value = value[2:-1].encode()
value = value[2:-1].encode(errors="surrogateescape")
else:
# If not, we know it is base64 encoded (so we cut of the starting 'base64:')
value = b64decode(value[7:])
Expand Down
4 changes: 2 additions & 2 deletions flow/record/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

from collections import OrderedDict

from .utils import to_native_str, to_str
from .utils import to_str
from .whitelist import WHITELIST, WHITELIST_TREE

log = logging.getLogger(__package__)
Expand Down Expand Up @@ -513,7 +513,7 @@ def __init__(self, name: str, fields: Optional[Sequence[tuple[str, str]]] = None
name, fields = parse_def(name)

self.name = name
self._field_tuples = tuple([(to_native_str(k), to_str(v)) for k, v in fields])
self._field_tuples = tuple([(to_str(k), to_str(v)) for k, v in fields])
self.recordType = _generate_record_class(name, self._field_tuples)
self.recordType._desc = self

Expand Down
29 changes: 2 additions & 27 deletions flow/record/fieldtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from flow.record.base import FieldType

RE_NORMALIZE_PATH = re.compile(r"[\\/]+")
NATIVE_UNICODE = isinstance("", str)

UTC = timezone.utc

Expand Down Expand Up @@ -207,10 +206,7 @@ def _pack(self):
class string(string_type, FieldType):
def __new__(cls, value):
if isinstance(value, bytes_type):
value = cls._decode(value, "utf-8")
if isinstance(value, bytes_type):
# Still bytes, so decoding failed (Python 2)
return bytes(value)
value = value.decode(errors="surrogateescape")
return super().__new__(cls, value)

def _pack(self):
Expand All @@ -221,27 +217,6 @@ def __format__(self, spec):
return defang(self)
return str.__format__(self, spec)

@classmethod
def _decode(cls, data, encoding):
"""Decode a byte-string into a unicode-string.
Python 3: When `data` contains invalid unicode characters a `UnicodeDecodeError` is raised.
Python 2: When `data` contains invalid unicode characters the original byte-string is returned.
"""
if NATIVE_UNICODE:
# Raises exception on decode error
return data.decode(encoding)
try:
return data.decode(encoding)
except UnicodeDecodeError:
# Fallback to bytes (Python 2 only)
preview = data[:16].encode("hex_codec") + (".." if len(data) > 16 else "")
warnings.warn(
"Got binary data in string field (hex: {}). Compatibility is not guaranteed.".format(preview),
RuntimeWarning,
)
return data


# Alias for backwards compatibility
wstring = string
Expand Down Expand Up @@ -278,7 +253,7 @@ def __new__(cls, *args, **kwargs):
if len(args) == 1 and not kwargs:
arg = args[0]
if isinstance(arg, bytes_type):
arg = arg.decode("utf-8")
arg = arg.decode(errors="surrogateescape")
if isinstance(arg, string_type):
# If we are on Python 3.11 or newer, we can use fromisoformat() to parse the string (fast path)
#
Expand Down
7 changes: 0 additions & 7 deletions flow/record/fieldtypes/net/ipv4.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings

from flow.record import FieldType
from flow.record.utils import to_native_str


def addr_long(s):
Expand Down Expand Up @@ -45,9 +44,6 @@ def __init__(self, addr, netmask=None):
DeprecationWarning,
stacklevel=5,
)
if isinstance(addr, type("")):
addr = to_native_str(addr)

if not isinstance(addr, str):
raise TypeError("Subnet() argument 1 must be string, not {}".format(type(addr).__name__))

Expand All @@ -67,9 +63,6 @@ def __contains__(self, addr):
if addr is None:
return False

if isinstance(addr, type("")):
addr = to_native_str(addr)

if isinstance(addr, str):
addr = addr_long(addr)

Expand Down
6 changes: 1 addition & 5 deletions flow/record/jsonpacker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,8 @@ def pack_obj(self, obj):
serial["_recorddescriptor"] = obj._desc.identifier

for field_type, field_name in obj._desc.get_field_tuples():
# PYTHON2: Because "bytes" are also "str" we have to handle this here
if field_type == "bytes" and isinstance(serial[field_name], str):
serial[field_name] = base64.b64encode(serial[field_name]).decode()

# Boolean field types should be cast to a bool instead of staying ints
elif field_type == "boolean" and isinstance(serial[field_name], int):
if field_type == "boolean" and isinstance(serial[field_name], int):
serial[field_name] = bool(serial[field_name])

return serial
Expand Down
40 changes: 18 additions & 22 deletions flow/record/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
import base64
import os
import sys
import warnings
from functools import wraps
from typing import BinaryIO, TextIO

_native = str
_unicode = type("")
_bytes = type(b"")


def get_stdout(binary: bool = False) -> TextIO | BinaryIO:
"""Return the stdout stream as binary or text stream.
Expand Down Expand Up @@ -50,33 +47,32 @@ def is_stdout(fp: TextIO | BinaryIO) -> bool:

def to_bytes(value):
"""Convert a value to a byte string."""
if value is None or isinstance(value, _bytes):
if value is None or isinstance(value, bytes):
return value
if isinstance(value, _unicode):
return value.encode("utf-8")
return _bytes(value)
if isinstance(value, str):
return value.encode(errors="surrogateescape")
return bytes(value)


def to_str(value):
"""Convert a value to a unicode string."""
if value is None or isinstance(value, _unicode):
if value is None or isinstance(value, str):
return value
if isinstance(value, _bytes):
return value.decode("utf-8")
return _unicode(value)
if isinstance(value, bytes):
return value.decode(errors="surrogateescape")
return str(value)


def to_native_str(value):
"""Convert a value to a native `str`."""
if value is None or isinstance(value, _native):
return value
if isinstance(value, _unicode):
# Python 2: unicode -> str
return value.encode("utf-8")
if isinstance(value, _bytes):
# Python 3: bytes -> str
return value.decode("utf-8")
return _native(value)
warnings.warn(
(
"The to_native_str() function is deprecated, "
"this function will be removed in flow.record 3.20, "
"use to_str() instead"
),
DeprecationWarning,
)
return to_str(value)


def to_base64(value):
Expand Down
29 changes: 29 additions & 0 deletions tests/test_adapter_line.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from io import BytesIO

from flow.record import RecordDescriptor
from flow.record.adapter.line import LineWriter


def test_line_writer_write_surrogateescape():
output = BytesIO()

lw = LineWriter(
path=output,
fields="name",
)

TestRecord = RecordDescriptor(
"test/string",
[
("string", "name"),
],
)

# construct from 'bytes' but with invalid unicode bytes
record = TestRecord(b"R\xc3\xa9\xeamy")
lw.write(record)

output.seek(0)
data = output.read()

assert data == b"--[ RECORD 1 ]--\nname = R\xc3\xa9\xeamy\n"
28 changes: 28 additions & 0 deletions tests/test_adapter_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from io import BytesIO

from flow.record import RecordDescriptor
from flow.record.adapter.text import TextWriter


def test_text_writer_write_surrogateescape():
output = BytesIO()

tw = TextWriter(
path=output,
)

TestRecord = RecordDescriptor(
"test/string",
[
("string", "name"),
],
)

# construct from 'bytes' but with invalid unicode bytes
record = TestRecord(b"R\xc3\xa9\xeamy")
tw.write(record)

output.seek(0)
data = output.read()

assert data == b"<test/string name='R\xc3\xa9\\udceamy'>\n"
11 changes: 2 additions & 9 deletions tests/test_fieldtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,8 @@ def test_string():
assert r.name == "Rémy"

# construct from 'bytes' but with invalid unicode bytes
if isinstance("", str):
# Python 3
with pytest.raises(UnicodeDecodeError):
TestRecord(b"R\xc3\xa9\xeamy")
else:
# Python 2
with pytest.warns(RuntimeWarning):
r = TestRecord(b"R\xc3\xa9\xeamy")
assert r.name
r = TestRecord(b"R\xc3\xa9\xeamy")
assert r.name == "Ré\udceamy"


def test_wstring():
Expand Down
20 changes: 20 additions & 0 deletions tests/test_json_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,23 @@ def test_record_pack_bool_regression() -> None:

# pack the json string back to a record and make sure it is the same as before
assert packer.unpack(data) == record


def test_record_pack_surrogateescape() -> None:
TestRecord = RecordDescriptor(
"test/string",
[
("string", "name"),
],
)

record = TestRecord(b"R\xc3\xa9\xeamy")
packer = JsonRecordPacker()

data = packer.pack(record)

# pack to json string and check if the 3rd and 4th byte are properly surrogate escaped
assert data.startswith('{"name": "R\\u00e9\\udceamy",')

# pack the json string back to a record and make sure it is the same as before
assert packer.unpack(data) == record
29 changes: 25 additions & 4 deletions tests/test_record.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import inspect
import os
import sys
from unittest.mock import patch
Expand Down Expand Up @@ -27,8 +28,6 @@
from flow.record.exceptions import RecordDescriptorError
from flow.record.stream import RecordFieldRewriter

from . import utils_inspect as inspect


def test_record_creation():
TestRecord = RecordDescriptor(
Expand Down Expand Up @@ -288,8 +287,30 @@ def isatty():
writer.write(record)

out, err = capsys.readouterr()
modifier = "" if isinstance("", str) else "u"
expected = "<test/a a_string={u}'hello' common={u}'world' a_count=10>\n".format(u=modifier)
expected = "<test/a a_string='hello' common='world' a_count=10>\n"
assert out == expected


def test_record_printer_stdout_surrogateescape(capsys):
Record = RecordDescriptor(
"test/a",
[
("string", "name"),
],
)
record = Record(b"R\xc3\xa9\xeamy")

# fake capsys to be a tty.
def isatty():
return True

capsys._capture.out.tmpfile.isatty = isatty

writer = RecordPrinter(getattr(sys.stdout, "buffer", sys.stdout))
writer.write(record)

out, err = capsys.readouterr()
expected = "<test/a name='Ré\\udceamy'>\n"
assert out == expected


Expand Down
Loading

0 comments on commit 4e1a285

Please sign in to comment.