Skip to content

Commit

Permalink
feat: Implement contract whitelist functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
damilolaedwards committed Sep 30, 2024
1 parent 6f14ca7 commit ddc9c39
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 64 deletions.
6 changes: 5 additions & 1 deletion cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ func cmdRunStart(cmd *cobra.Command, args []string) error {
if projectConfig.OnChainConfig.Enabled {
cmdLogger.Info("Running Slither on contract at address: ", colors.Green, projectConfig.OnChainConfig.Address, colors.Reset)
} else {
cmdLogger.Info("Running Slither on the target contracts directory: ", colors.Green, projectConfig.TargetContracts.Dir, colors.Reset, ", Excluding paths: ", colors.Red, projectConfig.TargetContracts.ExcludePaths, colors.Reset, "\n")
if len(projectConfig.ContractWhitelist) > 0 {
cmdLogger.Info("Running Slither on the target contracts directory: ", colors.Green, projectConfig.TargetContracts.Dir, colors.Reset, ", Excluding paths: ", colors.Red, projectConfig.TargetContracts.ExcludePaths, colors.Reset, ", Selecting contracts: ", colors.Green, projectConfig.ContractWhitelist, colors.Reset, "\n")
} else {
cmdLogger.Info("Running Slither on the target contracts directory: ", colors.Green, projectConfig.TargetContracts.Dir, colors.Reset, ", Excluding paths: ", colors.Red, projectConfig.TargetContracts.ExcludePaths, colors.Reset, "\n")
}
}

// Parse contracts
Expand Down
3 changes: 3 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ type ProjectConfig struct {
// Name describes the project name.
Name string `json:"name" description:"The project name"`

// ContractWhitelist describes the only contracts that should be included
ContractWhitelist []string `json:"contractWhitelist"`

// TargetContracts describes the directory that holds the contracts to be fuzzed.
TargetContracts DirectoryConfig `json:"targetContracts"`

Expand Down
1 change: 1 addition & 0 deletions config/config_defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ func GetDefaultProjectConfig() (*ProjectConfig, error) {
Dir: "",
ExcludePaths: []string{},
},
ContractWhitelist: []string{},
TestContracts: DirectoryConfig{
Dir: "",
ExcludePaths: []string{},
Expand Down
154 changes: 91 additions & 63 deletions internal/slither/slither.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,39 @@ import (
//go:embed parse_contracts.py
var parseContractsScript string

func runSlitherOnLocal(targetDir string, targetExcludePaths []string, testDir string, testExcludePaths []string, slitherArgs map[string]any, outputFile *os.File) error {
// Run the command
func runSlitherOnLocal(targetDir string, targetExcludePaths []string, testDir string, testExcludePaths []string, slitherArgs map[string]any) (*types.SlitherOutput, error) {
// Create a temporary file to hold the Python script
tmpfile, err := os.CreateTemp("", "script-*.py")
scriptFile, err := os.CreateTemp("", "script-*.py")
if err != nil {
log.Println("Error creating temporary file:", err)
return nil
return nil, fmt.Errorf("error creating temporary file: %v", err)
}
defer func(name string) {
err := os.Remove(name)
if err != nil {
log.Println("Error removing temporary file:", err)
}
}(tmpfile.Name()) // Clean up
}(scriptFile.Name()) // Clean up

// Write the Python script to the temporary file
if _, err := tmpfile.Write([]byte(parseContractsScript)); err != nil {
log.Println("Error writing to temporary file:", err)
return nil
if _, err := scriptFile.Write([]byte(parseContractsScript)); err != nil {
return nil, fmt.Errorf("error writing to temporary file: %v", err)
}
if err := tmpfile.Close(); err != nil {
log.Println("Error closing temporary file:", err)
return nil
if err := scriptFile.Close(); err != nil {
return nil, fmt.Errorf("error closing temporary file: %v", err)
}

// Create a temporary file to hold the slither output
outputFile, err := os.CreateTemp("", "slither-output-*.json")
if err != nil {
return nil, fmt.Errorf("error creating temporary file: %v", err)
}
defer func(name string) {
err := os.Remove(name)
if err != nil {
log.Println("Error removing temporary file:", err)
}
}(outputFile.Name()) // Clean up

args := []string{"--target", ".", "--out", outputFile.Name(),
"--contracts-dir", targetDir}

Expand All @@ -61,46 +69,65 @@ func runSlitherOnLocal(targetDir string, targetExcludePaths []string, testDir st
}

// Prepare the command
cmd := exec.Command("python3", append([]string{tmpfile.Name()}, args...)...)
cmd := exec.Command("python3", append([]string{scriptFile.Name()}, args...)...)

output, err := cmd.CombinedOutput()
if err != nil {
fmt.Printf("error running slither: %v\n", err)
fmt.Printf("stderr: %s\n", output)
return err
return nil, fmt.Errorf("error running slither: %v\n", err)
}

// Read file contents
fileContents, err := os.ReadFile(outputFile.Name())
if err != nil {
return nil, fmt.Errorf("error reading slither output file: %v", err)
}

// Print out slither output
fmt.Println(string(output))

return nil
var slitherOutput types.SlitherOutput
err = json.Unmarshal(fileContents, &slitherOutput)
if err != nil {
return nil, err
}

return &slitherOutput, nil
}

func runSlitherOnchain(address string, networkPrefix string, apiKey string, slitherArgs map[string]any, outputFile *os.File) error {
// Run the command
func runSlitherOnchain(address string, networkPrefix string, apiKey string, slitherArgs map[string]any) (*types.SlitherOutput, error) {
// Create a temporary file to hold the Python script
tmpfile, err := os.CreateTemp("", "script-*.py")
scriptFile, err := os.CreateTemp("", "script-*.py")
if err != nil {
log.Println("Error creating temporary file:", err)
return nil
return nil, fmt.Errorf("error creating temporary file: %v", err)
}
defer func(name string) {
err := os.Remove(name)
if err != nil {
log.Println("Error removing temporary file:", err)
}
}(tmpfile.Name()) // Clean up
}(scriptFile.Name()) // Clean up

// Write the Python script to the temporary file
if _, err := tmpfile.Write([]byte(parseContractsScript)); err != nil {
log.Println("Error writing to temporary file:", err)
return nil
if _, err := scriptFile.Write([]byte(parseContractsScript)); err != nil {
return nil, fmt.Errorf("error writing to temporary file: %v", err)
}
if err := tmpfile.Close(); err != nil {
log.Println("Error closing temporary file:", err)
return nil
if err := scriptFile.Close(); err != nil {
return nil, fmt.Errorf("error closing temporary file: %v", err)
}

// Create a temporary file to hold the slither output
outputFile, err := os.CreateTemp("", "slither-output-*.json")
if err != nil {
return nil, fmt.Errorf("error creating temporary file: %v", err)
}
defer func(name string) {
err := os.Remove(name)
if err != nil {
log.Println("Error removing temporary file:", err)
}
}(outputFile.Name()) // Clean up

args := []string{"--target", address, "--out", outputFile.Name(), "--onchain",
"--network-prefix", networkPrefix, "--api-key", apiKey}

Expand All @@ -109,38 +136,38 @@ func runSlitherOnchain(address string, networkPrefix string, apiKey string, slit
}

// Prepare the command
cmd := exec.Command("python3", append([]string{tmpfile.Name()}, args...)...)
cmd := exec.Command("python3", append([]string{scriptFile.Name()}, args...)...)

output, err := cmd.CombinedOutput()
if err != nil {
fmt.Printf("error running slither: %v\n", err)
fmt.Printf("stderr: %s\n", output)
return err
return nil, fmt.Errorf("error running slither: %v\n", err)
}

// Print out slither output
// Print out output
fmt.Println(string(output))

return nil
}
// Read file contents
fileContents, err := os.ReadFile(outputFile.Name())
if err != nil {
return nil, fmt.Errorf("error reading slither output file: %v", err)
}

func ParseContracts(projectConfig *config.ProjectConfig) ([]types.Contract, string, error) {
var slitherOutput types.SlitherOutput

// Create a temporary file to hold the slither output
tmpfile, err := os.CreateTemp("", "slither-output-*.json")
err = json.Unmarshal(fileContents, &slitherOutput)
if err != nil {
return nil, "", fmt.Errorf("error creating temporary file: %v", err)
return nil, err
}
defer func(name string) {
err := os.Remove(name)
if err != nil {
log.Println("Error removing temporary file:", err)
}
}(tmpfile.Name()) // Clean up

return &slitherOutput, nil
}

func ParseContracts(projectConfig *config.ProjectConfig) ([]types.Contract, string, error) {
var slitherOutput *types.SlitherOutput
var err error

if projectConfig.OnChainConfig.Enabled {
err = runSlitherOnchain(projectConfig.OnChainConfig.Address, projectConfig.OnChainConfig.NetworkPrefix, projectConfig.OnChainConfig.ApiKey, projectConfig.SlitherArgs, tmpfile)
slitherOutput, err = runSlitherOnchain(projectConfig.OnChainConfig.Address, projectConfig.OnChainConfig.NetworkPrefix, projectConfig.OnChainConfig.ApiKey, projectConfig.SlitherArgs)
if err != nil {
return nil, "", err
}
Expand All @@ -151,29 +178,17 @@ func ParseContracts(projectConfig *config.ProjectConfig) ([]types.Contract, stri
return nil, "", fmt.Errorf("unable to read directory")
}

err = runSlitherOnLocal(projectConfig.TargetContracts.Dir, projectConfig.TargetContracts.ExcludePaths, projectConfig.TestContracts.Dir, projectConfig.TestContracts.ExcludePaths, projectConfig.SlitherArgs, tmpfile)
slitherOutput, err = runSlitherOnLocal(projectConfig.TargetContracts.Dir, projectConfig.TargetContracts.ExcludePaths, projectConfig.TestContracts.Dir, projectConfig.TestContracts.ExcludePaths, projectConfig.SlitherArgs)
if err != nil {
return nil, "", err
}
}

// Read the contents of the temporary file
file, err := os.ReadFile(tmpfile.Name())
if err != nil {
return nil, "", err
}

// Parse the slither output
err = json.Unmarshal(file, &slitherOutput)
if err != nil {
return nil, "", err
}

var filteredContracts []types.Contract
if projectConfig.OnChainConfig.Enabled {
filteredContracts = filterSlitherOutput(slitherOutput.Contracts, !projectConfig.OnChainConfig.ExcludeInterfaces, true, true)
filteredContracts = filterSlitherOutput(slitherOutput.Contracts, projectConfig.ContractWhitelist, !projectConfig.OnChainConfig.ExcludeInterfaces, true, true)
} else {
filteredContracts = filterSlitherOutput(slitherOutput.Contracts, projectConfig.IncludeInterfaces, projectConfig.IncludeAbstract, projectConfig.IncludeLibraries)
filteredContracts = filterSlitherOutput(slitherOutput.Contracts, projectConfig.ContractWhitelist, projectConfig.IncludeInterfaces, projectConfig.IncludeAbstract, projectConfig.IncludeLibraries)
}

contractCodes := getContractCodes(slitherOutput.Contracts)
Expand All @@ -192,10 +207,23 @@ func getContractCodes(contracts []types.SlitherContract) string {
return contractCodes.String()
}

func filterSlitherOutput(slitherContracts []types.SlitherContract, includeInterfaces bool, includeAbstract bool, includeLibraries bool) []types.Contract {
func filterSlitherOutput(slitherContracts []types.SlitherContract, whitelist []string, includeInterfaces bool, includeAbstract bool, includeLibraries bool) []types.Contract {
var filteredContracts []types.Contract
var whitelistMap map[string]bool

if len(whitelist) > 0 {
whitelistMap = make(map[string]bool)
for _, s := range whitelist {
whitelistMap[s] = true
}
}

for _, slitherContract := range slitherContracts {
// Skip contracts not in whitelist
if len(whitelist) > 0 && !whitelistMap[slitherContract.Name] {
continue
}

if !includeInterfaces && slitherContract.IsInterface {
continue
}
Expand All @@ -210,7 +238,7 @@ func filterSlitherOutput(slitherContracts []types.SlitherContract, includeInterf
ID: slitherContract.ID,
Name: slitherContract.Name,
Functions: slitherContract.Functions,
InheritedContracts: filterSlitherOutput(slitherContract.InheritedContracts, includeInterfaces, includeAbstract, includeLibraries),
InheritedContracts: filterSlitherOutput(slitherContract.InheritedContracts, []string{}, includeInterfaces, includeAbstract, includeLibraries),
IsAbstract: slitherContract.IsAbstract,
IsInterface: slitherContract.IsAbstract,
IsLibrary: slitherContract.IsAbstract,
Expand Down
9 changes: 9 additions & 0 deletions utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,12 @@ func MapToDictString(inputMap map[string]any) string {
sb.WriteString("}")
return sb.String()
}

// SliceContains returns whether
func SliceContains(arr []string, target string) bool {
m := make(map[string]bool)
for _, s := range arr {
m[s] = true
}
return m[target]
}

0 comments on commit ddc9c39

Please sign in to comment.