-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit cac6d65
Showing
32 changed files
with
1,917 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
assistant |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
package cmd | ||
|
||
const ( | ||
// DefaultProjectConfigFilename describes the default config filename for a given project folder. | ||
DefaultProjectConfigFilename = "assistant.json" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
package cmd | ||
|
||
import ( | ||
"assistant/config" | ||
"assistant/llm" | ||
"assistant/logging/colors" | ||
"assistant/utils" | ||
"fmt" | ||
"github.com/spf13/cobra" | ||
"os" | ||
"path/filepath" | ||
"time" | ||
) | ||
|
||
var generateCmd = &cobra.Command{ | ||
Use: "generate", | ||
Short: "Generate invariants for Medusa", | ||
Long: `Generate invariants for Medusa`, | ||
Args: cmdValidateGenerateArgs, | ||
RunE: cmdRunGenerate, | ||
SilenceUsage: true, | ||
SilenceErrors: false, | ||
} | ||
|
||
func init() { | ||
// Add all the flags allowed for the generate command | ||
err := addGenerateFlags() | ||
if err != nil { | ||
cmdLogger.Panic("Failed to initialize the fuzz command", err) | ||
} | ||
|
||
// Add the generate command and its associated flags to the root command | ||
rootCmd.AddCommand(generateCmd) | ||
} | ||
|
||
// cmdValidateGenerateArgs makes sure that there are no positional arguments provided to the generate command | ||
func cmdValidateGenerateArgs(cmd *cobra.Command, args []string) error { | ||
// Make sure we have no positional args | ||
if err := cobra.NoArgs(cmd, args); err != nil { | ||
err = fmt.Errorf("generate does not accept any positional arguments, only flags and their associated values") | ||
cmdLogger.Error("Failed to validate args to the generate command", err) | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
// cmdRunGenerate runs the generate CLI command | ||
func cmdRunGenerate(cmd *cobra.Command, args []string) error { | ||
var projectConfig *config.ProjectConfig | ||
|
||
// Check to see if --config flag was used and store the value of --config flag | ||
configFlagUsed := cmd.Flags().Changed("config") | ||
configPath, err := cmd.Flags().GetString("config") | ||
if err != nil { | ||
cmdLogger.Error("Failed to run the generate command", err) | ||
return err | ||
} | ||
|
||
// If --config was not used, look for `assistant.json` in the current work directory | ||
if !configFlagUsed { | ||
workingDirectory, err := os.Getwd() | ||
if err != nil { | ||
cmdLogger.Error("Failed to run the generate command", err) | ||
return err | ||
} | ||
configPath = filepath.Join(workingDirectory, DefaultProjectConfigFilename) | ||
} | ||
|
||
// Check to see if the file exists at configPath | ||
_, existenceError := os.Stat(configPath) | ||
|
||
// Possibility #1: File was found | ||
if existenceError == nil { | ||
// Try to read the configuration file and throw an error if something goes wrong | ||
cmdLogger.Info("Reading the configuration file at: ", colors.Bold, configPath, colors.Reset) | ||
projectConfig, err = config.ReadProjectConfigFromFile(configPath) | ||
if err != nil { | ||
cmdLogger.Error("Failed to run the generate command", err) | ||
return err | ||
} | ||
} | ||
|
||
// Possibility #2: If the --config flag was used, and we couldn't find the file, we'll throw an error | ||
if configFlagUsed && existenceError != nil { | ||
cmdLogger.Error("Failed to run the generate command", err) | ||
return existenceError | ||
} | ||
|
||
// Possibility #3: --config flag was not used and assistant.json was not found, so use the default project config | ||
if !configFlagUsed && existenceError != nil { | ||
cmdLogger.Warn(fmt.Sprintf("Unable to find the config file at %v, will use the default project configuration", configPath)) | ||
|
||
projectConfig, err = config.GetDefaultProjectConfig() | ||
if err != nil { | ||
cmdLogger.Error("Failed to run the generate command", err) | ||
return err | ||
} | ||
} | ||
|
||
// Update the project configuration given whatever flags were set using the CLI | ||
err = updateProjectConfigWithGenerateFlags(cmd, projectConfig) | ||
if err != nil { | ||
cmdLogger.Error("Failed to run the generate command", err) | ||
return err | ||
} | ||
|
||
// Validate project config | ||
err = projectConfig.Validate() | ||
if err != nil { | ||
cmdLogger.Error("Failed to run the generate command", err) | ||
return err | ||
} | ||
|
||
targetContracts, err := utils.ReadDirectoryContents(projectConfig.TargetContracts.Dir, projectConfig.TargetContracts.ExcludePaths...) | ||
if err != nil { | ||
cmdLogger.Error("Failed to run the generate command", err) | ||
return err | ||
} | ||
|
||
var fuzzTests string | ||
if projectConfig.FuzzTests.Dir != "" { | ||
fuzzTests, err = utils.ReadDirectoryContents(projectConfig.FuzzTests.Dir, projectConfig.FuzzTests.ExcludePaths...) | ||
if err != nil { | ||
cmdLogger.Error("Failed to run the generate command", err) | ||
return err | ||
} | ||
} | ||
|
||
var unitTests string | ||
if projectConfig.UnitTests.Dir != "" { | ||
unitTests, err = utils.ReadDirectoryContents(projectConfig.UnitTests.Dir, projectConfig.UnitTests.ExcludePaths...) | ||
if err != nil { | ||
cmdLogger.Error("Failed to run the generate command", err) | ||
return err | ||
} | ||
} | ||
|
||
var parsedCoverageReport utils.CoverageReport | ||
if projectConfig.CoverageReportFile != "" { | ||
_, err := os.Stat(projectConfig.CoverageReportFile) | ||
if err != nil { | ||
cmdLogger.Error("Failed to run the generate command", err) | ||
return err | ||
} | ||
|
||
coverageReport, err := os.ReadFile(projectConfig.CoverageReportFile) | ||
if err != nil { | ||
cmdLogger.Error("Failed to run the generate command", err) | ||
return err | ||
} | ||
|
||
parsedCoverageReport, err = utils.ParseCoverageReportHTML(string(coverageReport)) | ||
if err != nil { | ||
cmdLogger.Error("Failed to run the generate command", err) | ||
return err | ||
} | ||
|
||
// Exclude reports of files we do not need | ||
includePaths := []string{projectConfig.TargetContracts.Dir} | ||
excludePaths := projectConfig.TargetContracts.ExcludePaths | ||
|
||
if projectConfig.FuzzTests.Dir != "" { | ||
includePaths = append(includePaths, projectConfig.FuzzTests.Dir) | ||
excludePaths = append(excludePaths, projectConfig.FuzzTests.ExcludePaths...) | ||
} | ||
|
||
utils.FilterCoverageFiles(&parsedCoverageReport, includePaths, excludePaths) | ||
} | ||
|
||
invariants, err := llm.AskGPT4Turbo(append(llm.TrainingPrompts(), llm.Message{ | ||
Role: "user", | ||
Content: llm.GenerateInvariantsPrompt(targetContracts, fuzzTests, unitTests, fmt.Sprintf("%v", parsedCoverageReport)), | ||
})) | ||
if err != nil { | ||
cmdLogger.Error("Failed to run the generate command", err) | ||
return err | ||
} | ||
|
||
invariants = fmt.Sprintf("=============== Invariants generated at %v ===============\n\n%v\n\n", time.Now().String(), invariants) | ||
|
||
// Open the out file in append mode, create it if it doesn't exist | ||
file, err := os.OpenFile(projectConfig.Out, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
defer func(file *os.File) { | ||
err := file.Close() | ||
if err != nil { | ||
cmdLogger.Error("Error closing "+projectConfig.Out, err) | ||
} | ||
}(file) | ||
|
||
// Write to the file | ||
_, err = file.WriteString(invariants) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
package cmd | ||
|
||
import ( | ||
"assistant/config" | ||
"fmt" | ||
|
||
"github.com/spf13/cobra" | ||
) | ||
|
||
// addGenerateFlags adds the various flags for the generate command | ||
func addGenerateFlags() error { | ||
// Get the default project config and throw an error if we cant | ||
defaultConfig, err := config.GetDefaultProjectConfig() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
// Prevent alphabetical sorting of usage message | ||
generateCmd.Flags().SortFlags = false | ||
|
||
// Config file | ||
generateCmd.Flags().String("config", "", "path to config file") | ||
|
||
// Flags | ||
generateCmd.Flags().String("out", "", | ||
fmt.Sprintf("path to output directory (unless a config file is provided, default is %q)", defaultConfig.Out)) | ||
generateCmd.Flags().String("target-contracts-dir", "", | ||
fmt.Sprintf("directory path for target contracts (unless a config file is provided, default is %q)", defaultConfig.TargetContracts.Dir)) | ||
generateCmd.Flags().String("fuzz-tests-dir", "", fmt.Sprintf("directory path for fuzz tests (unless a config file is provided, default is %q)", defaultConfig.FuzzTests.Dir)) | ||
generateCmd.Flags().String("unit-tests-dir", "", | ||
fmt.Sprintf("directory path for unit tests (unless a config file is provided, default is %q)", defaultConfig.UnitTests.Dir)) | ||
generateCmd.Flags().String("coverage-report-file", "", | ||
fmt.Sprintf("directory path for coverage report (unless a config file is provided, default is %q)", defaultConfig.CoverageReportFile)) | ||
|
||
return nil | ||
} | ||
|
||
// updateProjectConfigWithGenerateFlags will update the given projectConfig with any CLI arguments that were provided to the generate command | ||
func updateProjectConfigWithGenerateFlags(cmd *cobra.Command, projectConfig *config.ProjectConfig) error { | ||
var err error | ||
|
||
// Update output path | ||
if cmd.Flags().Changed("out") { | ||
projectConfig.Out, err = cmd.Flags().GetString("out") | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
|
||
// Update target contracts directory | ||
if cmd.Flags().Changed("target-contracts-dir") { | ||
projectConfig.TargetContracts.Dir, err = cmd.Flags().GetString("target-contracts-dir") | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
|
||
// Update fuzz tests directory | ||
if cmd.Flags().Changed("fuzz-tests-dir") { | ||
projectConfig.FuzzTests.Dir, err = cmd.Flags().GetString("fuzz-tests-dir") | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
|
||
// Update unit tests directory | ||
if cmd.Flags().Changed("unit-tests-dir") { | ||
projectConfig.UnitTests.Dir, err = cmd.Flags().GetString("unit-tests-dir") | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
|
||
// Update coverage report file | ||
if cmd.Flags().Changed("coverage-report-file") { | ||
projectConfig.CoverageReportFile, err = cmd.Flags().GetString("coverage-report-file") | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
|
||
return nil | ||
} |
Oops, something went wrong.