diff --git a/package/etc/pylib/parser_source_cache.py b/package/etc/pylib/parser_source_cache.py index dc573a344..61e670d44 100644 --- a/package/etc/pylib/parser_source_cache.py +++ b/package/etc/pylib/parser_source_cache.py @@ -17,6 +17,40 @@ class LogParser: class LogDestination: pass +import builtins +import io +import pickle +from base64 import b64decode + +safe_builtins = { + 'range', + 'complex', + 'set', + 'frozenset', + 'slice', +} + +class RestrictedUnpickler(pickle.Unpickler): + def find_class(self, module, name): + # Only allow safe classes from builtins. + if module == "builtins" and name in safe_builtins: + return getattr(builtins, name) + # Forbid everything else. + raise pickle.UnpicklingError("global '%s.%s' is forbidden" % + (module, name)) + +def restricted_loads(s): + """Helper function analogous to pickle.loads().""" + return RestrictedUnpickler(io.BytesIO(s)).load() + +def restricted_decode(obj): + """Overwrite sqlitedict.decode to prevent code injection.""" + return restricted_loads(bytes(obj)) + +def restricted_decode_key(key): + """Overwrite sqlitedict.decode_key to prevent code injection.""" + return restricted_loads(b64decode(key.encode("ascii"))) + def ip2int(addr): ip4_to_int = lambda addr: struct.unpack("!I", socket.inet_aton(addr))[0] @@ -54,7 +88,7 @@ def int_to_ip6(num): class psc_parse(LogParser): def init(self, options): self.logger = syslogng.Logger() - self.db = SqliteDict(f"{hostdict}.sqlite") + self.db = SqliteDict(f"{hostdict}.sqlite", decode=restricted_decode, decode_key=restricted_decode_key) return True def deinit(self): @@ -82,7 +116,7 @@ class psc_dest(LogDestination): def init(self, options): self.logger = syslogng.Logger() try: - self.db = SqliteDict(f"{hostdict}.sqlite", autocommit=True) + self.db = SqliteDict(f"{hostdict}.sqlite", autocommit=True, decode=restricted_decode, decode_key=restricted_decode_key) except Exception: exc_type, exc_value, exc_traceback = sys.exc_info() lines = traceback.format_exception(exc_type, exc_value, exc_traceback) @@ -123,7 +157,7 @@ def flush(self): if __name__ == "__main__": - db = SqliteDict(f"{hostdict}.sqlite", autocommit=True) + db = SqliteDict(f"{hostdict}.sqlite", autocommit=True, decode=restricted_decode, decode_key=restricted_decode_key) db[0] = "seed" db.commit() db.close() \ No newline at end of file diff --git a/package/etc/pylib/parser_vps_cache.py b/package/etc/pylib/parser_vps_cache.py index a95162862..d55d5ef04 100644 --- a/package/etc/pylib/parser_vps_cache.py +++ b/package/etc/pylib/parser_vps_cache.py @@ -18,13 +18,47 @@ class LogDestination: pass +import builtins +import io +import pickle +from base64 import b64decode + +safe_builtins = { + 'range', + 'complex', + 'set', + 'frozenset', + 'slice', +} + +class RestrictedUnpickler(pickle.Unpickler): + def find_class(self, module, name): + # Only allow safe classes from builtins. + if module == "builtins" and name in safe_builtins: + return getattr(builtins, name) + # Forbid everything else. + raise pickle.UnpicklingError("global '%s.%s' is forbidden" % + (module, name)) + +def restricted_loads(s): + """Helper function analogous to pickle.loads().""" + return RestrictedUnpickler(io.BytesIO(s)).load() + +def restricted_decode(obj): + """Overwrite sqlitedict.decode to prevent code injection.""" + return restricted_loads(bytes(obj)) + +def restricted_decode_key(key): + """Overwrite sqlitedict.decode_key to prevent code injection.""" + return restricted_loads(b64decode(key.encode("ascii"))) + hostdict = str("/var/lib/syslog-ng/vps") class vpsc_parse(LogParser): def init(self, options): self.logger = syslogng.Logger() - self.db = SqliteDict(f"{hostdict}.sqlite") + self.db = SqliteDict(f"{hostdict}.sqlite", decode=restricted_decode, decode_key=restricted_decode_key) return True def deinit(self): @@ -52,7 +86,7 @@ class vpsc_dest(LogDestination): def init(self, options): self.logger = syslogng.Logger() try: - self.db = SqliteDict(f"{hostdict}.sqlite", autocommit=True) + self.db = SqliteDict(f"{hostdict}.sqlite", autocommit=True, decode=restricted_decode, decode_key=restricted_decode_key) except Exception: exc_type, exc_value, exc_traceback = sys.exc_info() lines = traceback.format_exception(exc_type, exc_value, exc_traceback)