Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sockets cleanup #1351

Draft
wants to merge 5 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 0 additions & 50 deletions src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs

This file was deleted.

253 changes: 27 additions & 226 deletions src/Renci.SshNet/Abstractions/SocketAbstraction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,146 +6,38 @@
using System.Threading.Tasks;

using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;

namespace Renci.SshNet.Abstractions
{
internal static partial class SocketAbstraction
{
public static bool CanRead(Socket socket)
{
if (socket.Connected)
{
return socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0;
}

return false;
}

/// <summary>
/// Returns a value indicating whether the specified <see cref="Socket"/> can be used
/// to send data.
/// </summary>
/// <param name="socket">The <see cref="Socket"/> to check.</param>
/// <returns>
/// <see langword="true"/> if <paramref name="socket"/> can be written to; otherwise, <see langword="false"/>.
/// </returns>
public static bool CanWrite(Socket socket)
{
if (socket != null && socket.Connected)
{
return socket.Poll(-1, SelectMode.SelectWrite);
}

return false;
}

public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
{
var socket = new Socket(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: true);
return socket;
}

public static void Connect(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
{
ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: false);
}

public static async Task ConnectAsync(Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
{
await socket.ConnectAsync(remoteEndpoint, cancellationToken).ConfigureAwait(false);
}

private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket)
{
var connectCompleted = new ManualResetEvent(initialState: false);
var args = new SocketAsyncEventArgs
{
UserToken = connectCompleted,
RemoteEndPoint = remoteEndpoint
};
args.Completed += ConnectCompleted;
using var connectCompleted = new ManualResetEventSlim(initialState: false);
using var args = new SocketAsyncEventArgs
{
RemoteEndPoint = remoteEndpoint
};
args.Completed += (_, _) => connectCompleted.Set();

if (socket.ConnectAsync(args))
{
if (!connectCompleted.WaitOne(connectTimeout))
if (!connectCompleted.Wait(connectTimeout))
{
// avoid ObjectDisposedException in ConnectCompleted
args.Completed -= ConnectCompleted;
if (ownsSocket)
{
// dispose Socket
socket.Dispose();
}

// dispose ManualResetEvent
connectCompleted.Dispose();

// dispose SocketAsyncEventArgs
args.Dispose();
socket.Dispose();

throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
"Connection failed to establish within {0:F0} milliseconds.",
connectTimeout.TotalMilliseconds));
}
}

// dispose ManualResetEvent
connectCompleted.Dispose();

if (args.SocketError != SocketError.Success)
{
var socketError = (int) args.SocketError;

if (ownsSocket)
{
// dispose Socket
socket.Dispose();
}

// dispose SocketAsyncEventArgs
args.Dispose();

throw new SocketException(socketError);
}

// dispose SocketAsyncEventArgs
args.Dispose();
}

public static void ClearReadBuffer(Socket socket)
{
var timeout = TimeSpan.FromMilliseconds(500);
var buffer = new byte[256];
int bytesReceived;

do
{
bytesReceived = ReadPartial(socket, buffer, 0, buffer.Length, timeout);
}
while (bytesReceived > 0);
}

public static int ReadPartial(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout)
{
socket.ReceiveTimeout = timeout.AsTimeout(nameof(timeout));

try
{
return socket.Receive(buffer, offset, size, SocketFlags.None);
}
catch (SocketException ex)
{
if (ex.SocketErrorCode == SocketError.TimedOut)
{
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
"Socket read operation has timed out after {0:F0} milliseconds.",
timeout.TotalMilliseconds));
}

throw;
}
}

public static void ReadContinuous(Socket socket, byte[] buffer, int offset, int size, Action<byte[], int, int> processReceivedBytesAction)
Expand All @@ -167,11 +59,6 @@ public static void ReadContinuous(Socket socket, byte[] buffer, int offset, int
}
catch (SocketException ex)
{
if (IsErrorResumable(ex.SocketErrorCode))
{
continue;
}

#pragma warning disable IDE0010 // Add missing cases
switch (ex.SocketErrorCode)
{
Expand Down Expand Up @@ -212,41 +99,6 @@ public static int ReadByte(Socket socket, TimeSpan timeout)
return buffer[0];
}

/// <summary>
/// Sends a byte using the specified <see cref="Socket"/>.
/// </summary>
/// <param name="socket">The <see cref="Socket"/> to write to.</param>
/// <param name="value">The value to send.</param>
/// <exception cref="SocketException">The write failed.</exception>
public static void SendByte(Socket socket, byte value)
{
var buffer = new[] { value };
Send(socket, buffer, 0, 1);
}

/// <summary>
/// Receives data from a bound <see cref="Socket"/>.
/// </summary>
/// <param name="socket">The <see cref="Socket"/> to read from.</param>
/// <param name="size">The number of bytes to receive.</param>
/// <param name="timeout">Specifies the amount of time after which the call will time out.</param>
/// <returns>
/// The bytes received.
/// </returns>
/// <remarks>
/// If no data is available for reading, the <see cref="Read(Socket, int, TimeSpan)"/> method will
/// block until data is available or the time-out value is exceeded. If the time-out value is exceeded, the
/// <see cref="Read(Socket, int, TimeSpan)"/> call will throw a <see cref="SshOperationTimeoutException"/>.
/// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the
/// <see cref="Read(Socket, int, TimeSpan)"/> method will complete immediately and throw a <see cref="SocketException"/>.
/// </remarks>
public static byte[] Read(Socket socket, int size, TimeSpan timeout)
{
var buffer = new byte[size];
_ = Read(socket, buffer, 0, size, timeout);
return buffer;
}

/// <summary>
/// Receives data from a bound <see cref="Socket"/> into a receive buffer.
/// </summary>
Expand All @@ -264,10 +116,6 @@ public static byte[] Read(Socket socket, int size, TimeSpan timeout)
/// block until data is available or the time-out value is exceeded. If the time-out value is exceeded, the
/// <see cref="Read(Socket, byte[], int, int, TimeSpan)"/> call will throw a <see cref="SshOperationTimeoutException"/>.
/// </para>
/// <para>
/// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the
/// <see cref="Read(Socket, byte[], int, int, TimeSpan)"/> method will complete immediately and throw a <see cref="SocketException"/>.
/// </para>
/// </remarks>
public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeSpan readTimeout)
{
Expand All @@ -288,94 +136,47 @@ public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeS

totalBytesRead += bytesRead;
}
catch (SocketException ex)
catch (SocketException ex) when (ex.SocketErrorCode == SocketError.TimedOut)
{
if (IsErrorResumable(ex.SocketErrorCode))
{
ThreadAbstraction.Sleep(30);
continue;
}

if (ex.SocketErrorCode == SocketError.TimedOut)
{
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
"Socket read operation has timed out after {0:F0} milliseconds.",
readTimeout.TotalMilliseconds));
}

throw;
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
"Socket read operation has timed out after {0:F0} milliseconds.",
readTimeout.TotalMilliseconds),
ex);
}
}
while (totalBytesRead < totalBytesToRead);

return totalBytesRead;
}

#if NET6_0_OR_GREATER == false
public static Task<int> ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken)
{
return socket.ReceiveAsync(buffer, 0, buffer.Length, cancellationToken);
}
#endif

public static void Send(Socket socket, byte[] data)
public static async ValueTask<int> ReadAsync(Socket socket, byte[] buffer, int offset, int size, CancellationToken cancellationToken)
{
Send(socket, data, 0, data.Length);
}

public static void Send(Socket socket, byte[] data, int offset, int size)
{
var totalBytesSent = 0; // how many bytes are already sent
var totalBytesToSend = size;
var totalBytesRead = 0;
var totalBytesToRead = size;

do
{
try
{
var bytesSent = socket.Send(data, offset + totalBytesSent, totalBytesToSend - totalBytesSent, SocketFlags.None);
if (bytesSent == 0)
var bytesRead = await socket.ReceiveAsync(new ArraySegment<byte>(buffer, offset + totalBytesRead, totalBytesToRead - totalBytesRead), SocketFlags.None, cancellationToken).ConfigureAwait(false);
if (bytesRead == 0)
{
throw new SshConnectionException("An established connection was aborted by the server.",
DisconnectReason.ConnectionLost);
return 0;
}

totalBytesSent += bytesSent;
totalBytesRead += bytesRead;
}
catch (SocketException ex)
catch (SocketException ex) when (ex.SocketErrorCode == SocketError.TimedOut)
{
if (IsErrorResumable(ex.SocketErrorCode))
{
// socket buffer is probably full, wait and try again
ThreadAbstraction.Sleep(30);
}
else
{
throw; // any serious error occurr
}
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
"Socket read operation has timed out after {0:F0} milliseconds.",
socket.ReceiveTimeout),
ex);
}
}
while (totalBytesSent < totalBytesToSend);
}

public static bool IsErrorResumable(SocketError socketError)
{
#pragma warning disable IDE0010 // Add missing cases
switch (socketError)
{
case SocketError.WouldBlock:
case SocketError.IOPending:
case SocketError.NoBufferSpaceAvailable:
return true;
default:
return false;
}
#pragma warning restore IDE0010 // Add missing cases
}
while (totalBytesRead < totalBytesToRead);

private static void ConnectCompleted(object sender, SocketAsyncEventArgs e)
{
var eventWaitHandle = (ManualResetEvent) e.UserToken;
_ = eventWaitHandle?.Set();
return totalBytesRead;
}
}
}
Loading