Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ipnetwork and ipaddress hashable #148

Merged
merged 6 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 37 additions & 18 deletions flow/record/fieldtypes/net/ip.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,75 @@
from ipaddress import ip_address, ip_network
from __future__ import annotations

from ipaddress import (
IPv4Address,
IPv4Network,
IPv6Address,
IPv6Network,
ip_address,
ip_network,
)
from typing import Union

Miauwkeru marked this conversation as resolved.
Show resolved Hide resolved
from flow.record.base import FieldType
from flow.record.fieldtypes import defang

_IPNetwork = Union[IPv4Network, IPv6Network]
_IPAddress = Union[IPv4Address, IPv6Address]


class ipaddress(FieldType):
val = None
_type = "net.ipaddress"

def __init__(self, addr):
def __init__(self, addr: str | int | bytes):
self.val = ip_address(addr)

def __eq__(self, b):
def __eq__(self, b: str | int | bytes | _IPAddress) -> bool:
try:
return self.val == ip_address(b)
except ValueError:
return False

def __str__(self):
def __hash__(self) -> int:
return hash(self.val)

def __str__(self) -> str:
return str(self.val)

def __repr__(self):
return "{}({!r})".format(self._type, str(self))
def __repr__(self) -> str:
return f"{self._type}({str(self)!r})"

def __format__(self, spec):
def __format__(self, spec: str) -> str:
if spec == "defang":
return defang(str(self))
return str.__format__(str(self), spec)

def _pack(self):
def _pack(self) -> int:
return int(self.val)

@staticmethod
def _unpack(data):
def _unpack(data: int) -> ipaddress:
return ipaddress(data)


class ipnetwork(FieldType):
val = None
_type = "net.ipnetwork"

def __init__(self, addr):
def __init__(self, addr: str | int | bytes):
self.val = ip_network(addr)

def __eq__(self, b):
def __eq__(self, b: str | int | bytes | _IPNetwork) -> bool:
try:
return self.val == ip_network(b)
except ValueError:
return False

def __hash__(self) -> int:
return hash(self.val)

@staticmethod
def _is_subnet_of(a, b):
def _is_subnet_of(a: _IPNetwork, b: _IPNetwork) -> bool:
try:
# Always false if one is v4 and the other is v6.
if a._version != b._version:
Expand All @@ -59,23 +78,23 @@ def _is_subnet_of(a, b):
except AttributeError:
raise TypeError("Unable to test subnet containment " "between {} and {}".format(a, b))

Miauwkeru marked this conversation as resolved.
Show resolved Hide resolved
def __contains__(self, b):
def __contains__(self, b: str | int | bytes | _IPAddress) -> bool:
try:
return self._is_subnet_of(ip_network(b), self.val)
except (ValueError, TypeError):
return False

def __str__(self):
def __str__(self) -> str:
return str(self.val)

def __repr__(self):
return "{}({!r})".format(self._type, str(self))
def __repr__(self) -> str:
return f"{self._type}({str(self)!r})"

def _pack(self):
def _pack(self) -> str:
return self.val.compressed

@staticmethod
def _unpack(data):
def _unpack(data: str) -> ipnetwork:
return ipnetwork(data)


Expand Down
15 changes: 15 additions & 0 deletions tests/test_fieldtype_ip.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,19 @@ def test_record_ipaddress():
assert TestRecord("0.0.0.0").ip == "0.0.0.0"
assert TestRecord("192.168.0.1").ip == "192.168.0.1"
assert TestRecord("255.255.255.255").ip == "255.255.255.255"
assert hash(TestRecord("192.168.0.1").ip) == hash(net.ipaddress("192.168.0.1"))

# ipv6
assert TestRecord("::1").ip == "::1"
assert TestRecord("2001:4860:4860::8888").ip == "2001:4860:4860::8888"
assert TestRecord("2001:4860:4860::4444").ip == "2001:4860:4860::4444"

# Test whether it functions in a set
data = {TestRecord(ip).ip for ip in ["192.168.0.1", "192.168.0.1", "::1", "::1"]}
assert len(data) == 2
assert net.ipaddress("::1") in data
assert net.ipaddress("192.168.0.1") in data

# instantiate from different types
assert TestRecord(1).ip == "0.0.0.1"
assert TestRecord(0x7F0000FF).ip == "127.0.0.255"
Expand Down Expand Up @@ -90,6 +97,7 @@ def test_record_ipnetwork():
assert "192.168.1.1" not in r.subnet
assert isinstance(r.subnet, net.ipnetwork)
assert repr(r.subnet) == "net.ipnetwork('192.168.0.0/24')"
assert hash(r.subnet) == hash(net.ipnetwork("192.168.0.0/24"))
yunzheng marked this conversation as resolved.
Show resolved Hide resolved

r = TestRecord("192.168.1.1/32")
assert r.subnet == "192.168.1.1"
Expand All @@ -111,6 +119,13 @@ def test_record_ipnetwork():
assert "64:ff9b::0.0.0.0" in r.subnet
assert "64:ff9b::255.255.255.255" in r.subnet

# Test whether it functions in a set
data = {TestRecord(x).subnet for x in ["192.168.0.0/24", "192.168.0.0/24", "::1", "::1"]}
assert len(data) == 2
assert net.ipnetwork("::1") in data
assert net.ipnetwork("192.168.0.0/24") in data
assert "::1" not in data


@pytest.mark.parametrize("PSelector", [Selector, CompiledSelector])
def test_selector_ipaddress(PSelector):
Expand Down
Loading