-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmlton_server.sml
152 lines (133 loc) · 6.33 KB
/
mlton_server.sml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
structure AthenaServer =
struct
open TopEnv;
open TextIO
structure SU = SockUtil
fun processString(str) =
let (** val _ = print("\nInside processString, incoming request:\n" ^ str ^ "\n") **)
val res = TopEnv.processPhraseDirectlyFromString(str,Semantics.top_val_env)
(** val _ = print("\nInside processString, resulting value string:\n" ^ res ^ "\n") **)
in
res
end
val max_bytes_at_once = 1024 * 80
(*======================================================================================================
The readAllSimple function is straightforward and does not require encoding the length of the payload into
the first 4 bytes of the client's request. It is, however, predicated on the assumption that the client will
shut down the connection after transmitting the payload. Thus, a "natural EOF" can be detected on the server
side simply when Socket.recvVec returns a vector of length 0, at which point we know that the entire payload
has been received. This is clean and efficient (esp. with a large value of max_bytes_at_once, but again,
it does require the client to close the connection after each send. This means that there are no persistent
socket connections, which is fine for the use cases envisioned here (remote code execution).
** Important Note: If readAllSimple is used instead of readAll, then the corresponding sending function on
the Python client side should be send_request_to_server_simple (rather than send_request_to_server), to
ensure that the client does not encode the length of the payload into the first 4 bytes.
======================================================================================================*)
fun readAllSimple(conn) =
let fun loop(vector_list,iteration) =
let val in_vector = Socket.recvVec(conn,max_bytes_at_once)
val len = Word8Vector.length(in_vector)
in
(* Connection is closed, end of message reached *)
if len = 0 then rev(in_vector::vector_list)
else loop(in_vector::vector_list,iteration+1)
end
val vector_list = loop([],1)
in
Byte.bytesToString(Word8Vector.concat(vector_list))
end;
fun readExactly(conn, n) =
let fun loop(acc_list, bytes_read) =
if bytes_read >= n
then rev(acc_list)
else
let val vec = Socket.recvVec(conn, n - bytes_read)
val len = Word8Vector.length(vec)
in
if len = 0
then Basic.fail("Premature connection closure, unable to extract the length of the payload from the first 4 bytes.")
else loop(vec::acc_list, bytes_read + len)
end
in
Word8Vector.concat(loop([], 0))
end
(*======================================================================================================
This implementation of readAll assumes that the client has encoded the length of the actual payload into
the first 4 bytes of the request. (This means that the largest possible payload that a client can send to
an Athena TCP server has a size that can be represented by 4 bytes; that comes to roughly 4 gigabytes.)
Since TCP might break up a message into arbitrarily many chunks, it's conceivable (though extremely unlikely)
that even the first 4 bytes of the incoming request are broken up into more that one part (e.g., 2 bytes and
then another 2 bytes). For that reason, readExactly is used to extract the first 4 bytes (though this is
likely overkill; in the vast majority of cases, the internal loop of readExactly should only perform
one single iteration).
======================================================================================================*)
fun readAll(conn) =
let val length_vec = readExactly(conn,4)
(******
val _ = print("\nReceived length bytes: ")
val _ = Word8Vector.app (fn b => print(Int.toString(Word8.toInt b) ^ " "))
length_vec
******)
(******
Now do the exact inverse of what the client is expected to do:
******)
val expected_length = Word8Vector.foldl (fn (b, acc) => acc * 256 + Word8.toInt b)
0
length_vec
(******
val _ = print("\nExpected length: " ^ Int.toString expected_length)
******)
fun loop(vector_list, bytes_read) =
if bytes_read >= expected_length
then rev(vector_list)
else
let val in_vector = Socket.recvVec(conn, Int.min(max_bytes_at_once, expected_length - bytes_read))
val len = Word8Vector.length(in_vector)
(******
val _ = print("\nReceived chunk of length: " ^ Int.toString len)
******)
in
loop(in_vector::vector_list, bytes_read + len)
end
in
Byte.bytesToString(Word8Vector.concat(loop([], 0)))
end
fun acceptLoop(server_port:int) =
let fun respond(conn) =
let val request_text = readAll(conn)
val reply = processString(request_text)
val _ = SU.sendStr(conn,reply)
in
Socket.close(conn)
end handle x => (Socket.close(conn);raise x)
fun accept(listener) =
let val (conn,conn_addr) = Socket.accept(listener)
val _ = Thread.spawn(fn () => (respond(conn); ()))
in
accept(listener)
end
val listener = INetSock.TCP.socket()
val _ = INetSock.TCP.setNODELAY(listener,true)
val _ = Socket.Ctl.setREUSEADDR(listener,true);
in
Socket.bind(listener,INetSock.any server_port);
Socket.listen(listener,9);
accept(listener)
end
fun startServer(port,file_name_option) =
let val sock = INetSock.TCP.socket()
val _ = Repl.init(file_name_option)
in
print "\nEntering accept loop...\n";
acceptLoop(port)
end
fun noisy(f) =
fn () =>
(let val () = f()
handle e =>
let val () = TextIO.print ("Exception: " ^ (exnName e)
^ " Message: " ^ (exnMessage e) ^ "\n")
in RunCML.shutdown(OS.Process.failure) end
in RunCML.shutdown(OS.Process.success)
end)
end;