Skip to content

Commit

Permalink
Merge branch 'master' into HF_Echidna
Browse files Browse the repository at this point in the history
  • Loading branch information
Jim8y authored Nov 20, 2024
2 parents 650c9e9 + 842fa52 commit 331541a
Show file tree
Hide file tree
Showing 15 changed files with 328 additions and 43 deletions.
82 changes: 82 additions & 0 deletions src/Neo.Extensions/Collections/CollectionExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (C) 2015-2024 The Neo Project.
//
// CollectionExtensions.cs file belongs to the neo project and is free
// software distributed under the MIT software license, see the
// accompanying file LICENSE in the main directory of the
// repository or http://www.opensource.org/licenses/mit-license.php
// for more details.
//
// Redistribution and use in source and binary forms with or without
// modifications are permitted.


using System;
using System.Collections.Generic;


namespace Neo.Extensions
{
public static class CollectionExtensions
{
/// <summary>
/// Removes the key-value pairs from the dictionary that match the specified predicate.
/// </summary>
/// <typeparam name="TKey">The type of the keys in the dictionary.</typeparam>
/// <typeparam name="TValue">The type of the values in the dictionary.</typeparam>
/// <param name="dict">The dictionary to remove key-value pairs from.</param>
/// <param name="predicate">The predicate to match key-value pairs.</param>
/// <param name="afterRemoved">An action to perform after each key-value pair is removed.</param>
public static void RemoveWhere<TKey, TValue>(
this IDictionary<TKey, TValue> dict,
Func<KeyValuePair<TKey, TValue>, bool> predicate,
Action<KeyValuePair<TKey, TValue>>? afterRemoved = null)
{
var items = new List<KeyValuePair<TKey, TValue>>();
foreach (var item in dict) // avoid linq
{
if (predicate(item))
items.Add(item);
}

foreach (var item in items)
{
if (dict.Remove(item.Key))
afterRemoved?.Invoke(item);
}
}

/// <summary>
/// Chunks the source collection into chunks of the specified size.
/// For example, if the source collection is [1, 2, 3, 4, 5] and the chunk size is 3, the result will be [[1, 2, 3], [4, 5]].
/// </summary>
/// <typeparam name="T">The type of the elements in the collection.</typeparam>
/// <param name="source">The collection to chunk.</param>
/// <param name="chunkSize">The size of each chunk.</param>
/// <returns>An enumerable of arrays, each containing a chunk of the source collection.</returns>
/// <exception cref="ArgumentNullException">Thrown when the source collection is null.</exception>
/// <exception cref="ArgumentOutOfRangeException">Thrown when the chunk size is less than or equal to 0.</exception>
public static IEnumerable<T[]> Chunk<T>(this IReadOnlyCollection<T> source, int chunkSize)
{
if (source is null)
throw new ArgumentNullException(nameof(source));

if (chunkSize <= 0)
throw new ArgumentOutOfRangeException(nameof(chunkSize), "Chunk size must > 0.");

using IEnumerator<T> enumerator = source.GetEnumerator();
for (var remain = source.Count; remain > 0;)
{
var chunk = new T[Math.Min(remain, chunkSize)];
for (var i = 0; i < chunk.Length; i++)
{
if (!enumerator.MoveNext()) // Additional checks
throw new InvalidOperationException("unexpected end of sequence");
chunk[i] = enumerator.Current;
}

remain -= chunk.Length;
yield return chunk;
}
}
}
}
2 changes: 1 addition & 1 deletion src/Neo/Network/P2P/Message.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ internal static int TryDeserialize(ByteString data, out Message msg)
payloadIndex += 8;
}

if (length > PayloadMaxSize) throw new FormatException();
if (length > PayloadMaxSize) throw new FormatException($"Invalid payload length: {length}.");

if (data.Count < (int)length + payloadIndex) return 0;

Expand Down
9 changes: 6 additions & 3 deletions src/Neo/Network/P2P/Payloads/ExtensiblePayload.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,17 @@ Witness[] IVerifiable.Witnesses
}
set
{
if (value.Length != 1) throw new ArgumentException();
if (value.Length != 1) throw new ArgumentException($"Expected 1 witness, got {value.Length}.");
Witness = value[0];
}
}

void ISerializable.Deserialize(ref MemoryReader reader)
{
((IVerifiable)this).DeserializeUnsigned(ref reader);
if (reader.ReadByte() != 1) throw new FormatException();
var count = reader.ReadByte();
if (count != 1)
throw new FormatException($"Expected 1 witness, got {count}.");
Witness = reader.ReadSerializable<Witness>();
}

Expand All @@ -103,7 +105,8 @@ void IVerifiable.DeserializeUnsigned(ref MemoryReader reader)
Category = reader.ReadVarString(32);
ValidBlockStart = reader.ReadUInt32();
ValidBlockEnd = reader.ReadUInt32();
if (ValidBlockStart >= ValidBlockEnd) throw new FormatException();
if (ValidBlockStart >= ValidBlockEnd)
throw new FormatException($"Invalid valid block range: {ValidBlockStart} >= {ValidBlockEnd}.");
Sender = reader.ReadSerializable<UInt160>();
Data = reader.ReadVarMemory();
}
Expand Down
2 changes: 1 addition & 1 deletion src/Neo/Network/P2P/Payloads/GetBlockByIndexPayload.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void ISerializable.Deserialize(ref MemoryReader reader)
IndexStart = reader.ReadUInt32();
Count = reader.ReadInt16();
if (Count < -1 || Count == 0 || Count > HeadersPayload.MaxHeadersCount)
throw new FormatException();
throw new FormatException($"Invalid count: {Count}/{HeadersPayload.MaxHeadersCount}.");
}

void ISerializable.Serialize(BinaryWriter writer)
Expand Down
3 changes: 2 additions & 1 deletion src/Neo/Network/P2P/Payloads/GetBlocksPayload.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ void ISerializable.Deserialize(ref MemoryReader reader)
{
HashStart = reader.ReadSerializable<UInt256>();
Count = reader.ReadInt16();
if (Count < -1 || Count == 0) throw new FormatException();
if (Count < -1 || Count == 0)
throw new FormatException($"Invalid count: {Count}.");
}

void ISerializable.Serialize(BinaryWriter writer)
Expand Down
8 changes: 3 additions & 5 deletions src/Neo/Network/P2P/Payloads/InvPayload.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,14 @@ public static InvPayload Create(InventoryType type, params UInt256[] hashes)
/// <param name="type">The type of the inventories.</param>
/// <param name="hashes">The hashes of the inventories.</param>
/// <returns>The created payloads.</returns>
public static IEnumerable<InvPayload> CreateGroup(InventoryType type, UInt256[] hashes)
public static IEnumerable<InvPayload> CreateGroup(InventoryType type, IReadOnlyCollection<UInt256> hashes)
{
for (int i = 0; i < hashes.Length; i += MaxHashesCount)
foreach (var chunk in hashes.Chunk(MaxHashesCount))
{
int endIndex = i + MaxHashesCount;
if (endIndex > hashes.Length) endIndex = hashes.Length;
yield return new InvPayload
{
Type = type,
Hashes = hashes[i..endIndex]
Hashes = chunk,
};
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/Neo/Network/P2P/Payloads/OracleResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ protected override void DeserializeWithoutType(ref MemoryReader reader)
Id = reader.ReadUInt64();
Code = (OracleResponseCode)reader.ReadByte();
if (!Enum.IsDefined(typeof(OracleResponseCode), Code))
throw new FormatException();
throw new FormatException($"Invalid response code: {Code}.");

Result = reader.ReadVarMemory(MaxResultSize);
if (Code != OracleResponseCode.Success && Result.Length > 0)
throw new FormatException();
Expand Down
46 changes: 45 additions & 1 deletion src/Neo/Network/P2P/Payloads/Signer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;

namespace Neo.Network.P2P.Payloads
{
/// <summary>
/// Represents a signer of a <see cref="Transaction"/>.
/// </summary>
public class Signer : IInteroperable, ISerializable
public class Signer : IInteroperable, ISerializable, IEquatable<Signer>
{
// This limits maximum number of AllowedContracts or AllowedGroups here
private const int MaxSubitems = 16;
Expand Down Expand Up @@ -66,6 +67,31 @@ public class Signer : IInteroperable, ISerializable
/*AllowedGroups*/ (Scopes.HasFlag(WitnessScope.CustomGroups) ? AllowedGroups.GetVarSize() : 0) +
/*Rules*/ (Scopes.HasFlag(WitnessScope.WitnessRules) ? Rules.GetVarSize() : 0);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool Equals(Signer other)
{
if (ReferenceEquals(this, other))
return true;
if (other is null) return false;
return Account == other.Account &&
Scopes == other.Scopes &&
AllowedContracts.AsSpan().SequenceEqual(other.AllowedContracts.AsSpan()) &&
AllowedGroups.AsSpan().SequenceEqual(other.AllowedGroups.AsSpan()) &&
Rules.AsEnumerable().SequenceEqual(other.Rules.AsEnumerable());
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public override bool Equals(object obj)
{
if (obj == null) return false;
return obj is Signer signerObj && Equals(signerObj);
}

public override int GetHashCode()
{
return HashCode.Combine(Account.GetHashCode(), Scopes);
}

public void Deserialize(ref MemoryReader reader)
{
Account = reader.ReadSerializable<UInt160>();
Expand Down Expand Up @@ -202,5 +228,23 @@ VM.Types.StackItem IInteroperable.ToStackItem(IReferenceCounter referenceCounter
Scopes.HasFlag(WitnessScope.WitnessRules) ? new VM.Types.Array(referenceCounter, Rules.Select(u => u.ToStackItem(referenceCounter))) : new VM.Types.Array(referenceCounter)
]);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool operator ==(Signer left, Signer right)
{
if (left is null || right is null)
return Equals(left, right);

return left.Equals(right);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool operator !=(Signer left, Signer right)
{
if (left is null || right is null)
return !Equals(left, right);

return !left.Equals(right);
}
}
}
13 changes: 9 additions & 4 deletions src/Neo/Network/P2P/Payloads/Transaction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,18 @@ private static Signer[] DeserializeSigners(ref MemoryReader reader, int maxCount
public void DeserializeUnsigned(ref MemoryReader reader)
{
Version = reader.ReadByte();
if (Version > 0) throw new FormatException();
if (Version > 0) throw new FormatException($"Invalid version: {Version}.");

Nonce = reader.ReadUInt32();
SystemFee = reader.ReadInt64();
if (SystemFee < 0) throw new FormatException();
if (SystemFee < 0) throw new FormatException($"Invalid system fee: {SystemFee}.");

NetworkFee = reader.ReadInt64();
if (NetworkFee < 0) throw new FormatException();
if (SystemFee + NetworkFee < SystemFee) throw new FormatException();
if (NetworkFee < 0) throw new FormatException($"Invalid network fee: {NetworkFee}.");

if (SystemFee + NetworkFee < SystemFee)
throw new FormatException($"Invalid fee: {SystemFee} + {NetworkFee} < {SystemFee}.");

ValidUntilBlock = reader.ReadUInt32();
Signers = DeserializeSigners(ref reader, MaxTransactionAttributes);
Attributes = DeserializeAttributes(ref reader, MaxTransactionAttributes - Signers.Length);
Expand Down
7 changes: 4 additions & 3 deletions src/Neo/Network/P2P/Payloads/TransactionAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ public abstract class TransactionAttribute : ISerializable

public void Deserialize(ref MemoryReader reader)
{
if (reader.ReadByte() != (byte)Type)
throw new FormatException();
var type = reader.ReadByte();
if (type != (byte)Type)
throw new FormatException($"Expected {Type}, got {type}.");
DeserializeWithoutType(ref reader);
}

Expand All @@ -52,7 +53,7 @@ public static TransactionAttribute DeserializeFrom(ref MemoryReader reader)
{
TransactionAttributeType type = (TransactionAttributeType)reader.ReadByte();
if (ReflectionCache<TransactionAttributeType>.CreateInstance(type) is not TransactionAttribute attribute)
throw new FormatException();
throw new FormatException($"Invalid attribute type: {type}.");
attribute.DeserializeWithoutType(ref reader);
return attribute;
}
Expand Down
5 changes: 2 additions & 3 deletions src/Neo/Network/P2P/Payloads/WitnessRule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void ISerializable.Deserialize(ref MemoryReader reader)
{
Action = (WitnessRuleAction)reader.ReadByte();
if (Action != WitnessRuleAction.Allow && Action != WitnessRuleAction.Deny)
throw new FormatException();
throw new FormatException($"Invalid action: {Action}.");
Condition = WitnessCondition.DeserializeFrom(ref reader, WitnessCondition.MaxNestingDepth);
}

Expand All @@ -81,9 +81,8 @@ void ISerializable.Serialize(BinaryWriter writer)
public static WitnessRule FromJson(JObject json)
{
WitnessRuleAction action = Enum.Parse<WitnessRuleAction>(json["action"].GetString());

if (action != WitnessRuleAction.Allow && action != WitnessRuleAction.Deny)
throw new FormatException();
throw new FormatException($"Invalid action: {action}.");

return new()
{
Expand Down
2 changes: 1 addition & 1 deletion src/Neo/Network/P2P/RemoteNode.ProtocolHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ private void OnGetDataMessageReceived(InvPayload payload)

if (notFound.Count > 0)
{
foreach (InvPayload entry in InvPayload.CreateGroup(payload.Type, notFound.ToArray()))
foreach (InvPayload entry in InvPayload.CreateGroup(payload.Type, notFound))
EnqueueMessage(Message.Create(MessageCommand.NotFound, entry));
}
}
Expand Down
27 changes: 8 additions & 19 deletions src/Neo/Network/P2P/TaskManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ private void OnNewTasks(InvPayload payload)
session.InvTasks[hash] = TimeProvider.Current.UtcNow;
}

foreach (InvPayload group in InvPayload.CreateGroup(payload.Type, hashes.ToArray()))
foreach (InvPayload group in InvPayload.CreateGroup(payload.Type, hashes))
Sender.Tell(Message.Create(MessageCommand.GetData, group));
}

Expand Down Expand Up @@ -317,18 +317,9 @@ private void OnTimer()
{
foreach (TaskSession session in sessions.Values)
{
foreach (var (hash, time) in session.InvTasks.ToArray())
if (TimeProvider.Current.UtcNow - time > TaskTimeout)
{
if (session.InvTasks.Remove(hash))
DecrementGlobalTask(hash);
}
foreach (var (index, time) in session.IndexTasks.ToArray())
if (TimeProvider.Current.UtcNow - time > TaskTimeout)
{
if (session.IndexTasks.Remove(index))
DecrementGlobalTask(index);
}
var now = TimeProvider.Current.UtcNow;
session.InvTasks.RemoveWhere(p => now - p.Value > TaskTimeout, p => DecrementGlobalTask(p.Key));
session.IndexTasks.RemoveWhere(p => now - p.Value > TaskTimeout, p => DecrementGlobalTask(p.Key));
}
foreach (var (actor, session) in sessions)
RequestTasks(actor, session);
Expand Down Expand Up @@ -365,15 +356,13 @@ private void RequestTasks(IActorRef remoteNode, TaskSession session)
HashSet<UInt256> hashes = new(session.AvailableTasks);
if (hashes.Count > 0)
{
foreach (UInt256 hash in hashes.ToArray())
{
if (!IncrementGlobalTask(hash))
hashes.Remove(hash);
}
hashes.RemoveWhere(p => !IncrementGlobalTask(p));
session.AvailableTasks.Remove(hashes);

foreach (UInt256 hash in hashes)
session.InvTasks[hash] = DateTime.UtcNow;
foreach (InvPayload group in InvPayload.CreateGroup(InventoryType.Block, hashes.ToArray()))

foreach (InvPayload group in InvPayload.CreateGroup(InventoryType.Block, hashes))
remoteNode.Tell(Message.Create(MessageCommand.GetData, group));
return;
}
Expand Down
Loading

0 comments on commit 331541a

Please sign in to comment.