From 39e349e42019b7b79ecedb0b1e9e1e512ac0bada Mon Sep 17 00:00:00 2001 From: Katarzyna Marek Date: Thu, 20 Jul 2023 17:04:52 +0200 Subject: [PATCH] bugfix: adjust parsing `ScalafmtConfig` for `StandardConvention` layout --- .../meta/internal/metals/ScalafmtConfig.scala | 88 +++++++++++++++---- .../scala/tests/ScalafmtConfigSuite.scala | 30 +++++++ 2 files changed, 103 insertions(+), 15 deletions(-) diff --git a/metals/src/main/scala/scala/meta/internal/metals/ScalafmtConfig.scala b/metals/src/main/scala/scala/meta/internal/metals/ScalafmtConfig.scala index 60f504a45d3..e984df242c3 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/ScalafmtConfig.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/ScalafmtConfig.scala @@ -8,6 +8,7 @@ import scala.meta.internal.semver.SemVer import scala.meta.io.AbsolutePath import com.typesafe.config.Config +import com.typesafe.config.ConfigException import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigParseOptions import com.typesafe.config.ConfigSyntax @@ -15,7 +16,7 @@ import com.typesafe.config.parser.ConfigDocument import com.typesafe.config.parser.ConfigDocumentFactory /** - * A partial repersentation of scalafmt config format. + * A partial representation of scalafmt config format. * Includes only settings that affect dialect. */ case class ScalafmtConfig( @@ -150,24 +151,42 @@ object ScalafmtConfig { def getFileOverrides( conf: Config - ): Try[List[(PathMatcher, ScalafmtDialect)]] = { + ): Try[ + (List[(PathMatcher, ScalafmtDialect)], Map[String, ScalafmtDialect]) + ] = { Try { if (conf.hasPath("fileOverride")) { val obj = conf.getObject("fileOverride") val asConfig = obj.toConfig() val keys = obj.keySet().asScala - keys.toList - .map { key => - val quotedKey = '"' + key + '"' - val innerCfg = asConfig.getConfig(quotedKey) - val dialect = getRunnerDialectRaw(innerCfg) - key -> dialect - } - .collect { case (glob, Some(dialect)) => - val matcher = PathMatcher.Nio(glob) - matcher -> dialect + val (langs, globs) = + keys.toList + .flatMap { key => + val quotedKey = '"' + key + '"' + val dialect = + try { + val innerCfg = asConfig.getConfig(quotedKey) + getRunnerDialectRaw(innerCfg) + } catch { + case _: ConfigException.WrongType => + val dialect = asConfig.getString(quotedKey) + ScalafmtDialect.fromString(dialect) + } + dialect.map(key -> _) + } + .partition(_._1.startsWith("lang:")) + + val overrides = + globs.map { case (key, dialect) => + val glob = if (key.startsWith(".")) s"glob:**$key" else key + PathMatcher.Nio(glob) -> dialect } - } else List.empty + val langOverrides = + langs.map { case (lang, dialect) => + lang.stripPrefix("lang:") -> dialect + }.toMap + (overrides, langOverrides) + } else (List.empty, Map.empty) } } @@ -197,10 +216,25 @@ object ScalafmtConfig { ): Try[List[PathMatcher]] = readMatchers(s"project.$key")(v => PathMatcher.Nio(v)) + def getLangOverrides( + config: Config, + langMap: Map[String, ScalafmtDialect], + ): Try[List[(PathMatcher, ScalafmtDialect)]] = + Try { + if (config.hasPath("project.layout")) { + config.getString("project.layout") match { + case "StandardConvention" => + StandardConvention.langOverrides(langMap) + case _ => Nil + } + } else Nil + } + for { version <- getVersion(config) runnerDialect <- getRunnerDialect(config) - overrides <- getFileOverrides(config) + (overrides, langOverrides) <- getFileOverrides(config) + langOverridesWithDefaults <- getLangOverrides(config, langOverrides) includeFilters <- filters("includeFilters") excludeFilters <- filters("excludeFilters") includePaths <- paths("includePaths") @@ -208,7 +242,31 @@ object ScalafmtConfig { } yield { val include = includePaths ++ includeFilters val exclude = excludeFilters ++ excludePaths - ScalafmtConfig(version, runnerDialect, overrides, include, exclude) + val allOverrides = overrides ++ langOverridesWithDefaults + ScalafmtConfig(version, runnerDialect, allOverrides, include, exclude) + } + } +} + +object StandardConvention { + def langOverrides( + langMap: Map[String, ScalafmtDialect] + ): List[(PathMatcher, ScalafmtDialect)] = { + val modules = List("main", "test", "it") + val defaults = + List( + ("scala-2.11", ScalafmtDialect.Scala211), + ("scala-2.12", ScalafmtDialect.Scala212), + ("scala-2.13", ScalafmtDialect.Scala213), + ("scala-3", ScalafmtDialect.Scala3), + ("scala-2", ScalafmtDialect.Scala3), + ) + + defaults.flatMap { case (language, default) => + val dialect = langMap.get(language).getOrElse(default) + modules.map { module => + (PathMatcher.Nio(s"glob:**/src/$module/$language/**"), dialect) + } } } } diff --git a/tests/unit/src/test/scala/tests/ScalafmtConfigSuite.scala b/tests/unit/src/test/scala/tests/ScalafmtConfigSuite.scala index bd3848031e0..a64022f9d52 100644 --- a/tests/unit/src/test/scala/tests/ScalafmtConfigSuite.scala +++ b/tests/unit/src/test/scala/tests/ScalafmtConfigSuite.scala @@ -7,9 +7,11 @@ import scala.meta.internal.metals.PathMatcher import scala.meta.internal.metals.ScalafmtConfig import scala.meta.internal.metals.ScalafmtDialect import scala.meta.internal.semver.SemVer +import scala.meta.io.AbsolutePath import com.typesafe.config.ConfigFactory import munit.TestOptions +import munit.internal.io.PlatformIO class ScalafmtConfigSuite extends BaseSuite { @@ -133,6 +135,34 @@ class ScalafmtConfigSuite extends BaseSuite { ), ) + test("v3.2.0") { + val cfg = ConfigFactory.parseString( + s"""|version = "3.2.0" + |runner.dialect = scala3 + |project.layout = StandardConvention + |fileOverride { + | "lang:scala-2" = scala211 + |} + |""".stripMargin + ) + val config = ScalafmtConfig.parse(cfg).get + val root = AbsolutePath(PlatformIO.Paths.get(".")) + + def assertDialectFor(path: String, dialect: ScalafmtDialect) = + assertEquals( + config.overrideFor(root.resolve(path)).orElse(config.runnerDialect), + Some(dialect), + ) + + assertDialectFor("src/main/scala-2.12/Main.scala", ScalafmtDialect.Scala212) + assertDialectFor("src/test/scala-3/SomeTest.scala", ScalafmtDialect.Scala3) + assertDialectFor( + "src/main/scala-2/src/Main.scala", + ScalafmtDialect.Scala211, + ) + assertDialectFor("src/main/scala/dir/Main.scala", ScalafmtDialect.Scala3) + } + def checkUpdate( options: TestOptions, config: String,