From caa32d645f862d26fe1a451ad57a1c28f8d25f97 Mon Sep 17 00:00:00 2001 From: denis_savitsky Date: Fri, 2 Feb 2024 23:51:27 +0100 Subject: [PATCH] Add support for $addToSet --- .../src/main/scala/oolong/AstParser.scala | 24 +++++++++ oolong-core/src/main/scala/oolong/UExpr.scala | 8 +++ .../src/main/scala/oolong/dsl/Dsl.scala | 8 +++ .../oolong/mongo/OolongMongoUpdateSpec.scala | 39 +++++++++++++++ .../oolong/mongo/MongoUpdateCompiler.scala | 49 +++++++++++++++--- .../scala/oolong/mongo/MongoUpdateNode.scala | 6 +++ .../test/scala/oolong/mongo/UpdateSpec.scala | 50 ++++++++++++++++++- 7 files changed, 175 insertions(+), 9 deletions(-) diff --git a/oolong-core/src/main/scala/oolong/AstParser.scala b/oolong-core/src/main/scala/oolong/AstParser.scala index de80d0d..e94078c 100644 --- a/oolong-core/src/main/scala/oolong/AstParser.scala +++ b/oolong-core/src/main/scala/oolong/AstParser.scala @@ -304,6 +304,16 @@ private[oolong] class DefaultAstParser(using quotes: Quotes) extends AstParser { val value = getValue(valueExpr) parseUpdater(updater, FieldUpdateExpr.SetOnInsert(UExpr.Prop(prop), value) :: acc) + case '{type t; ($updater: Updater[Doc]).addToSetAll[`t`, `t`]($selectProp, ($valueExpr: Iterable[`t`]))} => + val prop = parsePropSelector(selectProp) + val value = getValueOrIterable(valueExpr) + parseUpdater(updater, FieldUpdateExpr.AddToSet(UExpr.Prop(prop), value, multipleValues = true) :: acc) + + case '{type t; ($updater: Updater[Doc]).addToSet[`t`, `t`]($selectProp, ($valueExpr: `t` ))} => + val prop = parsePropSelector(selectProp) + val value = getValueOrIterable(valueExpr) + parseUpdater(updater, FieldUpdateExpr.AddToSet(UExpr.Prop(prop), value, multipleValues = false) :: acc) + case '{ $updater: Updater[Doc] } => updater match { case AsTerm(Ident(name)) if name == paramName => @@ -346,6 +356,20 @@ private[oolong] class DefaultAstParser(using quotes: Quotes) extends AstParser { } } + private def getValueOrIterable(expr: Expr[Any]): UExpr = + expr match + case '{ $iter: Iterable[t] } => getIterable(iter) + case base => getValue(base) + + def getIterable[T: Type](expr: Expr[Iterable[T]]): UExpr = + expr match { + // AsIterable can ignore lift e.g. in following case: lift(List(List(Random.nextInt())) + case '{ type t; lift($x: Iterable[`t`]) } => UExpr.ScalaCodeIterable(x) + case AsIterable(elems) => UExpr.UIterable(elems.map(getConstant).toList) + case _ => + report.errorAndAbort("Unexpected expr while parsing AST: " + expr.asTerm.show(using Printer.TreeStructure)) + } + private def getValue(expr: Expr[Any]): UExpr = expr match case '{ lift($x) } => UExpr.ScalaCode(x) diff --git a/oolong-core/src/main/scala/oolong/UExpr.scala b/oolong-core/src/main/scala/oolong/UExpr.scala index ac37491..791aa35 100644 --- a/oolong-core/src/main/scala/oolong/UExpr.scala +++ b/oolong-core/src/main/scala/oolong/UExpr.scala @@ -12,14 +12,18 @@ private[oolong] object UExpr { case class Prop(path: String) extends UExpr case class Constant[T](t: T) extends UExpr + case class UIterable[T](t: List[UExpr]) extends UExpr case class ScalaCode(code: Expr[Any]) extends UExpr + case class ScalaCodeIterable(code: Expr[Iterable[Any]]) extends UExpr + @nowarn("msg=unused explicit parameter") // used in macro sealed abstract class FieldUpdateExpr(prop: Prop) object FieldUpdateExpr { + // field update operators case class Set(prop: Prop, expr: UExpr) extends FieldUpdateExpr(prop: Prop) case class Inc(prop: Prop, expr: UExpr) extends FieldUpdateExpr(prop) @@ -35,6 +39,10 @@ private[oolong] object UExpr { case class Rename(prop: Prop, expr: UExpr) extends FieldUpdateExpr(prop) case class SetOnInsert(prop: Prop, expr: UExpr) extends FieldUpdateExpr(prop) + + // array update operators + case class AddToSet(prop: Prop, expr: UExpr, multipleValues: Boolean) extends FieldUpdateExpr(prop) + } } diff --git a/oolong-core/src/main/scala/oolong/dsl/Dsl.scala b/oolong-core/src/main/scala/oolong/dsl/Dsl.scala index 108d95c..168ad0a 100644 --- a/oolong-core/src/main/scala/oolong/dsl/Dsl.scala +++ b/oolong-core/src/main/scala/oolong/dsl/Dsl.scala @@ -47,4 +47,12 @@ sealed trait Updater[DocT] { def setOnInsert[PropT, ValueT](selectProp: DocT => PropT, value: ValueT)(using PropT =:= ValueT, ): Updater[DocT] = useWithinMacro("setOnInsert") + + def addToSet[PropT, ValueT](selectProp: DocT => Iterable[PropT], value: ValueT)(using + PropT =:= ValueT + ): Updater[DocT] = useWithinMacro("addToSet") + + def addToSetAll[PropT, ValueT](selectProp: DocT => Iterable[PropT], value: Iterable[ValueT])(using + PropT =:= ValueT + ): Updater[DocT] = useWithinMacro("addToSet") } diff --git a/oolong-mongo-it/src/test/scala/oolong/mongo/OolongMongoUpdateSpec.scala b/oolong-mongo-it/src/test/scala/oolong/mongo/OolongMongoUpdateSpec.scala index 5eb3189..711a592 100644 --- a/oolong-mongo-it/src/test/scala/oolong/mongo/OolongMongoUpdateSpec.scala +++ b/oolong-mongo-it/src/test/scala/oolong/mongo/OolongMongoUpdateSpec.scala @@ -6,6 +6,7 @@ import scala.concurrent.ExecutionContext import com.dimafeng.testcontainers.ForAllTestContainer import com.dimafeng.testcontainers.MongoDBContainer import concurrent.duration.DurationInt +import oolong.bson.BsonDecoder import oolong.dsl.* import org.mongodb.scala.MongoClient import org.mongodb.scala.bson.BsonDocument @@ -129,4 +130,42 @@ class OolongMongoUpdateSpec extends AsyncFlatSpec with ForAllTestContainer with } yield assert(res.wasAcknowledged() && res.getMatchedCount == 1 && res.getModifiedCount == 1) } + it should "update with $addToSet" in { + for { + res <- collection + .updateOne( + query[TestClass](_.field1 == "0"), + update[TestClass]( + _.addToSet(_.field4, 3) + ) + ) + .head() + upd <- collection + .find(query[TestClass](_.field1 == "0")) + .head() + .map(BsonDecoder[TestClass].fromBson(_).get) + } yield assert( + res.wasAcknowledged() && res.getModifiedCount == 1 && upd.field4.size == 3 + ) + } + + it should "update with $addToSet using $each" in { + for { + res <- collection + .updateOne( + query[TestClass](_.field1 == "0"), + update[TestClass]( + _.addToSetAll(_.field4, List(4, 5)) + ) + ) + .head() + upd <- collection + .find(query[TestClass](_.field1 == "0")) + .head() + .map(BsonDecoder[TestClass].fromBson(_).get) + } yield assert( + res.wasAcknowledged() && res.getModifiedCount == 1 && upd.field4.size == 4 + ) + } + } diff --git a/oolong-mongo/src/main/scala/oolong/mongo/MongoUpdateCompiler.scala b/oolong-mongo/src/main/scala/oolong/mongo/MongoUpdateCompiler.scala index 5b006cc..a32ca57 100644 --- a/oolong-mongo/src/main/scala/oolong/mongo/MongoUpdateCompiler.scala +++ b/oolong-mongo/src/main/scala/oolong/mongo/MongoUpdateCompiler.scala @@ -7,7 +7,9 @@ import scala.quoted.Type import oolong.* import oolong.UExpr.FieldUpdateExpr import oolong.bson.meta.QueryMeta +import oolong.mongo.MongoUpdateNode.MongoUpdateOp import oolong.mongo.MongoUpdateNode as MU +import org.mongodb.scala.bson.BsonArray import org.mongodb.scala.bson.BsonBoolean import org.mongodb.scala.bson.BsonDocument import org.mongodb.scala.bson.BsonDouble @@ -44,10 +46,14 @@ object MongoUpdateCompiler extends Backend[UExpr, MU, BsonDocument] { case FieldUpdateExpr.SetOnInsert(prop, expr) => MU.MongoUpdateOp.SetOnInsert(MU.Prop(renames.getOrElse(prop.path, prop.path)), rec(expr)) case FieldUpdateExpr.Unset(prop) => MU.MongoUpdateOp.Unset(MU.Prop(renames.getOrElse(prop.path, prop.path))) + case FieldUpdateExpr.AddToSet(prop, expr, each) => + MU.MongoUpdateOp.AddToSet(MU.Prop(renames.getOrElse(prop.path, prop.path)), rec(expr), each) }) - case UExpr.ScalaCode(code) => MU.ScalaCode(code) - case UExpr.Constant(t) => MU.Constant(t) - case _ => report.errorAndAbort("Unexpected expr " + pprint(ast)) + case UExpr.ScalaCode(code) => MU.ScalaCode(code) + case UExpr.Constant(t) => MU.Constant(t) + case UExpr.UIterable(t) => MU.UIterable(t.map(rec(_))) + case UExpr.ScalaCodeIterable(t) => MU.ScalaCodeIterable(t) + case _ => report.errorAndAbort("Unexpected expr " + pprint(ast)) } rec(ast, meta) @@ -82,7 +88,10 @@ object MongoUpdateCompiler extends Backend[UExpr, MU, BsonDocument] { )("$rename"), renderOps( ops.collect { case s: MU.MongoUpdateOp.SetOnInsert => s }.map(op => render(op.prop) + ": " + render(op.value)) - )("$setOnInsert") + )("$setOnInsert"), + renderOps( + ops.collect { case s: MU.MongoUpdateOp.AddToSet => s }.map(renderAddToSet) + )("$addToSet") ).flatten .mkString("{\n", ",\n", "\n}") @@ -100,10 +109,20 @@ object MongoUpdateCompiler extends Backend[UExpr, MU, BsonDocument] { case '{ ${ x }: t } => RenderUtils.renderCaseClass[t](x) case _ => "?" + case MU.UIterable(iterable) => iterable.map(render).mkString("[", ",", "]") + case MU.ScalaCodeIterable(_) => "[ ? ]" + case _ => report.errorAndAbort(s"Wrong term: $query") } - def renderOps(ops: List[String])(op: String) = + private def renderAddToSet(op: MU.MongoUpdateOp.AddToSet)(using Quotes): String = + val renderOfValue = render(op.value) + val finalRenderOfValue = + if op.each then s"""{ "$$each" : $renderOfValue }""" + else renderOfValue + render(op.prop) + ": " + finalRenderOfValue + + private def renderOps(ops: List[String])(op: String) = ops match case Nil => None case list => Some(s"\t \"$op\": { " + list.mkString(", ") + " }") @@ -112,10 +131,15 @@ object MongoUpdateCompiler extends Backend[UExpr, MU, BsonDocument] { import quotes.reflect.* def targetOps(setters: List[MU.MongoUpdateOp]): List[Expr[(String, BsonValue)]] = - setters.map { case op: MU.MongoUpdateOp => + setters.map { op => val key = op.prop.path val valueExpr = handleValues(op.value) - '{ ${ Expr(key) } -> $valueExpr } + val finalValueExpr = op match + case addToSet: MongoUpdateOp.AddToSet => + if addToSet.each then '{ BsonDocument("$each" -> $valueExpr) } + else valueExpr + case _ => valueExpr + '{ ${ Expr(key) } -> $finalValueExpr } } optRepr match { @@ -128,6 +152,7 @@ object MongoUpdateCompiler extends Backend[UExpr, MU, BsonDocument] { val tMuls = targetOps(ops.collect { case s: MU.MongoUpdateOp.Mul => s }) val tRenames = targetOps(ops.collect { case s: MU.MongoUpdateOp.Rename => s }) val tSetOnInserts = targetOps(ops.collect { case s: MU.MongoUpdateOp.SetOnInsert => s }) + val tAddToSets = targetOps(ops.collect { case s: MU.MongoUpdateOp.AddToSet => s }) // format: off def updaterGroup(groupName: String, updaters: List[Expr[(String, BsonValue)]]): Option[Expr[(String, BsonDocument)]] = @@ -147,6 +172,7 @@ object MongoUpdateCompiler extends Backend[UExpr, MU, BsonDocument] { updaterGroup("$mul", tMuls), updaterGroup("$rename", tRenames), updaterGroup("$setOnInsert", tSetOnInserts), + updaterGroup("$addToSet", tAddToSets), ).flatten '{ @@ -181,6 +207,13 @@ object MongoUpdateCompiler extends Backend[UExpr, MU, BsonDocument] { case MU.Constant(b: Boolean) => '{ BsonBoolean.apply(${ Expr(b: Boolean) }) } case MU.ScalaCode(code) => BsonUtils.extractLifted(code) - case _ => report.errorAndAbort(s"Given type is not literal constant") + case MU.UIterable(list) => + '{ + BsonArray.fromIterable(${ + Expr.ofList(list.map(handleValues)) + }) + } + case MU.ScalaCodeIterable(exprList) => BsonUtils.extractLifted(exprList) + case _ => report.errorAndAbort(s"Given type is not literal constant") } } diff --git a/oolong-mongo/src/main/scala/oolong/mongo/MongoUpdateNode.scala b/oolong-mongo/src/main/scala/oolong/mongo/MongoUpdateNode.scala index 73571ff..7cea788 100644 --- a/oolong-mongo/src/main/scala/oolong/mongo/MongoUpdateNode.scala +++ b/oolong-mongo/src/main/scala/oolong/mongo/MongoUpdateNode.scala @@ -13,8 +13,12 @@ case object MongoUpdateNode { case class Constant[T](t: T) extends MU + case class UIterable[T](t: List[MU]) extends MU + case class ScalaCode(code: Expr[Any]) extends MU + case class ScalaCodeIterable(code: Expr[Iterable[Any]]) extends MU + sealed abstract class MongoUpdateOp(val prop: Prop, val value: MU) extends MU object MongoUpdateOp { case class Set(override val prop: Prop, override val value: MU) extends MongoUpdateOp(prop, value) @@ -25,5 +29,7 @@ case object MongoUpdateNode { case class Mul(override val prop: Prop, override val value: MU) extends MongoUpdateOp(prop, value) case class Rename(override val prop: Prop, override val value: MU) extends MongoUpdateOp(prop, value) case class SetOnInsert(override val prop: Prop, override val value: MU) extends MongoUpdateOp(prop, value) + + case class AddToSet(override val prop: Prop, override val value: MU, each: Boolean) extends MongoUpdateOp(prop, value) } } diff --git a/oolong-mongo/src/test/scala/oolong/mongo/UpdateSpec.scala b/oolong-mongo/src/test/scala/oolong/mongo/UpdateSpec.scala index 40d2294..e1e860b 100644 --- a/oolong-mongo/src/test/scala/oolong/mongo/UpdateSpec.scala +++ b/oolong-mongo/src/test/scala/oolong/mongo/UpdateSpec.scala @@ -23,7 +23,10 @@ class UpdateSpec extends AnyFunSuite { dateField: LocalDate, innerClassField: InnerClass, optionField: Option[Long], - optionInnerClassField: Option[InnerClass] + optionInnerClassField: Option[InnerClass], + listField: List[Int], + classInnerClassField: List[InnerClass], + nestedListField: List[List[Int]] ) case class InnerClass( @@ -170,6 +173,51 @@ class UpdateSpec extends AnyFunSuite { ) } + test("$addToSet") { + val q = update[TestClass](_.addToSet(_.listField, 1)) + val repr = renderUpdate[TestClass](_.addToSet(_.listField, 1)) + test( + q, + repr, + BsonDocument("$addToSet" -> BsonDocument("listField" -> BsonInt32(1))) + ) + } + + test("$addToSet nested") { + val q = update[TestClass](_.addToSet(_.nestedListField, List(1, 2, 3))) + val repr = renderUpdate[TestClass](_.addToSet(_.nestedListField, List(1, 2, 3))) + test( + q, + repr, + BsonDocument("$addToSet" -> BsonDocument("nestedListField" -> BsonArray(BsonInt32(1), BsonInt32(2), BsonInt32(3)))) + ) + } + + test("$addToSet with $each") { + val q = update[TestClass](_.addToSetAll(_.listField, List(1))) + val repr = renderUpdate[TestClass](_.addToSetAll(_.listField, List(1))) + test( + q, + repr, + BsonDocument("$addToSet" -> BsonDocument("listField" -> BsonDocument("$each" -> BsonArray(BsonInt32(1))))) + ) + } + + test("$addToSet with $each nested") { + val q = update[TestClass](_.addToSetAll(_.nestedListField, lift(List(List(1, 2, 3))))) + val repr = renderUpdate[TestClass](_.addToSetAll(_.nestedListField, lift(List(List(1, 2, 3))))) + test( + q, + repr, + BsonDocument( + "$addToSet" -> BsonDocument( + "nestedListField" -> BsonDocument("$each" -> BsonArray(BsonArray(BsonInt32(1), BsonInt32(2), BsonInt32(3)))) + ) + ), + ignoreRender = true + ) + } + test("several update operators combined") { val q = update[TestClass]( _.unset(_.dateField)