diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 84b22ac4..5324fc75 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -63,7 +63,7 @@ jobs: - uses: golangci/golangci-lint-action@971e284b6050e8a5849b72094c50ab08da042db8 # v6.1.1 if: matrix.os.name == 'linux' with: - version: v1.63.1 + version: v1.63.4 - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 with: name: regal-${{ matrix.os.name }} diff --git a/internal/lsp/cache/cache.go b/internal/lsp/cache/cache.go index 5e22b2e4..0abd48aa 100644 --- a/internal/lsp/cache/cache.go +++ b/internal/lsp/cache/cache.go @@ -106,6 +106,20 @@ func (c *Cache) SetModule(fileURI string, module *ast.Module) { c.modules.Set(fileURI, module) } +func (c *Cache) GetContentAndModule(fileURI string) (string, *ast.Module, bool) { + content, ok := c.GetFileContents(fileURI) + if !ok { + return "", nil, false + } + + module, ok := c.GetModule(fileURI) + if !ok { + return "", nil, false + } + + return content, module, true +} + func (c *Cache) Rename(oldKey, newKey string) { if content, ok := c.fileContents.Get(oldKey); ok { c.fileContents.Set(newKey, content) diff --git a/internal/lsp/handler/handler.go b/internal/lsp/handler/handler.go new file mode 100644 index 00000000..6c67142a --- /dev/null +++ b/internal/lsp/handler/handler.go @@ -0,0 +1,44 @@ +package handler + +import ( + "context" + + "github.com/anderseknert/roast/pkg/encoding" + "github.com/sourcegraph/jsonrpc2" +) + +type handlerFunc[T any] func(T) (any, error) + +type handlerContextFunc[T any] func(context.Context, T) (any, error) + +var ErrInvalidParams = &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams} + +func Decode[T any](req *jsonrpc2.Request, params *T) error { + if req.Params == nil { + return ErrInvalidParams + } + + if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { + return &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams, Message: err.Error()} + } + + return nil +} + +func WithParams[T any](req *jsonrpc2.Request, h handlerFunc[T]) (any, error) { + var params T + if err := Decode(req, ¶ms); err != nil { + return nil, err + } + + return h(params) +} + +func WithContextAndParams[T any](ctx context.Context, req *jsonrpc2.Request, h handlerContextFunc[T]) (any, error) { + var params T + if err := Decode(req, ¶ms); err != nil { + return nil, err + } + + return h(ctx, params) +} diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 4edac704..3663480d 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -1,4 +1,4 @@ -//nolint:nilerr,nilnil +//nolint:nilerr,nilnil,gochecknoglobals package lsp import ( @@ -36,6 +36,7 @@ import ( "github.com/styrainc/regal/internal/lsp/completions/providers" lsconfig "github.com/styrainc/regal/internal/lsp/config" "github.com/styrainc/regal/internal/lsp/examples" + "github.com/styrainc/regal/internal/lsp/handler" "github.com/styrainc/regal/internal/lsp/hover" "github.com/styrainc/regal/internal/lsp/log" "github.com/styrainc/regal/internal/lsp/opa/oracle" @@ -68,6 +69,19 @@ const ( ruleNameDirectoryPackageMismatch = "directory-package-mismatch" ) +var ( + noCodeActions = make([]types.CodeAction, 0) + noDocumentSymbols = make([]types.DocumentSymbol, 0) + noCompletionItems = make([]types.CompletionItem, 0) + noFoldingRanges = make([]types.FoldingRange, 0) + noDiagnostics = make([]types.Diagnostic, 0) + + trueValue = true + truePtr = &trueValue + + orc = oracle.New() +) + type LanguageServerOptions struct { // LogWriter is the io.Writer where all logged messages will be written. LogWriter io.Writer @@ -121,7 +135,7 @@ type LanguageServer struct { configWatcher *lsconfig.Watcher loadedConfig *config.Config // this is also used to lock the updates to the cache of enabled rules - loadedConfigLock sync.Mutex + loadedConfigLock sync.RWMutex loadedConfigEnabledNonAggregateRules []string loadedConfigEnabledAggregateRules []string loadedConfigAllRegoVersions *concurrent.Map[string, ast.RegoVersion] @@ -165,65 +179,62 @@ type lintWorkspaceJob struct { AggregateReportOnly bool } -func (l *LanguageServer) Handle( - ctx context.Context, - conn *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { +//nolint:wrapcheck +func (l *LanguageServer) Handle(ctx context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (any, error) { l.logf(log.LevelDebug, "received request: %s", req.Method) // null params are allowed, but only for certain methods if req.Params == nil && req.Method != "shutdown" && req.Method != "exit" { - return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams} + return nil, handler.ErrInvalidParams } switch req.Method { case "initialize": - return l.handleInitialize(ctx, conn, req) + return handler.WithContextAndParams(ctx, req, l.handleInitialize) case "initialized": - return l.handleInitialized(ctx, conn, req) + return l.handleInitialized() case "textDocument/codeAction": - return l.handleTextDocumentCodeAction(ctx, conn, req) + return handler.WithParams(req, l.handleTextDocumentCodeAction) case "textDocument/definition": - return l.handleTextDocumentDefinition(ctx, conn, req) + return handler.WithParams(req, l.handleTextDocumentDefinition) case "textDocument/diagnostic": - return l.handleTextDocumentDiagnostic(ctx, conn, req) + return l.handleTextDocumentDiagnostic() case "textDocument/didOpen": - return l.handleTextDocumentDidOpen(ctx, conn, req) + return handler.WithParams(req, l.handleTextDocumentDidOpen) case "textDocument/didClose": - return l.handleTextDocumentDidClose(ctx, conn, req) + return handler.WithParams(req, l.handleTextDocumentDidClose) case "textDocument/didSave": - return l.handleTextDocumentDidSave(ctx, conn, req) + return handler.WithContextAndParams(ctx, req, l.handleTextDocumentDidSave) case "textDocument/documentSymbol": - return l.handleTextDocumentDocumentSymbol(ctx, conn, req) + return handler.WithParams(req, l.handleTextDocumentDocumentSymbol) case "textDocument/didChange": - return l.handleTextDocumentDidChange(ctx, conn, req) + return handler.WithParams(req, l.handleTextDocumentDidChange) case "textDocument/foldingRange": - return l.handleTextDocumentFoldingRange(ctx, conn, req) + return handler.WithParams(req, l.handleTextDocumentFoldingRange) case "textDocument/formatting": - return l.handleTextDocumentFormatting(ctx, conn, req) + return handler.WithContextAndParams(ctx, req, l.handleTextDocumentFormatting) case "textDocument/hover": - return l.handleTextDocumentHover(ctx, conn, req) + return handler.WithParams(req, l.handleTextDocumentHover) case "textDocument/inlayHint": - return l.handleTextDocumentInlayHint(ctx, conn, req) + return handler.WithParams(req, l.handleTextDocumentInlayHint) case "textDocument/codeLens": - return l.handleTextDocumentCodeLens(ctx, conn, req) + return handler.WithContextAndParams(ctx, req, l.handleTextDocumentCodeLens) case "textDocument/completion": - return l.handleTextDocumentCompletion(ctx, conn, req) + return handler.WithContextAndParams(ctx, req, l.handleTextDocumentCompletion) case "workspace/didChangeWatchedFiles": - return l.handleWorkspaceDidChangeWatchedFiles(ctx, conn, req) + return handler.WithParams(req, l.handleWorkspaceDidChangeWatchedFiles) case "workspace/diagnostic": - return l.handleWorkspaceDiagnostic(ctx, conn, req) + return l.handleWorkspaceDiagnostic() case "workspace/didRenameFiles": - return l.handleWorkspaceDidRenameFiles(ctx, conn, req) + return handler.WithContextAndParams(ctx, req, l.handleWorkspaceDidRenameFiles) case "workspace/didDeleteFiles": - return l.handleWorkspaceDidDeleteFiles(ctx, conn, req) + return handler.WithContextAndParams(ctx, req, l.handleWorkspaceDidDeleteFiles) case "workspace/didCreateFiles": - return l.handleWorkspaceDidCreateFiles(ctx, conn, req) + return handler.WithParams(req, l.handleWorkspaceDidCreateFiles) case "workspace/executeCommand": - return l.handleWorkspaceExecuteCommand(ctx, conn, req) + return handler.WithParams(req, l.handleWorkspaceExecuteCommand) case "workspace/symbol": - return l.handleWorkspaceSymbol(ctx, conn, req) + return l.handleWorkspaceSymbol() case "shutdown": // no-op as we wait for the exit signal before closing channel return struct{}{}, nil @@ -466,22 +477,22 @@ func (l *LanguageServer) StartHoverWorker(ctx context.Context) { } func (l *LanguageServer) getLoadedConfig() *config.Config { - l.loadedConfigLock.Lock() - defer l.loadedConfigLock.Unlock() + l.loadedConfigLock.RLock() + defer l.loadedConfigLock.RUnlock() return l.loadedConfig } func (l *LanguageServer) getEnabledNonAggregateRules() []string { - l.loadedConfigLock.Lock() - defer l.loadedConfigLock.Unlock() + l.loadedConfigLock.RLock() + defer l.loadedConfigLock.RUnlock() return l.loadedConfigEnabledNonAggregateRules } func (l *LanguageServer) getEnabledAggregateRules() []string { - l.loadedConfigLock.Lock() - defer l.loadedConfigLock.Unlock() + l.loadedConfigLock.RLock() + defer l.loadedConfigLock.RUnlock() return l.loadedConfigEnabledAggregateRules } @@ -490,19 +501,21 @@ func (l *LanguageServer) getEnabledAggregateRules() []string { // config. These take some time to compute and only change when config changes, // so we can store them on the server to speed up diagnostic runs. func (l *LanguageServer) loadEnabledRulesFromConfig(ctx context.Context, cfg config.Config) error { - l.loadedConfigLock.Lock() - defer l.loadedConfigLock.Unlock() + lint := linter.NewLinter().WithUserConfig(cfg) - enabledRules, err := linter.NewLinter().WithUserConfig(cfg).DetermineEnabledRules(ctx) + enabledRules, err := lint.DetermineEnabledRules(ctx) if err != nil { return fmt.Errorf("failed to determine enabled rules: %w", err) } - enabledAggregateRules, err := linter.NewLinter().WithUserConfig(cfg).DetermineEnabledAggregateRules(ctx) + enabledAggregateRules, err := lint.DetermineEnabledAggregateRules(ctx) if err != nil { return fmt.Errorf("failed to determine enabled aggregate rules: %w", err) } + l.loadedConfigLock.Lock() + defer l.loadedConfigLock.Unlock() + l.loadedConfigEnabledNonAggregateRules = []string{} for _, r := range enabledRules { @@ -607,7 +620,6 @@ func (l *LanguageServer) StartConfigWorker(ctx context.Context) { contents, ok := l.cache.GetFileContents(k) if ok { l.cache.Delete(k) - l.cache.SetIgnoredFileContents(k, contents) } @@ -802,16 +814,9 @@ func (l *LanguageServer) StartCommandWorker(ctx context.Context) { //nolint:main break } - currentModule, ok := l.cache.GetModule(file) - if !ok { - l.logf(log.LevelMessage, "failed to get module for file %q", file) - - break - } - - currentContents, ok := l.cache.GetFileContents(file) + currentContents, currentModule, ok := l.cache.GetContentAndModule(file) if !ok { - l.logf(log.LevelMessage, "failed to get contents for file %q", file) + l.logf(log.LevelMessage, "failed to get content or module for file %q", file) break } @@ -1292,19 +1297,8 @@ func (l *LanguageServer) fixRenameParams( }, errors.New("failed to find fixed file's old location") } - changes := make([]any, 0) - oldURI := uri.FromPath(l.clientIdentifier, oldFile) newURI := uri.FromPath(l.clientIdentifier, fixedFile) - changes = append(changes, types.RenameFile{ - Kind: "rename", - OldURI: oldURI, - NewURI: newURI, - Options: &types.RenameFileOptions{ - Overwrite: false, - IgnoreIfExists: false, - }, - }) // are there old dirs? dirs, err := util.DirCleanUpPaths( @@ -1320,6 +1314,17 @@ func (l *LanguageServer) fixRenameParams( return types.ApplyWorkspaceAnyEditParams{}, fmt.Errorf("failed to determine empty directories post rename: %w", err) } + changes := make([]any, 0, len(dirs)+1) + changes = append(changes, types.RenameFile{ + Kind: "rename", + OldURI: oldURI, + NewURI: newURI, + Options: &types.RenameFileOptions{ + Overwrite: false, + IgnoreIfExists: false, + }, + }) + for _, dir := range dirs { changes = append( changes, @@ -1362,20 +1367,17 @@ func (l *LanguageServer) processHoverContentUpdate(ctx context.Context, fileURI bis := l.builtinsForCurrentCapabilities() - success, err := updateParse(ctx, l.cache, l.regoStore, fileURI, bis) - if err != nil { + if success, err := updateParse(ctx, l.cache, l.regoStore, fileURI, bis); err != nil { return fmt.Errorf("failed to update parse: %w", err) - } - - if !success { + } else if !success { return nil } - if err = hover.UpdateBuiltinPositions(l.cache, fileURI, bis); err != nil { + if err := hover.UpdateBuiltinPositions(l.cache, fileURI, bis); err != nil { return fmt.Errorf("failed to update builtin positions: %w", err) } - if err = hover.UpdateKeywordLocations(ctx, l.cache, fileURI); err != nil { + if err := hover.UpdateKeywordLocations(ctx, l.cache, fileURI); err != nil { return fmt.Errorf("failed to update keyword locations: %w", err) } @@ -1401,19 +1403,7 @@ type HoverResponse struct { Range types.Range `json:"range"` } -func (l *LanguageServer) handleTextDocumentHover( - _ context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.TextDocumentHoverParams - - json := encoding.JSON() - - if err := json.Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - +func (l *LanguageServer) handleTextDocumentHover(params types.TextDocumentHoverParams) (any, error) { if l.ignoreURI(params.TextDocument.URI) { return nil, nil } @@ -1521,23 +1511,13 @@ func (l *LanguageServer) handleTextDocumentHover( return nil, nil } -func (l *LanguageServer) handleTextDocumentCodeAction( - _ context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.CodeActionParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - - yes := true - actions := make([]types.CodeAction, 0) - +func (l *LanguageServer) handleTextDocumentCodeAction(params types.CodeActionParams) (any, error) { if l.ignoreURI(params.TextDocument.URI) { - return actions, nil + return noCodeActions, nil } + actions := []types.CodeAction{} + // only VS Code has the capability to open a provided URL, as far as we know // if we learn about others with this capability later, we should add them! if l.clientIdentifier == clients.IdentifierVSCode { @@ -1562,7 +1542,7 @@ func (l *LanguageServer) handleTextDocumentCodeAction( Title: "Format using opa fmt", Kind: "quickfix", Diagnostics: []types.Diagnostic{diag}, - IsPreferred: &yes, + IsPreferred: truePtr, Command: FmtCommand([]string{params.TextDocument.URI}), }) case ruleNameUseRegoV1: @@ -1570,7 +1550,7 @@ func (l *LanguageServer) handleTextDocumentCodeAction( Title: "Format for Rego v1 using opa fmt", Kind: "quickfix", Diagnostics: []types.Diagnostic{diag}, - IsPreferred: &yes, + IsPreferred: truePtr, Command: FmtV1Command([]string{params.TextDocument.URI}), }) case ruleNameUseAssignmentOperator: @@ -1578,7 +1558,7 @@ func (l *LanguageServer) handleTextDocumentCodeAction( Title: "Replace = with := in assignment", Kind: "quickfix", Diagnostics: []types.Diagnostic{diag}, - IsPreferred: &yes, + IsPreferred: truePtr, Command: UseAssignmentOperatorCommand([]string{ params.TextDocument.URI, strconv.FormatUint(uint64(diag.Range.Start.Line+1), 10), @@ -1590,7 +1570,7 @@ func (l *LanguageServer) handleTextDocumentCodeAction( Title: "Format comment to have leading whitespace", Kind: "quickfix", Diagnostics: []types.Diagnostic{diag}, - IsPreferred: &yes, + IsPreferred: truePtr, Command: NoWhiteSpaceCommentCommand([]string{ params.TextDocument.URI, strconv.FormatUint(uint64(diag.Range.Start.Line+1), 10), @@ -1602,7 +1582,7 @@ func (l *LanguageServer) handleTextDocumentCodeAction( Title: "Move file so that directory structure mirrors package path", Kind: "quickfix", Diagnostics: []types.Diagnostic{diag}, - IsPreferred: &yes, + IsPreferred: truePtr, Command: DirectoryStructureMismatchCommand([]string{ params.TextDocument.URI, }), @@ -1616,7 +1596,7 @@ func (l *LanguageServer) handleTextDocumentCodeAction( Title: txt, Kind: "quickfix", Diagnostics: []types.Diagnostic{diag}, - IsPreferred: &yes, + IsPreferred: truePtr, Command: types.Command{ Title: txt, Command: "vscode.open", @@ -1629,16 +1609,7 @@ func (l *LanguageServer) handleTextDocumentCodeAction( return actions, nil } -func (l *LanguageServer) handleWorkspaceExecuteCommand( - _ context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.ExecuteCommandParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - +func (l *LanguageServer) handleWorkspaceExecuteCommand(params types.ExecuteCommandParams) (any, error) { // this must not block, so we send the request to the worker on a buffered channel. // the response to the workspace/executeCommand request must be sent before the command is executed // so that the client can complete the request and be ready to receive the follow-on request for @@ -1649,17 +1620,7 @@ func (l *LanguageServer) handleWorkspaceExecuteCommand( return struct{}{}, nil } -func (l *LanguageServer) handleTextDocumentInlayHint( - _ context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.TextDocumentInlayHintParams - - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - +func (l *LanguageServer) handleTextDocumentInlayHint(params types.TextDocumentInlayHintParams) (any, error) { if l.ignoreURI(params.TextDocument.URI) { return []types.InlayHint{}, nil } @@ -1679,16 +1640,12 @@ func (l *LanguageServer) handleTextDocumentInlayHint( return partialInlayHints(parseErrors, contents, params.TextDocument.URI, bis), nil } + // TODO: use GetContentAndModule here, or do we need to handle the cases separately? // file is blank, nothing to do if contents, ok := l.cache.GetFileContents(params.TextDocument.URI); ok && contents == "" { return []types.InlayHint{}, nil } - // file could not be parsed, nothing to do - if errors, ok := l.cache.GetParseErrors(params.TextDocument.URI); ok && len(errors) > 0 { - return []types.InlayHint{}, nil - } - module, ok := l.cache.GetModule(params.TextDocument.URI) if !ok { l.logf(log.LevelMessage, "failed to get inlay hint: no parsed module for uri %q", params.TextDocument.URI) @@ -1701,22 +1658,8 @@ func (l *LanguageServer) handleTextDocumentInlayHint( return inlayHints, nil } -func (l *LanguageServer) handleTextDocumentCodeLens( - ctx context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.CodeLensParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - - module, ok := l.cache.GetModule(params.TextDocument.URI) - if !ok { - return nil, nil // return a null response, as per the spec - } - - contents, ok := l.cache.GetFileContents(params.TextDocument.URI) +func (l *LanguageServer) handleTextDocumentCodeLens(ctx context.Context, params types.CodeLensParams) (any, error) { + contents, module, ok := l.cache.GetContentAndModule(params.TextDocument.URI) if !ok { return nil, nil // return a null response, as per the spec } @@ -1743,18 +1686,8 @@ func (l *LanguageServer) handleTextDocumentCodeLens( return filteredLenses, nil } -func (l *LanguageServer) handleTextDocumentCompletion( - ctx context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (any, error) { - var params types.CompletionParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - - // when config ignores a file, then we return an empty completion list - // as a no-op. +func (l *LanguageServer) handleTextDocumentCompletion(ctx context.Context, params types.CompletionParams) (any, error) { + // when config ignores a file, then we return an empty completion list as a no-op. if l.ignoreURI(params.TextDocument.URI) { return types.CompletionList{ IsIncomplete: false, @@ -1777,20 +1710,19 @@ func (l *LanguageServer) handleTextDocumentCompletion( return nil, fmt.Errorf("failed to find completions: %w", err) } - // make sure the items is always [] instead of null as is required by the spec if items == nil { - return types.CompletionList{ - IsIncomplete: false, - Items: make([]types.CompletionItem, 0), - }, nil + // make sure the items is always [] instead of null as is required by the spec + items = noCompletionItems } return types.CompletionList{ - IsIncomplete: true, + IsIncomplete: items != nil, Items: items, }, nil } +var noInlayHints = make([]types.InlayHint, 0) + func partialInlayHints( parseErrors []types.Diagnostic, contents, @@ -1804,14 +1736,10 @@ func partialInlayHints( } } - if firstErrorLine == 0 { + if firstErrorLine == 0 || firstErrorLine > uint(len(strings.Split(contents, "\n"))) { // if there are parse errors from line 0, we skip doing anything - return []types.InlayHint{} - } - - if firstErrorLine > uint(len(strings.Split(contents, "\n"))) { // if the last valid line is beyond the end of the file, we exit as something is up - return []types.InlayHint{} + return noInlayHints } // select the lines from the contents up to the first parse error @@ -1821,30 +1749,19 @@ func partialInlayHints( module, err := rparse.Module(fileURI, lines) if err != nil { // if we still can't parse the bit we hoped was valid, we exit as this is 'too hard' - return []types.InlayHint{} + return noInlayHints } return getInlayHints(module, builtins) } -func (l *LanguageServer) handleWorkspaceSymbol( - _ context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.WorkspaceSymbolParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - +// Note: currently ignoring params.Query, as the client seems to do a good +// job of filtering anyway, and that would merely be an optimization here. +// But perhaps a good one to do at some point, and I'm not sure all clients +// do this filtering. +func (l *LanguageServer) handleWorkspaceSymbol() (any, error) { symbols := make([]types.WorkspaceSymbol, 0) contents := l.cache.GetAllFiles() - - // Note: currently ignoring params.Query, as the client seems to do a good - // job of filtering anyway, and that would merely be an optimization here. - // But perhaps a good one to do at some point, and I'm not sure all clients - // do this filtering. - bis := l.builtinsForCurrentCapabilities() for moduleURL, module := range l.cache.GetAllModules() { @@ -1860,16 +1777,7 @@ func (l *LanguageServer) handleWorkspaceSymbol( return symbols, nil } -func (l *LanguageServer) handleTextDocumentDefinition( - _ context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.DefinitionParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - +func (l *LanguageServer) handleTextDocumentDefinition(params types.DefinitionParams) (any, error) { if l.ignoreURI(params.TextDocument.URI) { return nil, nil } @@ -1884,7 +1792,6 @@ func (l *LanguageServer) handleTextDocumentDefinition( return nil, fmt.Errorf("failed to filter ignored paths: %w", err) } - orc := oracle.New() query := oracle.DefinitionQuery{ Filename: uri.ToPath(l.clientIdentifier, params.TextDocument.URI), Pos: positionToOffset(contents, params.Position), @@ -1917,16 +1824,7 @@ func (l *LanguageServer) handleTextDocumentDefinition( return loc, nil } -func (l *LanguageServer) handleTextDocumentDidOpen( - _ context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.TextDocumentDidOpenParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - +func (l *LanguageServer) handleTextDocumentDidOpen(params types.TextDocumentDidOpenParams) (any, error) { // if the opened file is ignored in config, then we only store the // contents for file level operations like formatting. if l.ignoreURI(params.TextDocument.URI) { @@ -1946,22 +1844,12 @@ func (l *LanguageServer) handleTextDocumentDidOpen( } l.lintFileJobs <- job - l.builtinsPositionJobs <- job return struct{}{}, nil } -func (l *LanguageServer) handleTextDocumentDidClose( - _ context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.TextDocumentDidCloseParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - +func (l *LanguageServer) handleTextDocumentDidClose(params types.TextDocumentDidCloseParams) (any, error) { // if the file being closed is ignored in config, then we // need to clear it from the ignored state in the cache. if l.ignoreURI(params.TextDocument.URI) { @@ -1971,16 +1859,7 @@ func (l *LanguageServer) handleTextDocumentDidClose( return struct{}{}, nil } -func (l *LanguageServer) handleTextDocumentDidChange( - _ context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.TextDocumentDidChangeParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - +func (l *LanguageServer) handleTextDocumentDidChange(params types.TextDocumentDidChangeParams) (any, error) { if len(params.ContentChanges) == 0 { return struct{}{}, nil } @@ -1996,10 +1875,6 @@ func (l *LanguageServer) handleTextDocumentDidChange( return struct{}{}, nil } - if len(params.ContentChanges) < 1 { - return struct{}{}, nil - } - l.cache.SetFileContents(params.TextDocument.URI, params.ContentChanges[0].Text) job := lintFileJob{ @@ -2015,14 +1890,8 @@ func (l *LanguageServer) handleTextDocumentDidChange( func (l *LanguageServer) handleTextDocumentDidSave( ctx context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.TextDocumentDidSaveParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - + params types.TextDocumentDidSaveParams, +) (any, error) { if params.Text != nil && l.getLoadedConfig() == nil { if !strings.Contains(*params.Text, "\r\n") { return struct{}{}, nil @@ -2037,15 +1906,9 @@ func (l *LanguageServer) handleTextDocumentDidSave( return struct{}{}, nil } - formattingEnabled := false - - for _, rule := range enabled { - if rule == "opa-fmt" || rule == "use-rego-v1" { - formattingEnabled = true - - break - } - } + formattingEnabled := slices.ContainsFunc(enabled, func(rule string) bool { + return rule == ruleNameOPAFmt || rule == ruleNameUseRegoV1 + }) if formattingEnabled { resp := types.ShowMessageParams{ @@ -2064,30 +1927,16 @@ func (l *LanguageServer) handleTextDocumentDidSave( return struct{}{}, nil } -func (l *LanguageServer) handleTextDocumentDocumentSymbol( - _ context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.DocumentSymbolParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - +func (l *LanguageServer) handleTextDocumentDocumentSymbol(params types.DocumentSymbolParams) (any, error) { if l.ignoreURI(params.TextDocument.URI) { - return []types.DocumentSymbol{}, nil + return noDocumentSymbols, nil } - contents, ok := l.cache.GetFileContents(params.TextDocument.URI) + contents, module, ok := l.cache.GetContentAndModule(params.TextDocument.URI) if !ok { l.logf(log.LevelMessage, "failed to get file contents for uri %q", params.TextDocument.URI) - return []types.DocumentSymbol{}, nil - } - - module, ok := l.cache.GetModule(params.TextDocument.URI) - if !ok { - return []types.DocumentSymbol{}, nil + return noDocumentSymbols, nil } bis := l.builtinsForCurrentCapabilities() @@ -2095,24 +1944,10 @@ func (l *LanguageServer) handleTextDocumentDocumentSymbol( return documentSymbols(contents, module, bis), nil } -func (l *LanguageServer) handleTextDocumentFoldingRange( - _ context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.FoldingRangeParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - - module, ok := l.cache.GetModule(params.TextDocument.URI) - if !ok { - return []types.FoldingRange{}, nil - } - - text, ok := l.cache.GetFileContents(params.TextDocument.URI) +func (l *LanguageServer) handleTextDocumentFoldingRange(params types.FoldingRangeParams) (any, error) { + text, module, ok := l.cache.GetContentAndModule(params.TextDocument.URI) if !ok { - return []types.FoldingRange{}, nil + return noFoldingRanges, nil } return findFoldingRanges(text, module), nil @@ -2120,14 +1955,8 @@ func (l *LanguageServer) handleTextDocumentFoldingRange( func (l *LanguageServer) handleTextDocumentFormatting( ctx context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.DocumentFormattingParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - + params types.DocumentFormattingParams, +) (any, error) { var oldContent string // Fetch the contents used for formatting from the appropriate cache location. @@ -2258,22 +2087,13 @@ func (l *LanguageServer) handleTextDocumentFormatting( return ComputeEdits(oldContent, newContent), nil } -func (l *LanguageServer) handleWorkspaceDidCreateFiles( - _ context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.WorkspaceDidCreateFilesParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - +func (l *LanguageServer) handleWorkspaceDidCreateFiles(params types.WorkspaceDidCreateFilesParams) (any, error) { if l.ignoreURI(params.Files[0].URI) { return struct{}{}, nil } for _, createOp := range params.Files { - if _, _, err = cache.UpdateCacheForURIFromDisk( + if _, _, err := cache.UpdateCacheForURIFromDisk( l.cache, uri.FromPath(l.clientIdentifier, createOp.URI), uri.ToPath(l.clientIdentifier, createOp.URI), @@ -2296,14 +2116,8 @@ func (l *LanguageServer) handleWorkspaceDidCreateFiles( func (l *LanguageServer) handleWorkspaceDidDeleteFiles( ctx context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.WorkspaceDidDeleteFilesParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - + params types.WorkspaceDidDeleteFilesParams, +) (any, error) { if l.ignoreURI(params.Files[0].URI) { return struct{}{}, nil } @@ -2321,19 +2135,15 @@ func (l *LanguageServer) handleWorkspaceDidDeleteFiles( func (l *LanguageServer) handleWorkspaceDidRenameFiles( ctx context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - var params types.WorkspaceDidRenameFilesParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - + params types.WorkspaceDidRenameFilesParams, +) (any, error) { for _, renameOp := range params.Files { if l.ignoreURI(renameOp.OldURI) && l.ignoreURI(renameOp.NewURI) { continue } + var err error + content, ok := l.cache.GetFileContents(renameOp.OldURI) // if the content is not in the cache then we can attempt to load from // the disk instead. @@ -2351,7 +2161,7 @@ func (l *LanguageServer) handleWorkspaceDidRenameFiles( // clear the cache and send diagnostics for the old URI to clear the client l.cache.Delete(renameOp.OldURI) - if err := l.sendFileDiagnostics(ctx, renameOp.OldURI); err != nil { + if err = l.sendFileDiagnostics(ctx, renameOp.OldURI); err != nil { l.logf(log.LevelMessage, "failed to send diagnostic: %s", err) } @@ -2375,11 +2185,7 @@ func (l *LanguageServer) handleWorkspaceDidRenameFiles( return struct{}{}, nil } -func (l *LanguageServer) handleWorkspaceDiagnostic( - _ context.Context, - _ *jsonrpc2.Conn, - _ *jsonrpc2.Request, -) (result any, err error) { +func (l *LanguageServer) handleWorkspaceDiagnostic() (any, error) { workspaceReport := types.WorkspaceDiagnosticReport{ Items: make([]types.WorkspaceFullDocumentDiagnosticReport, 0), } @@ -2393,7 +2199,7 @@ func (l *LanguageServer) handleWorkspaceDiagnostic( wkspceDiags, ok := l.cache.GetFileDiagnostics(l.workspaceRootURI) if !ok { - wkspceDiags = []types.Diagnostic{} + wkspceDiags = noDiagnostics } workspaceReport.Items = append(workspaceReport.Items, types.WorkspaceFullDocumentDiagnosticReport{ @@ -2406,20 +2212,10 @@ func (l *LanguageServer) handleWorkspaceDiagnostic( return workspaceReport, nil } -func (l *LanguageServer) handleInitialize( - ctx context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (any, error) { - var params types.InitializeParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - +func (l *LanguageServer) handleInitialize(ctx context.Context, params types.InitializeParams) (any, error) { // params.RootURI is not expected to have a trailing slash, but if one is // present it will be removed for consistency. l.workspaceRootURI = strings.TrimSuffix(params.RootURI, rio.PathSeparator) - l.clientIdentifier = clients.DetermineClientIdentifier(params.ClientInfo.Name) if l.clientIdentifier == clients.IdentifierGeneric { @@ -2615,11 +2411,7 @@ func (l *LanguageServer) loadWorkspaceContents(ctx context.Context, newOnly bool return changedOrNewURIs, nil } -func (l *LanguageServer) handleInitialized( - _ context.Context, - _ *jsonrpc2.Conn, - _ *jsonrpc2.Request, -) (result any, err error) { +func (l *LanguageServer) handleInitialized() (any, error) { // if running without config, then we should send the diagnostic request now // otherwise it'll happen when the config is loaded if !l.configWatcher.IsWatching() { @@ -2629,11 +2421,7 @@ func (l *LanguageServer) handleInitialized( return struct{}{}, nil } -func (*LanguageServer) handleTextDocumentDiagnostic( - _ context.Context, - _ *jsonrpc2.Conn, - _ *jsonrpc2.Request, -) (result any, err error) { +func (*LanguageServer) handleTextDocumentDiagnostic() (any, error) { // this is a no-op. Because we accept the textDocument/didChange event, which contains the new content, // we don't need to do anything here as once the new content has been parsed, the diagnostics will be sent // on the channel regardless of this request. @@ -2641,19 +2429,8 @@ func (*LanguageServer) handleTextDocumentDiagnostic( } func (l *LanguageServer) handleWorkspaceDidChangeWatchedFiles( - _ context.Context, - _ *jsonrpc2.Conn, - req *jsonrpc2.Request, -) (result any, err error) { - if req.Params == nil { - return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams} - } - - var params types.WorkspaceDidChangeWatchedFilesParams - if err := encoding.JSON().Unmarshal(*req.Params, ¶ms); err != nil { - return nil, fmt.Errorf("failed to unmarshal params: %w", err) - } - + params types.WorkspaceDidChangeWatchedFilesParams, +) (any, error) { // this handles the case of a new config file being created when one did // not exist before if len(params.Changes) > 0 && strings.HasSuffix(params.Changes[0].URI, ".regal/config.yaml") { @@ -2664,7 +2441,7 @@ func (l *LanguageServer) handleWorkspaceDidChangeWatchedFiles( } // when a file is changed (saved), then we send trigger a full workspace lint - regoFiles := make([]string, 0) + regoFiles := make([]string, 0, len(params.Changes)) for _, change := range params.Changes { if change.URI == "" || l.ignoreURI(change.URI) { @@ -2702,7 +2479,7 @@ func (l *LanguageServer) sendFileDiagnostics(ctx context.Context, fileURI string } func (l *LanguageServer) getFilteredModules() (map[string]*ast.Module, error) { - ignore := make([]string, 0) + var ignore []string if cfg := l.getLoadedConfig(); cfg != nil && cfg.Ignore.Files != nil { ignore = cfg.Ignore.Files @@ -2741,11 +2518,8 @@ func (l *LanguageServer) ignoreURI(fileURI string) bool { false, l.workspacePath(), ) - if err != nil || len(paths) == 0 { - return true - } - return false + return err != nil || len(paths) == 0 } func (l *LanguageServer) workspacePath() string { diff --git a/internal/lsp/server_formatting_test.go b/internal/lsp/server_formatting_test.go index 607aadf4..ae9928a2 100644 --- a/internal/lsp/server_formatting_test.go +++ b/internal/lsp/server_formatting_test.go @@ -2,7 +2,6 @@ package lsp import ( "context" - "encoding/json" "path/filepath" "testing" @@ -26,7 +25,7 @@ func TestFormatting(t *testing.T) { return struct{}{}, nil } - ls, connClient, err := createAndInitServer(ctx, newTestLogger(t), tempDir, map[string]string{}, clientHandler) + ls, _, err := createAndInitServer(ctx, newTestLogger(t), tempDir, map[string]string{}, clientHandler) if err != nil { t.Fatalf("failed to create and init language server: %s", err) } @@ -39,19 +38,12 @@ func TestFormatting(t *testing.T) { ` ls.cache.SetFileContents(mainRegoURI, content) - bs, err := json.Marshal(&types.DocumentFormattingParams{ + params := types.DocumentFormattingParams{ TextDocument: types.TextDocumentIdentifier{URI: mainRegoURI}, Options: types.FormattingOptions{}, - }) - if err != nil { - t.Fatalf("failed to marshal document formatting params: %v", err) } - var msg json.RawMessage = bs - - req := &jsonrpc2.Request{Params: &msg} - - res, err := ls.handleTextDocumentFormatting(ctx, connClient, req) + res, err := ls.handleTextDocumentFormatting(ctx, params) if err != nil { t.Fatalf("failed to format document: %s", err) }