Skip to content

Commit

Permalink
Avoid reading decompressed data into memory (trufflesecurity#2196)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahrav authored Dec 14, 2023
1 parent e72fdb6 commit d8cb658
Showing 1 changed file with 7 additions and 23 deletions.
30 changes: 7 additions & 23 deletions pkg/handlers/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,6 @@ func (a *Archive) FromFile(originalCtx logContext.Context, data io.Reader) chan
return archiveChan
}

type decompressorInfo struct {
depth int
reader io.Reader
archiveChan chan []byte
archiver archiver.Decompressor
}

// openArchive takes a reader and extracts the contents up to the maximum depth.
func (a *Archive) openArchive(ctx logContext.Context, depth int, reader io.Reader, archiveChan chan []byte) error {
if depth >= maxDepth {
Expand All @@ -108,11 +101,14 @@ func (a *Archive) openArchive(ctx logContext.Context, depth int, reader io.Reade

switch archive := format.(type) {
case archiver.Decompressor:
info := decompressorInfo{depth: depth, reader: arReader, archiveChan: archiveChan, archiver: archive}

return a.handleDecompressor(ctx, info)
// Decompress tha archive and feed the decompressed data back into the archive handler to extract any nested archives.
compReader, err := archive.OpenReader(arReader)
if err != nil {
return err
}
return a.openArchive(ctx, depth+1, compReader, archiveChan)
case archiver.Extractor:
return archive.Extract(logContext.WithValue(ctx, depthKey, depth+1), reader, nil, a.extractorHandler(archiveChan))
return archive.Extract(logContext.WithValue(ctx, depthKey, depth+1), arReader, nil, a.extractorHandler(archiveChan))
default:
return fmt.Errorf("unknown archive type: %s", format.Name())
}
Expand All @@ -133,18 +129,6 @@ func (a *Archive) handleNonArchiveContent(ctx logContext.Context, reader io.Read
return nil
}

func (a *Archive) handleDecompressor(ctx logContext.Context, info decompressorInfo) error {
compReader, err := info.archiver.OpenReader(info.reader)
if err != nil {
return err
}
fileBytes, err := a.ReadToMax(ctx, compReader)
if err != nil {
return err
}
return a.openArchive(ctx, info.depth+1, bytes.NewReader(fileBytes), info.archiveChan)
}

// IsFiletype returns true if the provided reader is an archive.
func (a *Archive) IsFiletype(_ logContext.Context, reader io.Reader) (io.Reader, bool) {
format, readerB, err := archiver.Identify("", reader)
Expand Down

0 comments on commit d8cb658

Please sign in to comment.