-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.go
248 lines (219 loc) · 6.7 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
package main
import (
"bytes"
"compress/flate"
"compress/gzip"
"io"
"net/http"
"os"
"strings"
"github.com/andybalholm/brotli"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
)
func init() {
logLevel := os.Getenv("LOG_LEVEL")
if logLevel == "" {
logLevel = "info"
}
switch logLevel {
case "debug":
logrus.SetLevel(logrus.DebugLevel)
case "info":
logrus.SetLevel(logrus.InfoLevel)
case "warning":
logrus.SetLevel(logrus.WarnLevel)
case "error":
logrus.SetLevel(logrus.ErrorLevel)
default:
logrus.SetLevel(logrus.InfoLevel)
}
logrus.SetFormatter(&logrus.TextFormatter{
FullTimestamp: true,
})
}
func replaceDomainInResponse(originalSubdomain, replaceSubdomain, originalDomain string, buffer *bytes.Buffer) {
body := buffer.String()
fullReplace := replaceSubdomain + "." + "lunaroasis.net" // We know that statescale and snapscale are under this domain
fullOriginal := originalSubdomain + "." + originalDomain // Original domain can vary
replacedBody := strings.ReplaceAll(body, fullReplace, fullOriginal)
buffer.Reset()
buffer.WriteString(replacedBody)
}
func proxyRequest(fullSubdomain, path string, buffer *bytes.Buffer, r *http.Request) (int, map[string]string, error) {
client := &http.Client{}
target := "https://" + fullSubdomain + ".lunaroasis.net" + path
newReq, err := http.NewRequest(r.Method, target, r.Body)
if err != nil {
logrus.Errorf("Failed to create request: %v", err)
return 0, nil, err
}
newReq.Header = r.Header
resp, err := client.Do(newReq)
if err != nil {
logrus.Errorf("Failed to send request: %v", err)
return 0, nil, err
}
defer resp.Body.Close()
headers := make(map[string]string)
for key, values := range resp.Header {
for _, value := range values {
headers[key] = value
}
}
encoding := resp.Header.Get("Content-Encoding")
var reader io.Reader
switch encoding {
case "br":
// Decompress Brotli data
reader = brotli.NewReader(resp.Body)
case "gzip":
// Decompress Gzip data
reader, err = gzip.NewReader(resp.Body)
if err != nil {
logrus.Errorf("Failed to create gzip reader: %v", err)
return 0, nil, err
}
case "deflate":
// Decompress Deflate data
reader = flate.NewReader(resp.Body)
default:
reader = resp.Body
}
io.Copy(buffer, reader)
return resp.StatusCode, headers, nil
}
func handleHttpRequest(w http.ResponseWriter, r *http.Request) {
logrus.Infof("Received request from %s", r.Host)
hostParts := strings.Split(r.Host, ".")
if len(hostParts) < 3 {
logrus.Errorf("Invalid domain: %s", r.Host)
http.Error(w, "Invalid domain", http.StatusBadRequest)
return
}
subdomain := hostParts[0] // Extract original domain
originalDomain := strings.Join(hostParts[1:], ".")
// Check for WebSocket upgrade headers
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
// Handle WebSocket requests by proxying to snapscale
proxyWebSocketRequest(subdomain, w, r)
return
}
buffer := new(bytes.Buffer)
backupBuffer := new(bytes.Buffer)
logrus.Debugf("Proxying request to %s", subdomain+"-statescale")
statusCode, headers, err := proxyRequest(subdomain+"-statescale", r.RequestURI, buffer, r)
logrus.Debugf("Received status code %d", statusCode)
if err != nil || statusCode >= 400 {
logrus.Debugf("Proxying request to %s", subdomain+"-snapscale")
backupStatusCode, backupHeaders, _ := proxyRequest(subdomain+"-snapscale", r.RequestURI, backupBuffer, r)
logrus.Debugf("Received status code %d", backupStatusCode)
replaceDomainInResponse(subdomain, subdomain+"-snapscale", originalDomain, backupBuffer)
for key, value := range backupHeaders {
w.Header().Set(key, value)
}
w.WriteHeader(backupStatusCode)
encoding := headers["Content-Encoding"]
buffer = compressData(buffer, encoding)
io.Copy(w, backupBuffer)
return
}
replaceDomainInResponse(subdomain, subdomain+"-statescale", originalDomain, buffer)
for key, value := range headers {
w.Header().Set(key, value)
}
w.WriteHeader(statusCode)
// If the original response was Brotli-compressed, recompress the data
encoding := headers["Content-Encoding"]
buffer = compressData(buffer, encoding)
io.Copy(w, buffer)
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow all connections
},
}
func proxyWebSocketRequest(subdomain string, w http.ResponseWriter, r *http.Request) {
// Build target URL
fullSubdomain := subdomain + "-snapscale"
target := "wss://" + fullSubdomain + ".lunaroasis.net" + r.RequestURI
// Create a new WebSocket connection to the target
dialer := websocket.Dialer{}
targetConn, resp, err := dialer.Dial(target, nil)
if err != nil {
logrus.Errorf("Failed to connect to target: %v", err)
if resp != nil {
logrus.Errorf("Handshake response status: %s", resp.Status)
// Log all response headers for debugging
for k, v := range resp.Header {
logrus.Errorf("%s: %s", k, v)
}
}
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
defer targetConn.Close()
// Upgrade the client connection to a WebSocket connection
clientConn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logrus.Errorf("Failed to upgrade client connection: %v", err)
return // No need to send an error response, Upgrade already did if there was an error
}
defer clientConn.Close()
// Start goroutines to copy data between the client and target
go func() {
for {
messageType, message, err := targetConn.ReadMessage()
if err != nil {
logrus.Errorf("Failed to read from target: %v", err)
return
}
err = clientConn.WriteMessage(messageType, message)
if err != nil {
logrus.Errorf("Failed to write to client: %v", err)
return
}
}
}()
go func() {
for {
messageType, message, err := clientConn.ReadMessage()
if err != nil {
logrus.Errorf("Failed to read from client: %v", err)
return
}
err = targetConn.WriteMessage(messageType, message)
if err != nil {
logrus.Errorf("Failed to write to target: %v", err)
return
}
}
}()
// The goroutines will run until one of the connections is closed
select {}
}
func compressData(buffer *bytes.Buffer, encoding string) *bytes.Buffer {
var compressedData bytes.Buffer
var writer io.WriteCloser
switch encoding {
case "br":
writer = brotli.NewWriterLevel(&compressedData, brotli.DefaultCompression)
case "gzip":
writer = gzip.NewWriter(&compressedData)
case "deflate":
writer, _ = flate.NewWriter(&compressedData, flate.DefaultCompression)
default:
return buffer
}
io.Copy(writer, buffer)
writer.Close()
return &compressedData
}
func handleRequest(w http.ResponseWriter, r *http.Request) {
handleHttpRequest(w, r)
}
func main() {
logrus.Info("Starting server on :8080")
http.HandleFunc("/", handleRequest)
http.ListenAndServe(":8080", nil)
}