diff --git a/.gitignore b/.gitignore index d37a3977..b12d3742 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ /cache /data /logs +/error.log /.tmp /__hijack /oss_mirror diff --git a/README.MD b/README.MD index 9be92024..65654c78 100644 --- a/README.MD +++ b/README.MD @@ -86,10 +86,12 @@ cluster_id: ${CLUSTER_ID} cluster_secret: ${CLUSTER_SECRET} # 文件同步间隔 (分钟) sync_interval: 10 -# 同步文件时最多打开的连接数量 -download_max_conn: 64 # 连接超时限制(秒),网不好就调高点 connect_timeout: 10 +# 同步文件时最多打开的连接数量 +download_max_conn: 64 +# 是否启用 gzip 压缩 +use_gzip: false # 服务器上行限制 serve_limit: # 是否启用上行限制 diff --git a/build.sh b/build.sh index 70e22fd6..cc861011 100755 --- a/build.sh +++ b/build.sh @@ -12,7 +12,7 @@ outputdir=output mkdir -p "$outputdir" -[ -n "$TAG" ] || TAG=$(git describe --tags --match v[0-9]* --abbrev=0 2>/dev/null || git log -1 --format="dev-%H") +[ -n "$TAG" ] || TAG=$(git describe --tags --match v[0-9]* --abbrev=0 --candidates=0 2>/dev/null || git log -1 --format="dev-%H") echo "Detected tag: $TAG" diff --git a/cluster.go b/cluster.go index 56c8a1d4..d686e957 100644 --- a/cluster.go +++ b/cluster.go @@ -20,6 +20,8 @@ package main import ( + "compress/gzip" + "compress/zlib" "context" "crypto" "encoding/hex" @@ -28,7 +30,9 @@ import ( "io" "net" "net/http" + "net/url" "os" + "path" "path/filepath" "sort" "strings" @@ -66,10 +70,11 @@ type Cluster struct { socket *Socket cancelKeepalive context.CancelFunc downloadMux sync.Mutex - downloading map[string]chan struct{} + downloading map[string]chan error waitEnable []chan struct{} - client *http.Client + client *http.Client + bufSlots chan []byte handlerAPIv0 http.Handler handlerAPIv1 http.Handler @@ -103,7 +108,7 @@ func NewCluster( disabled: make(chan struct{}, 0), - downloading: make(map[string]chan struct{}), + downloading: make(map[string]chan error), client: &http.Client{ Transport: transport, @@ -111,11 +116,18 @@ func NewCluster( } close(cr.disabled) + cr.bufSlots = make(chan []byte, cr.maxConn) + for i := 0; i < cr.maxConn; i++ { + cr.bufSlots <- make([]byte, 1024*512) + } + // create folder strcture os.RemoveAll(cr.tmpDir) os.MkdirAll(cr.cacheDir, 0755) + var b [1]byte for i := 0; i < 0x100; i++ { - os.Mkdir(filepath.Join(cr.cacheDir, hex.EncodeToString([]byte{(byte)(i)})), 0755) + b[0] = (byte)(i) + os.Mkdir(filepath.Join(cr.cacheDir, hex.EncodeToString(b[:])), 0755) } os.MkdirAll(cr.dataDir, 0755) os.MkdirAll(cr.tmpDir, 0700) @@ -212,7 +224,7 @@ func (cr *Cluster) Enable(ctx context.Context) (err error) { return } logInfo("Sending enable packet") - tctx, cancel := context.WithTimeout(ctx, time.Second * time.Duration(config.ConnectTimeout)) + tctx, cancel := context.WithTimeout(ctx, time.Second*(time.Duration)(config.ConnectTimeout)) data, err := cr.socket.EmitAckContext(tctx, "enable", Map{ "host": cr.host, "port": cr.publicPort, @@ -235,6 +247,7 @@ func (cr *Cluster) Enable(ctx context.Context) (err error) { for _, ch := range cr.waitEnable { close(ch) } + cr.waitEnable = cr.waitEnable[:0] var keepaliveCtx context.Context keepaliveCtx, cr.cancelKeepalive = context.WithCancel(ctx) @@ -320,33 +333,37 @@ func (cr *Cluster) Disable(ctx context.Context) (ok bool) { { logInfo("Making keepalive before disable") tctx, cancel := context.WithTimeout(ctx, time.Second*10) - cr.KeepAlive(tctx) + ok = cr.KeepAlive(tctx) cancel() + if ok { + tctx, cancel := context.WithTimeout(ctx, time.Second*10) + data, err := cr.socket.EmitAckContext(tctx, "disable") + cancel() + if err != nil { + logErrorf("Disable failed: %v", err) + ok = false + } else { + logDebug("disable ack:", data) + if ero := data[0]; ero != nil { + logErrorf("Disable failed: %v", ero) + ok = false + } else if !data[1].(bool) { + logError("Disable failed: acked non true value") + ok = false + } + } + } else { + logWarn("Keep alive failed, disable without send packet") + ok = true + } } - tctx, cancel := context.WithTimeout(ctx, time.Second*10) - data, err := cr.socket.EmitAckContext(tctx, "disable") - cancel() - cr.enabled.Store(false) - cr.socket.Close() + go cr.socket.Close() cr.socket = nil close(cr.disabled) - if err != nil { - logErrorf("Disable failed: %v", err) - return false - } - logDebug("disable ack:", data) - if ero := data[0]; ero != nil { - logErrorf("Disable failed: %v", ero) - return false - } - if !data[1].(bool) { - logError("Disable failed: ack non true value") - return false - } logWarn("Cluster disabled") - return true + return } func (cr *Cluster) Disabled() <-chan struct{} { @@ -399,42 +416,25 @@ func (cr *Cluster) RequestCert(ctx context.Context) (ckp *CertKeyPair, err error return } -func (cr *Cluster) queryFunc(ctx context.Context, method string, url string, call func(*http.Request)) (res *http.Response, err error) { - var req *http.Request - req, err = http.NewRequestWithContext(ctx, method, cr.prefix+url, nil) - if err != nil { +func (cr *Cluster) makeReq(ctx context.Context, method string, relpath string, query url.Values) (req *http.Request, err error) { + var target *url.URL + if target, err = url.Parse(cr.prefix); err != nil { return } - - query := req.URL.Query() - if config.NoOpen { - query.Set("noopen", "1") + target.Path = path.Join(target.Path, relpath) + if query != nil { + target.RawQuery = query.Encode() } - req.URL.RawQuery = query.Encode() + req, err = http.NewRequestWithContext(ctx, method, target.String(), nil) + if err != nil { + return + } req.SetBasicAuth(cr.username, cr.password) req.Header.Set("User-Agent", cr.useragent) - if call != nil { - call(req) - } - res, err = cr.client.Do(req) return } -func (cr *Cluster) queryURL(ctx context.Context, method string, url string) (res *http.Response, err error) { - return cr.queryFunc(ctx, method, url, nil) -} - -func (cr *Cluster) queryURLHeader(ctx context.Context, method string, url string, header map[string]string) (res *http.Response, err error) { - return cr.queryFunc(ctx, method, url, func(req *http.Request) { - if header != nil { - for k, v := range header { - req.Header.Set(k, v) - } - } - }) -} - type FileInfo struct { Path string `json:"path" avro:"path"` Hash string `json:"hash" avro:"hash"` @@ -456,7 +456,11 @@ var fileListSchema = avro.MustParse(`{ }`) func (cr *Cluster) GetFileList(ctx context.Context) (files []FileInfo, err error) { - res, err := cr.queryURL(ctx, "GET", "/openbmclapi/files") + req, err := cr.makeReq(ctx, http.MethodGet, "/openbmclapi/files", nil) + if err != nil { + return + } + res, err := cr.client.Do(req) if err != nil { return } @@ -471,7 +475,7 @@ func (cr *Cluster) GetFileList(ctx context.Context) (files []FileInfo, err error if err != nil { return } - defer zr.Close() // TODO: reuse the decoder? + defer zr.Close() if err = avro.NewDecoderForSchema(fileListSchema, zr).Decode(&files); err != nil { return } @@ -481,12 +485,11 @@ func (cr *Cluster) GetFileList(ctx context.Context) (files []FileInfo, err error type syncStats struct { totalsize float64 downloaded float64 - slots chan []byte fcount atomic.Int32 fl int } -func (cr *Cluster) SyncFiles(ctx context.Context, files []FileInfo) { +func (cr *Cluster) SyncFiles(ctx context.Context, files []FileInfo, heavyCheck bool) { logInfo("Preparing to sync files...") if !cr.issync.CompareAndSwap(false, true) { logWarn("Another sync task is running!") @@ -494,9 +497,9 @@ func (cr *Cluster) SyncFiles(ctx context.Context, files []FileInfo) { } if cr.usedOSS() { - cr.ossSyncFiles(ctx, files) + cr.ossSyncFiles(ctx, files, heavyCheck) } else { - cr.syncFiles(ctx, files) + cr.syncFiles(ctx, files, heavyCheck) } cr.issync.Store(false) @@ -505,8 +508,8 @@ func (cr *Cluster) SyncFiles(ctx context.Context, files []FileInfo) { } // syncFiles download objects to the cache folder -func (cr *Cluster) syncFiles(ctx context.Context, files []FileInfo) error { - missing := cr.CheckFiles(cr.cacheDir, files) +func (cr *Cluster) syncFiles(ctx context.Context, files []FileInfo, heavyCheck bool) error { + missing := cr.CheckFiles(cr.cacheDir, files, heavyCheck) fl := len(missing) if fl == 0 { @@ -518,18 +521,16 @@ func (cr *Cluster) syncFiles(ctx context.Context, files []FileInfo) error { sort.Slice(missing, func(i, j int) bool { return missing[i].Size > missing[j].Size }) var stats syncStats - stats.slots = make(chan []byte, cr.maxConn) stats.fl = fl for _, f := range missing { stats.totalsize += (float64)(f.Size) } - for i := cap(stats.slots); i > 0; i-- { - stats.slots <- make([]byte, 1024*1024) - } logInfof("Starting sync files, count: %d, total: %s", fl, bytesToUnit(stats.totalsize)) start := time.Now() + done := make(chan struct{}, 1) + for _, f := range missing { pathRes, err := cr.fetchFile(ctx, &stats, f) if err != nil { @@ -537,10 +538,16 @@ func (cr *Cluster) syncFiles(ctx context.Context, files []FileInfo) error { return err } go func(f FileInfo) { + defer func() { + select { + case done <- struct{}{}: + case <-ctx.Done(): + } + }() select { case path := <-pathRes: if path != "" { - if err := cr.putFileToCache(path, f); err != nil { + if _, err := cr.putFileToCache(path, f); err != nil { logErrorf("Could not move file %q to cache:\n\t%v", path, err) } } @@ -549,9 +556,9 @@ func (cr *Cluster) syncFiles(ctx context.Context, files []FileInfo) error { } }(f) } - for i := cap(stats.slots); i > 0; i-- { + for i := len(missing); i > 0; i-- { select { - case <-stats.slots: + case <-done: case <-ctx.Done(): logWarn("File sync interrupted") return ctx.Err() @@ -563,30 +570,86 @@ func (cr *Cluster) syncFiles(ctx context.Context, files []FileInfo) error { return nil } -func (cr *Cluster) CheckFiles(dir string, files []FileInfo) (missing []FileInfo) { - logInfof("Start checking files at %q", dir) - usedOSS := cr.usedOSS() +func (cr *Cluster) CheckFiles(dir string, files []FileInfo, heavy bool) (missing []FileInfo) { + logInfof("Start checking files, heavy = %v", heavy) + var hashBuf [64]byte for i, f := range files { p := filepath.Join(dir, hashToFilename(f.Hash)) logDebugf("Checking file %s [%.2f%%]", p, (float32)(i+1)/(float32)(len(files))*100) - if usedOSS && f.Size == 0 { - logDebugf("Skipped empty file %s", p) - continue - } stat, err := os.Stat(p) if err == nil { if sz := stat.Size(); sz != f.Size { - logInfof("Found modified file: size of %q is %s but expect %s", - p, bytesToUnit((float64)(sz)), bytesToUnit((float64)(f.Size))) - missing = append(missing, f) + logInfof("Found modified file: size of %q is %d, expect %d", p, sz, f.Size) + goto MISSING } - } else { - logDebugf("Could not found file %q", p) - missing = append(missing, f) - if !errors.Is(err, os.ErrNotExist) { - os.Remove(p) + if heavy { + hashMethod, err := getHashMethod(len(f.Hash)) + if err != nil { + logErrorf("Unknown hash method for %q", f.Hash) + continue + } + hw := hashMethod.New() + + fd, err := os.Open(p) + if err != nil { + logErrorf("Could not open %q: %v", p, err) + goto MISSING + } + defer fd.Close() + if _, err = io.Copy(hw, fd); err != nil { + logErrorf("Could not calculate hash for %q: %v", p, err) + continue + } + if hs := hex.EncodeToString(hw.Sum(hashBuf[:0])); hs != f.Hash { + logInfof("Found modified file: hash of %q is %s, expect %s", p, hs, f.Hash) + goto MISSING + } } + continue } + if config.UseGzip { + p += ".gz" + if _, err := os.Stat(p); err == nil { + if heavy { + hashMethod, err := getHashMethod(len(f.Hash)) + if err != nil { + logErrorf("Unknown hash method for %q", f.Hash) + continue + } + hw := hashMethod.New() + + fd, err := os.Open(p) + if err != nil { + logErrorf("Could not open %q: %v", p, err) + goto MISSING + } + defer fd.Close() + r, err := gzip.NewReader(fd) + if err != nil { + logErrorf("Could not decompress %q: %v", p, err) + goto MISSING + } + var sz int64 + if sz, err = io.Copy(hw, r); err != nil { + logErrorf("Could not calculate hash for %q: %v", p, err) + goto MISSING + } + if sz != f.Size { + logInfof("Found modified file: size of %q is %d, expect %d", p, sz, f.Size) + goto MISSING + } + if hs := hex.EncodeToString(hw.Sum(hashBuf[:0])); hs != f.Hash { + logInfof("Found modified file: hash of %q is %s, expect %s", p, hs, f.Hash) + goto MISSING + } + } + continue + } + } + logDebugf("Could not found file %q", p) + MISSING: + os.Remove(p) + missing = append(missing, f) } logInfo("File check finished") return @@ -641,7 +704,7 @@ func (cr *Cluster) fetchFile(ctx context.Context, stats *syncStats, f FileInfo) WAIT_SLOT: for { select { - case buf = <-stats.slots: + case buf = <-cr.bufSlots: break WAIT_SLOT case <-ctx.Done(): return nil, ctx.Err() @@ -650,7 +713,7 @@ WAIT_SLOT: pathRes := make(chan string, 1) go func() { defer func() { - stats.slots <- buf + cr.bufSlots <- buf }() defer close(pathRes) @@ -684,12 +747,26 @@ WAIT_SLOT: return pathRes, nil } +var noOpenQuery = url.Values{ + "noopen": {"1"}, +} + func (cr *Cluster) fetchFileWithBuf(ctx context.Context, f FileInfo, hashMethod crypto.Hash, buf []byte) (path string, err error) { var ( - res *http.Response - fd *os.File + query url.Values = nil + req *http.Request + res *http.Response + fd *os.File + r io.Reader ) - if res, err = cr.queryURL(ctx, "GET", f.Path); err != nil { + if config.NoOpen { + query = noOpenQuery + } + if req, err = cr.makeReq(ctx, http.MethodGet, f.Path, query); err != nil { + return + } + req.Header.Set("Accept-Encoding", "gzip, deflate") + if res, err = cr.client.Do(req); err != nil { return } defer res.Body.Close() @@ -700,6 +777,18 @@ func (cr *Cluster) fetchFileWithBuf(ctx context.Context, f FileInfo, hashMethod err = fmt.Errorf("Unexpected status code: %d", res.StatusCode) return } + switch strings.ToLower(res.Header.Get("Content-Encoding")) { + case "gzip": + if r, err = gzip.NewReader(res.Body); err != nil { + return + } + case "deflate": + if r, err = zlib.NewReader(res.Body); err != nil { + return + } + default: + r = res.Body + } hw := hashMethod.New() @@ -713,7 +802,7 @@ func (cr *Cluster) fetchFileWithBuf(ctx context.Context, f FileInfo, hashMethod } }(path) - _, err = io.CopyBuffer(io.MultiWriter(hw, fd), res.Body, buf) + _, err = io.CopyBuffer(io.MultiWriter(hw, fd), r, buf) stat, err2 := fd.Stat() fd.Close() if err != nil { @@ -724,7 +813,7 @@ func (cr *Cluster) fetchFileWithBuf(ctx context.Context, f FileInfo, hashMethod return } if t := stat.Size(); f.Size >= 0 && t != f.Size { - err = fmt.Errorf("File size wrong, got %s, expect %s", bytesToUnit((float64)(t)), bytesToUnit((float64)(f.Size))) + err = fmt.Errorf("File size wrong, got %d, expect %d", t, f.Size) return } else if hs := hex.EncodeToString(hw.Sum(buf[:0])); hs != f.Hash { err = fmt.Errorf("File hash not match, got %s, expect %s", hs, f.Hash) @@ -733,14 +822,36 @@ func (cr *Cluster) fetchFileWithBuf(ctx context.Context, f FileInfo, hashMethod return } -func (cr *Cluster) putFileToCache(src string, f FileInfo) (err error) { +func (cr *Cluster) putFileToCache(src string, f FileInfo) (compressed bool, err error) { + defer os.Remove(src) + targetPath := filepath.Join(cr.cacheDir, hashToFilename(f.Hash)) + if config.UseGzip && f.Size > 1024*10 { + var srcFd, dstFd *os.File + if srcFd, err = os.Open(src); err != nil { + return + } + defer srcFd.Close() + targetPath += ".gz" + if dstFd, err = os.OpenFile(targetPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644); err != nil { + return + } + defer dstFd.Close() + w := gzip.NewWriter(dstFd) + defer w.Close() + if _, err = io.Copy(w, srcFd); err != nil { + return + } + return true, nil + } + os.Remove(targetPath) // remove the old file if exists if err = os.Rename(src, targetPath); err != nil { return } os.Chmod(targetPath, 0644) + // TODO: support hijack with compressed file if config.Hijack.Enable { if !strings.HasPrefix(f.Path, "/openbmclapi/download/") { target := filepath.Join(config.Hijack.Path, filepath.FromSlash(f.Path)) @@ -751,21 +862,22 @@ func (cr *Cluster) putFileToCache(src string, f FileInfo) (err error) { } } } - return + return false, nil } -func (cr *Cluster) lockDownloading(target string) (chan struct{}, bool) { +func (cr *Cluster) lockDownloading(target string) (chan error, bool) { cr.downloadMux.Lock() defer cr.downloadMux.Unlock() + if ch := cr.downloading[target]; ch != nil { return ch, true } - ch := make(chan struct{}, 1) + ch := make(chan error, 1) cr.downloading[target] = ch return ch, false } -func (cr *Cluster) DownloadFile(ctx context.Context, dir string, hash string) (err error) { +func (cr *Cluster) DownloadFile(ctx context.Context, hash string) (compressed bool, err error) { hashMethod, err := getHashMethod(len(hash)) if err != nil { return @@ -782,19 +894,46 @@ func (cr *Cluster) DownloadFile(ctx context.Context, dir string, hash string) (e Hash: hash, Size: -1, } - target := filepath.Join(dir, hashToFilename(hash)) - done, ok := cr.lockDownloading(target) + done, ok := cr.lockDownloading(hash) if ok { select { - case <-done: + case err = <-done: case <-cr.Disabled(): } return } - defer close(done) + defer func() { + done <- err + }() path, err := cr.fetchFileWithBuf(ctx, f, hashMethod, buf) if err != nil { return } - return copyFile(path, target, 0644) + if compressed, err = cr.putFileToCache(path, f); err != nil { + return + } + return +} + +func walkCacheDir(dir string, walker func(path string) (err error)) (err error) { + var b [1]byte + for i := 0; i < 0x100; i++ { + b[0] = (byte)(i) + d := filepath.Join(dir, hex.EncodeToString(b[:])) + var files []os.DirEntry + if files, err = os.ReadDir(d); err != nil { + if errors.Is(err, os.ErrNotExist) { + continue + } + return + } else { + for _, f := range files { + p := filepath.Join(d, f.Name()) + if err = walker(p); err != nil { + return + } + } + } + } + return nil } diff --git a/cluster_oss.go b/cluster_oss.go index 723feca1..45cadf62 100644 --- a/cluster_oss.go +++ b/cluster_oss.go @@ -21,6 +21,7 @@ package main import ( "context" + "encoding/hex" "errors" "fmt" "io" @@ -38,21 +39,72 @@ type fileInfoWithTargets struct { targets []string } -func (cr *Cluster) ossSyncFiles(ctx context.Context, files []FileInfo) error { - missingMap := make(map[string]*fileInfoWithTargets, 4) - for _, item := range cr.ossList { - dir := filepath.Join(item.FolderPath, "download") - need := cr.CheckFiles(dir, files) - for _, f := range need { - if info := missingMap[f.Hash]; info != nil { - info.targets = append(info.targets, dir) - } else { - missingMap[f.Hash] = &fileInfoWithTargets{ - FileInfo: f, - targets: []string{dir}, +func (cr *Cluster) CheckFilesOSS(dir string, files []FileInfo, heavy bool, missing map[string]*fileInfoWithTargets) { + addMissing := func(f FileInfo) { + if info := missing[f.Hash]; info != nil { + info.targets = append(info.targets, dir) + } else { + missing[f.Hash] = &fileInfoWithTargets{ + FileInfo: f, + targets: []string{dir}, + } + } + } + logInfof("Start checking files at %q, heavy = %v", dir, heavy) + var hashBuf [64]byte + for i, f := range files { + p := filepath.Join(dir, hashToFilename(f.Hash)) + logDebugf("Checking file %s [%.2f%%]", p, (float32)(i+1)/(float32)(len(files))*100) + if f.Size == 0 { + logDebugf("Skipped empty file %s", p) + continue + } + stat, err := os.Stat(p) + if err == nil { + if sz := stat.Size(); sz != f.Size { + logInfof("Found modified file: size of %q is %s, expect %s", + p, bytesToUnit((float64)(sz)), bytesToUnit((float64)(f.Size))) + goto MISSING + } + if heavy { + hashMethod, err := getHashMethod(len(f.Hash)) + if err != nil { + logErrorf("Unknown hash method for %q", f.Hash) + continue + } + hw := hashMethod.New() + + fd, err := os.Open(p) + if err != nil { + logErrorf("Could not open %q: %v", p, err) + goto MISSING + } + defer fd.Close() + if _, err = io.Copy(hw, fd); err != nil { + logErrorf("Could not calculate hash for %q: %v", p, err) + continue + } + if hs := hex.EncodeToString(hw.Sum(hashBuf[:0])); hs != f.Hash { + logInfof("Found modified file: hash of %q is %s, expect %s", p, hs, f.Hash) + goto MISSING } } + continue } + logDebugf("Could not found file %q", p) + MISSING: + os.Remove(p) + addMissing(f) + } + logInfo("File check finished") + return +} + +func (cr *Cluster) ossSyncFiles(ctx context.Context, files []FileInfo, heavyCheck bool) error { + missingMap := make(map[string]*fileInfoWithTargets, 4) + for _, item := range cr.ossList { + dir := filepath.Join(item.FolderPath, "download") + cr.CheckFilesOSS(dir, files, heavyCheck, missingMap) } missing := make([]*fileInfoWithTargets, 0, len(missingMap)) @@ -67,14 +119,10 @@ func (cr *Cluster) ossSyncFiles(ctx context.Context, files []FileInfo) error { } var stats syncStats - stats.slots = make(chan []byte, cr.maxConn) stats.fl = fl for _, f := range missing { stats.totalsize += (float64)(f.Size) } - for i := cap(stats.slots); i > 0; i-- { - stats.slots <- make([]byte, 1024*1024) - } logInfof("Starting sync files, count: %d, total: %s", fl, bytesToUnit(stats.totalsize)) start := time.Now() @@ -90,16 +138,19 @@ func (cr *Cluster) ossSyncFiles(ctx context.Context, files []FileInfo) error { } go func(f *fileInfoWithTargets) { defer func() { - done <- struct{}{} + select { + case done <- struct{}{}: + case <-ctx.Done(): + } }() select { case path := <-pathRes: if path != "" { defer os.Remove(path) // acquire slot here - buf := <-stats.slots - defer func(){ - stats.slots <- buf + buf := <-cr.bufSlots + defer func() { + cr.bufSlots <- buf }() var srcFd *os.File if srcFd, err = os.Open(path); err != nil { @@ -131,7 +182,7 @@ func (cr *Cluster) ossSyncFiles(ctx context.Context, files []FileInfo) error { } }(f) } - for i := cap(stats.slots); i > 0; i-- { + for i := len(missing); i > 0; i-- { select { case <-done: case <-ctx.Done(): @@ -259,3 +310,41 @@ func checkOSS(ctx context.Context, client *http.Client, item *OSSItem, size int) logInfof("Check finished for %q, used %v, %s/s; supportRange=%v", target, used, bytesToUnit((float64)(n)/used.Seconds()), supportRange) return } + +func (cr *Cluster) DownloadFileOSS(ctx context.Context, dir string, hash string) (err error) { + hashMethod, err := getHashMethod(len(hash)) + if err != nil { + return + } + + var buf []byte + { + buf0 := bufPool.Get().(*[]byte) + defer bufPool.Put(buf0) + buf = *buf0 + } + f := FileInfo{ + Path: "/openbmclapi/download/" + hash + "?noopen=1", + Hash: hash, + Size: -1, + } + target := filepath.Join(dir, hashToFilename(hash)) + done, ok := cr.lockDownloading(target) + if ok { + select { + case err = <-done: + case <-cr.Disabled(): + } + return + } + defer func() { + done <- err + }() + path, err := cr.fetchFileWithBuf(ctx, f, hashMethod, buf) + if err != nil { + return + } + defer os.Remove(path) + err = copyFile(path, target, 0644) + return +} diff --git a/config.go b/config.go index 5f45c64b..92bb1a8f 100644 --- a/config.go +++ b/config.go @@ -67,8 +67,9 @@ type Config struct { ClusterId string `yaml:"cluster_id"` ClusterSecret string `yaml:"cluster_secret"` SyncInterval int `yaml:"sync_interval"` - DownloadMaxConn int `yaml:"download_max_conn"` ConnectTimeout int `yaml:"connect_timeout"` + DownloadMaxConn int `yaml:"download_max_conn"` + UseGzip bool `yaml:"use_gzip"` ServeLimit ServeLimitConfig `yaml:"serve_limit"` Oss OSSConfig `yaml:"oss"` Hijack HijackConfig `yaml:"hijack_port"` @@ -87,8 +88,9 @@ func readConfig() (config Config) { ClusterId: "${CLUSTER_ID}", ClusterSecret: "${CLUSTER_SECRET}", SyncInterval: 10, - DownloadMaxConn: 64, ConnectTimeout: 10, + DownloadMaxConn: 64, + UseGzip: false, ServeLimit: ServeLimitConfig{ Enable: false, MaxConn: 16384, diff --git a/config.yaml b/config.yaml index 9b03ba55..8bfbb100 100644 --- a/config.yaml +++ b/config.yaml @@ -8,8 +8,9 @@ port: 4000 cluster_id: ${CLUSTER_ID} cluster_secret: ${CLUSTER_SECRET} sync_interval: 10 +connect_timeout: 10 download_max_conn: 64 -connect_timeout: 60 +use_gzip: false serve_limit: enable: false max_conn: 0 diff --git a/handler.go b/handler.go index 76e0b497..381e3be7 100644 --- a/handler.go +++ b/handler.go @@ -20,9 +20,11 @@ package main import ( + "compress/gzip" "crypto" "encoding/hex" "errors" + "fmt" "io" "net" "net/http" @@ -175,6 +177,7 @@ var emptyHashes = func() (hashes map[string]struct{}) { func (cr *Cluster) ServeHTTP(rw http.ResponseWriter, req *http.Request) { method := req.Method u := req.URL + rawpath := u.EscapedPath() switch { case strings.HasPrefix(rawpath, "/download/"): @@ -183,113 +186,34 @@ func (cr *Cluster) ServeHTTP(rw http.ResponseWriter, req *http.Request) { http.Error(rw, "405 Method Not Allowed", http.StatusMethodNotAllowed) return } + hash := rawpath[len("/download/"):] if len(hash) < 4 { - http.Error(rw, "404 Status Not Found", http.StatusNotFound) + http.Error(rw, "404 Not Found", http.StatusNotFound) return } - name := req.Form.Get("name") if _, ok := emptyHashes[hash]; ok { + name := req.URL.Query().Get("name") rw.Header().Set("Cache-Control", "max-age=2592000") // 30 days rw.Header().Set("Content-Type", "application/octet-stream") + rw.Header().Set("Content-Length", "0") + if name != "" { + rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name)) + } rw.Header().Set("X-Bmclapi-Hash", hash) - http.ServeContent(rw, req, name, time.Time{}, NullReader) + rw.WriteHeader(http.StatusOK) cr.hits.Add(1) - // cr.hbts.Add(0) // empty bytes + // cr.hbts.Add(0) // no need to add zero return } - hashFilename := hashToFilename(hash) - // if use OSS redirect if cr.ossList != nil { - logDebug("[handler]: Preparing OSS redirect response") - var err error - forEachSliceFromRandomIndex(len(cr.ossList), func(i int) bool { - item := cr.ossList[i] - logDebugf("[handler]: Checking file on OSS %d at %q ...", i, item.FolderPath) - - if !item.working.Load() { - logDebugf("[handler]: OSS %d is not working", i) - err = errors.New("All OSS server is down") - return false - } - - // check if the file exists - downloadDir := filepath.Join(item.FolderPath, "download") - path := filepath.Join(downloadDir, hashFilename) - var stat os.FileInfo - if stat, err = os.Stat(path); err != nil { - logDebugf("[handler]: Cannot read file on OSS %d: %v", i, err) - if errors.Is(err, os.ErrNotExist) { - if e := cr.DownloadFile(req.Context(), downloadDir, hash); e != nil { - logDebugf("[handler]: Cound not download the file: %v", e) - return false - } - if stat, err = os.Stat(path); err != nil { - return false - } - } else { - return false - } - } - - var target string - target, err = url.JoinPath(item.RedirectBase, "download", hashFilename) - if err != nil { - return false - } - size := stat.Size() - if item.supportRange { // fix the size for Ranged request - rg := req.Header.Get("Range") - rgs, err := gosrc.ParseRange(rg, size) - if err == nil && len(rgs) > 0 { - size = 0 - for _, r := range rgs { - size += r.Length - } - } - } - http.Redirect(rw, req, target, http.StatusFound) - cr.hits.Add(1) - cr.hbts.Add(size) - return true - }) - if err != nil { - logDebugf("[handler]: OSS redirect failed: %v", err) - if errors.Is(err, os.ErrNotExist) { - http.Error(rw, "404 Status Not Found", http.StatusNotFound) - return - } - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - logDebug("[handler]: OSS redirect successed") - return - } - - path := filepath.Join(cr.cacheDir, hashFilename) - if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) { - if err := cr.DownloadFile(req.Context(), cr.cacheDir, hash); err != nil { - http.Error(rw, "404 Status Not Found", http.StatusNotFound) - return - } - } - - rw.Header().Set("Cache-Control", "max-age=2592000") // 30 days - fd, err := os.Open(path) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) + cr.handleDownloadOSS(rw, req, hash) return } - defer fd.Close() - rw.Header().Set("Content-Type", "application/octet-stream") - rw.Header().Set("X-Bmclapi-Hash", hash) - counter := &countReader{ReadSeeker: fd} - http.ServeContent(rw, req, name, time.Time{}, counter) - cr.hits.Add(1) - cr.hbts.Add(counter.n) + cr.handleDownload(rw, req, hash) return case strings.HasPrefix(rawpath, "/measure/"): if req.Header.Get("x-openbmclapi-secret") != cr.password { @@ -341,3 +265,166 @@ func (cr *Cluster) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } http.NotFound(rw, req) } + +func (cr *Cluster) handleDownload(rw http.ResponseWriter, req *http.Request, hash string) { + acceptEncoding := splitCSV(req.Header.Get("Accept-Encoding")) + name := req.URL.Query().Get("name") + hashFilename := hashToFilename(hash) + + hasGzip := false + isGzip := false + path := filepath.Join(cr.cacheDir, hashFilename) + if config.UseGzip { + if _, err := os.Stat(path + ".gz"); err == nil { + hasGzip = true + } + } + if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) { + if !hasGzip { + if hasGzip, err = cr.DownloadFile(req.Context(), hash); err != nil { + http.Error(rw, "404 Status Not Found", http.StatusNotFound) + return + } + } + if hasGzip { + isGzip = true + path += ".gz" + } + } + + if !isGzip && rw.Header().Get("Range") != "" { + fd, err := os.Open(path) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + defer fd.Close() + counter := new(countReader) + counter.ReadSeeker = fd + + rw.Header().Set("Cache-Control", "max-age=2592000") // 30 days + rw.Header().Set("Content-Type", "application/octet-stream") + rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name)) + rw.Header().Set("X-Bmclapi-Hash", hash) + http.ServeContent(rw, req, name, time.Time{}, counter) + cr.hits.Add(1) + cr.hbts.Add(counter.n) + } else { + var r io.Reader + if hasGzip && acceptEncoding["gzip"] != 0 { + if !isGzip { + path += ".gz" + } + fd, err := os.Open(path) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + defer fd.Close() + r = fd + rw.Header().Set("Content-Encoding", "gzip") + } else { + fd, err := os.Open(path) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + defer fd.Close() + r = fd + if isGzip { + if r, err = gzip.NewReader(r); err != nil { + logErrorf("Could not decompress %q: %v", path, err) + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + } + } + rw.Header().Set("Cache-Control", "max-age=2592000") // 30 days + rw.Header().Set("Content-Type", "application/octet-stream") + rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name)) + rw.Header().Set("X-Bmclapi-Hash", hash) + if leng, err := getFileSize(r); err == nil { + rw.Header().Set("Content-Length", strconv.FormatInt(leng, 10)) + } + rw.WriteHeader(http.StatusOK) + if req.Method != http.MethodHead { + var buf []byte + { + buf0 := bufPool.Get().(*[]byte) + defer bufPool.Put(buf0) + buf = *buf0 + } + n, _ := io.CopyBuffer(rw, r, buf) + cr.hits.Add(1) + cr.hbts.Add(n) + } + } +} + +func (cr *Cluster) handleDownloadOSS(rw http.ResponseWriter, req *http.Request, hash string) { + logDebug("[handler]: Preparing OSS redirect response") + + hashFilename := hashToFilename(hash) + + var err error + forEachSliceFromRandomIndex(len(cr.ossList), func(i int) bool { + item := cr.ossList[i] + logDebugf("[handler]: Checking file on OSS %d at %q ...", i, item.FolderPath) + + if !item.working.Load() { + logDebugf("[handler]: OSS %d is not working", i) + err = errors.New("All OSS server is down") + return false + } + + // check if the file exists + downloadDir := filepath.Join(item.FolderPath, "download") + path := filepath.Join(downloadDir, hashFilename) + var stat os.FileInfo + if stat, err = os.Stat(path); err != nil { + logDebugf("[handler]: Cannot read file on OSS %d: %v", i, err) + if errors.Is(err, os.ErrNotExist) { + if e := cr.DownloadFileOSS(req.Context(), downloadDir, hash); e != nil { + logDebugf("[handler]: Cound not download the file: %v", e) + return false + } + if stat, err = os.Stat(path); err != nil { + return false + } + } else { + return false + } + } + + var target string + target, err = url.JoinPath(item.RedirectBase, "download", hashFilename) + if err != nil { + return false + } + size := stat.Size() + if item.supportRange { // fix the size for Ranged request + rg := req.Header.Get("Range") + rgs, err := gosrc.ParseRange(rg, size) + if err == nil && len(rgs) > 0 { + size = 0 + for _, r := range rgs { + size += r.Length + } + } + } + http.Redirect(rw, req, target, http.StatusFound) + cr.hits.Add(1) + cr.hbts.Add(size) + return true + }) + if err != nil { + logDebugf("[handler]: OSS redirect failed: %v", err) + if errors.Is(err, os.ErrNotExist) { + http.Error(rw, "404 Status Not Found", http.StatusNotFound) + return + } + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + logDebug("[handler]: OSS redirect successed") +} diff --git a/help.go b/help.go new file mode 100644 index 00000000..6c2db26a --- /dev/null +++ b/help.go @@ -0,0 +1,55 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2023 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package main + +import ( + "fmt" + "os" +) + +func printHelp() { + fmt.Printf("Usage for %s\n\n", os.Args[0]) + fmt.Println("Sub commands:") + fmt.Println(" help") + fmt.Println(" \t" + "Show this message") + fmt.Println() + fmt.Println(" main | serve | ") + fmt.Println(" \t" + "Execute the main program") + fmt.Println() + fmt.Println(" license") + fmt.Println(" \t" + "Print the full program license") + fmt.Println() + fmt.Println(" zip-cache [options ...]") + fmt.Println(" \t" + "Compress the cache directory") + fmt.Println() + fmt.Println(" Options:") + fmt.Println(" " + "verbose | v : Show compressing files") + fmt.Println(" " + "all | a : Compress all files") + fmt.Println(" " + "overwrite | o : Overwrite compressed file even if it exists") + fmt.Println(" " + "keep | k : Keep uncompressed file") + fmt.Println() + fmt.Println(" unzip-cache") + fmt.Println(" \t" + "Decompress the cache directory") + fmt.Println() + fmt.Println(" Options:") + fmt.Println(" " + "verbose | v : Show decompressing files") + fmt.Println(" " + "overwrite | o : Overwrite uncompressed file even if it exists") + fmt.Println(" " + "keep | k : Keep compressed file") +} diff --git a/logger.go b/logger.go index 8b154457..1e341d04 100644 --- a/logger.go +++ b/logger.go @@ -49,7 +49,7 @@ func logX(x string, args ...any) { buf.WriteString(time.Now().Format(logTimeFormat)) buf.WriteString("]: ") buf.WriteString(c) - fmt.Fprintln(os.Stderr, buf.String()) + fmt.Fprintln(os.Stdout, buf.String()) if fd := logfile.Load(); fd != nil { fd.Write(buf.Bytes()) fd.Write([]byte{'\n'}) @@ -66,7 +66,7 @@ func logXf(x string, format string, args ...any) { buf.WriteString(time.Now().Format(logTimeFormat)) buf.WriteString("]: ") buf.WriteString(c) - fmt.Fprintln(os.Stderr, buf.String()) + fmt.Fprintln(os.Stdout, buf.String()) if fd := logfile.Load(); fd != nil { fd.Write(buf.Bytes()) fd.Write([]byte{'\n'}) diff --git a/main.go b/main.go index 43e21b16..853b7c0b 100644 --- a/main.go +++ b/main.go @@ -20,13 +20,17 @@ package main import ( + "compress/gzip" "context" + "encoding/hex" "errors" "fmt" + "io" "net" "net/http" "os" "os/signal" + "path/filepath" "strings" "sync/atomic" "syscall" @@ -43,23 +47,188 @@ var config Config const baseDir = "." -func main() { - printShortLicense() +func parseArgs() { if len(os.Args) > 1 { subcmd := strings.ToLower(os.Args[1]) switch subcmd { + case "main", "serve": + break case "license": printLongLicense() os.Exit(0) + case "zip-cache": + flagVerbose := false + flagAll := false + flagOverwrite := false + flagKeep := false + for _, a := range os.Args[2:] { + switch strings.ToLower(a) { + case "verbose", "v": + flagVerbose = true + case "all", "a": + flagAll = true + case "overwrite", "o": + flagOverwrite = true + case "keep", "k": + flagKeep = true + } + } + cacheDir := filepath.Join(baseDir, "cache") + fmt.Printf("Cache directory = %q\n", cacheDir) + err := walkCacheDir(cacheDir, func(path string) (_ error) { + if strings.HasSuffix(path, ".gz") { + return + } + target := path + ".gz" + if !flagOverwrite { + if _, err := os.Stat(target); err == nil { + return + } + } + srcFd, err := os.Open(path) + if err != nil { + fmt.Printf("Error: could not open file %q: %v\n", path, err) + return + } + defer srcFd.Close() + stat, err := srcFd.Stat() + if err != nil { + fmt.Printf("Error: could not get stat of %q: %v\n", path, err) + return + } + if flagAll || stat.Size() > 1024*10 { + if flagVerbose { + fmt.Printf("compressing %s\n", path) + } + tmpPath := target + ".tmp" + var dstFd *os.File + if dstFd, err = os.OpenFile(tmpPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644); err != nil { + fmt.Printf("Error: could not create %q: %v\n", tmpPath, err) + return + } + defer dstFd.Close() + w := gzip.NewWriter(dstFd) + defer w.Close() + + if _, err = io.Copy(w, srcFd); err != nil { + os.Remove(tmpPath) + fmt.Printf("Error: could not compress %q: %v\n", path, err) + return + } + os.Remove(target) + if err = os.Rename(tmpPath, target); err != nil { + os.Remove(tmpPath) + fmt.Printf("Error: could not rename %q to %q\n", tmpPath, target) + return + } + if !flagKeep { + os.Remove(path) + } + } + return + }) + if err != nil { + fmt.Printf("Could not walk cache directory: %v", err) + os.Exit(1) + } + os.Exit(0) + case "unzip-cache": + flagVerbose := false + flagOverwrite := false + flagKeep := false + for _, a := range os.Args[2:] { + switch strings.ToLower(a) { + case "verbose", "v": + flagVerbose = true + case "overwrite", "o": + flagOverwrite = true + case "keep", "k": + flagKeep = true + } + } + cacheDir := filepath.Join(baseDir, "cache") + fmt.Printf("Cache directory = %q\n", cacheDir) + var hashBuf [64]byte + err := walkCacheDir(cacheDir, func(path string) (_ error) { + target, ok := strings.CutSuffix(path, ".gz") + if !ok { + return + } + + hash := filepath.Base(target) + hashMethod, err := getHashMethod(len(hash)) + if err != nil { + return + } + hw := hashMethod.New() + + if !flagOverwrite { + if _, err := os.Stat(target); err == nil { + return + } + } + srcFd, err := os.Open(path) + if err != nil { + fmt.Printf("Error: could not open file %q: %v\n", path, err) + return + } + defer srcFd.Close() + if flagVerbose { + fmt.Printf("decompressing %s\n", path) + } + tmpPath := target + ".tmp" + var dstFd *os.File + if dstFd, err = os.OpenFile(tmpPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644); err != nil { + fmt.Printf("Error: could not create %q: %v\n", tmpPath, err) + return + } + defer dstFd.Close() + r, err := gzip.NewReader(srcFd) + if err != nil { + fmt.Printf("Error: could not decompress %q: %v\n", path, err) + return + } + if _, err = io.Copy(io.MultiWriter(dstFd, hw), r); err != nil { + os.Remove(tmpPath) + fmt.Printf("Error: could not decompress %q: %v\n", path, err) + return + } + if hs := hex.EncodeToString(hw.Sum(hashBuf[:0])); hs != hash { + os.Remove(tmpPath) + fmt.Printf("Error: hash (%s) incorrect for %q. Got %s, want %s\n", hashMethod, path, hs, hash) + return + } + os.Remove(target) + if err = os.Rename(tmpPath, target); err != nil { + os.Remove(tmpPath) + fmt.Printf("Error: could not rename %q to %q\n", tmpPath, target) + return + } + if !flagKeep { + os.Remove(path) + } + return + }) + if err != nil { + fmt.Printf("Could not walk cache directory: %v", err) + os.Exit(1) + } + os.Exit(0) default: fmt.Println("Unknown sub command:", subcmd) - os.Exit(-1) + printHelp() + os.Exit(0x7f) } } +} + +func main() { + printShortLicense() + parseArgs() defer func() { if err := recover(); err != nil { - logError("Panic error:", err) + logError("Panic:", err) panic(err) } }() @@ -241,8 +410,9 @@ START: } os.Exit(1) } - cluster.SyncFiles(ctx, fl) + cluster.SyncFiles(ctx, fl, true) + checkCount := 0 createInterval(ctx, func() { logInfof("Fetching file list") fl, err := cluster.GetFileList(ctx) @@ -250,7 +420,8 @@ START: logError("Cannot query cluster file list:", err) return } - cluster.SyncFiles(ctx, fl) + checkCount = (checkCount + 1) % 10 + cluster.SyncFiles(ctx, fl, checkCount == 0) }, (time.Duration)(config.SyncInterval)*time.Minute) if err := cluster.Enable(ctx); err != nil { diff --git a/util.go b/util.go index dc44e75f..344e9f84 100644 --- a/util.go +++ b/util.go @@ -24,11 +24,13 @@ import ( "crypto" "crypto/x509" "encoding/pem" + "errors" "fmt" "io" "math/rand" "os" "path/filepath" + "strconv" "strings" "sync" "time" @@ -42,6 +44,20 @@ func split(str string, b byte) (l, r string) { return str, "" } +func splitCSV(line string) (values map[string]float32) { + list := strings.Split(line, ",") + values = make(map[string]float32, len(list)) + for _, v := range list { + name, opt := split(strings.ToLower(strings.TrimSpace(v)), ';') + var q float64 = 1 + if v, ok := strings.CutPrefix(opt, "q="); ok { + q, _ = strconv.ParseFloat(v, 32) + } + values[name] = (float32)(q) + } + return +} + func hashToFilename(hash string) string { return filepath.Join(hash[0:2], hash) } @@ -210,3 +226,18 @@ var ( func (nullReader) Read([]byte) (int, error) { return 0, io.EOF } func (nullReader) ReadAt([]byte, int64) (int, error) { return 0, io.EOF } func (nullReader) Seek(int64, int) (int64, error) { return 0, nil } + +var errNotSeeker = errors.New("r is not an io.Seeker") + +func getFileSize(r io.Reader) (n int64, err error) { + if s, ok := r.(io.Seeker); ok { + if n, err = s.Seek(0, io.SeekEnd); err == nil { + if _, err = s.Seek(0, io.SeekStart); err != nil { + return + } + } + } else { + err = errNotSeeker + } + return +}