Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test costs #176

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ libraryDependencies ++= Seq(
"org.scalatest" %% "scalatest" % "2.2.4" % "test"
)

val MONTAGUE_COMMIT_SHA = "b451836235cee5d900ec5f578a54ac702587858b"
lazy val montague = RootProject(uri(s"git://github.com/Workday/upshot-montague.git#$MONTAGUE_COMMIT_SHA"))
lazy val root = Project("root", file(".")) dependsOn montague



enablePlugins(JavaAppPackaging)
enablePlugins(BuildInfoPlugin)

Expand Down
5 changes: 0 additions & 5 deletions project/Build.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
import sbt.{Build => SbtBuild, _}

object Build extends SbtBuild {
val MONTAGUE_COMMIT_SHA = "b451836235cee5d900ec5f578a54ac702587858b"

lazy val root = Project("root", file(".")) dependsOn montague

lazy val montague = RootProject(uri(s"git://github.com/Workday/upshot-montague.git#$MONTAGUE_COMMIT_SHA"))
}
2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version=0.13.9
sbt.version=0.13.18
176 changes: 176 additions & 0 deletions src/main/scala/wordbots/CostEstimator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package wordbots

import wordbots.Semantics._

/*
* scoring: good is positive, bad is negative. applying to enemy multiplies by -1.
* */

object CostEstimator {
def estimateCost(node: AstNode, mode:Option[String]): String =
{println("\n\n---new cost---");(baseCost(mode)(genericEstimate(node))).toString}

private def baseCost(mode:Option[String]) : Float=>Float ={
mode match{
case Some("Object") => (x => x)
case Some("Event") => (x => x)
case _ => (x => x)
}
}

// scalastyle:off method.length
// scalastyle:off cyclomatic.complexity
//scalastyle:off magic.number
//estimate the cost of an AST node
private def astEst(node: AstNode): Float = {
node match{
// Meta
case If(condition, action) => 1 * childCosts(node).product
case MultipleActions(actions) => 1 * actions.map(child=>genericEstimate(child)).sum
case MultipleAbilities(abilities) => 1 * abilities.map(child=>genericEstimate(child)).sum
case Until(TurnsPassed(num), action) => 1 * childCosts(node).product

// Actions: Normal
case CanAttackAgain(target) => 1 * childCosts(node).product
case CanMoveAgain(target) => 1 * childCosts(node).product
case CanMoveAndAttackAgain(target) => 1 * childCosts(node).product
case DealDamage(target, num) => -1 * childCosts(node).product
case Destroy(target) => -2 * childCosts(node).product
case Discard(target) => -1 * childCosts(node).product
case Draw(target,num) => 1 * childCosts(node).product
case EndTurn => 1
case GiveAbility(target, ability) => 1 * childCosts(node).product
case ModifyAttribute(target, attr, op)=>1 * childCosts(node).product
case ModifyEnergy(target, op) => 1 * childCosts(node).product
case MoveObject(target, dest) => 1 * childCosts(node).product
case PayEnergy(target, amount) => -0.5f * childCosts(node).product
case RemoveAllAbilities(target) => 1 * childCosts(node).product
case RestoreAttribute(target, Health, Some(num)) => 1 * childCosts(node).product
case RestoreAttribute(target, Health, None) => 1 * childCosts(node).product
case ReturnToHand(target, player) => 1 * childCosts(node).product
case SetAttribute(target, attr, num)=> 1 * childCosts(node).product
case SwapAttributes(target, attr1, attr2) => 1 * childCosts(node).product
case TakeControl(player, target) => 2 * childCosts(node).product

// Actions: Utility
case SaveTarget(target) => 1 * childCosts(node).product

// Activated and triggered abilities
case ActivatedAbility(action) => 1 * childCosts(node).product
case TriggeredAbility(trigger, Instead(action)) => 1 * childCosts(node).product
case TriggeredAbility(trigger, action) => 1 * childCosts(node).product

// Passive abilities
case ApplyEffect(target, effect) => 1 * childCosts(node).product
case AttributeAdjustment(target, attr, op) => 1 * childCosts(node).product
case FreezeAttribute(target, attr) => 1 * childCosts(node).product
case HasAbility(target, ability) => 1 * childCosts(node).product

// Effects
case CanOnlyAttack(target) => 1 * childCosts(node).product

// Triggers
case AfterAttack(targetObj, objectType) => 1 * childCosts(node).product
case AfterCardPlay(targetPlayer, cardType) => 1 * childCosts(node).product
case AfterDamageReceived(targetObj) => 1 * childCosts(node).product
case AfterDestroyed(targetObj, cause) => 0.8f * childCosts(node).product
case AfterMove(targetObj) => 2 * childCosts(node).product
case AfterPlayed(targetObj) => 1 * childCosts(node).product
case BeginningOfTurn(targetPlayer) => 2 * childCosts(node).product
case EndOfTurn(targetPlayer) => 2 * childCosts(node).product

// Target objects
case ChooseO(collection) => 1 * childCosts(node).product
case AllO(collection) => 2 * childCosts(node).product
case RandomO(num, collection) => 0.5f * childCosts(node).product
case ThisObject => 1
case ItO => 1
case ItP => 1
case That => 1
case They => 1
case SavedTargetObject => 1

// Target cards
case ChooseC(collection) => 1 * childCosts(node).product
case AllC(collection) => 2 * childCosts(node).product
case RandomC(num, collection) => 0.5f * childCosts(node).product

// Target players
case Self => 1
case Opponent => -1
case AllPlayers => 0.8f
case ControllerOf(targetObject) => 1 * childCosts(node).product

// Conditions
case AdjacentTo(obj) => 0.7f * childCosts(node).product//todo: calibrate this with withinDistanceOf
case AttributeComparison(attr, comp)=> 0.8f * childCosts(node).product
case ControlledBy(player) => 0.9f * childCosts(node).product
case HasProperty(property) => 0.8f * childCosts(node).product
case Unoccupied => 1
case WithinDistanceOf(distance, obj)=> 0.8f * childCosts(node).product//todo: scale properly

// Global conditions
case CollectionExists(coll) => 1 * childCosts(node).product
case TargetHasProperty(target, property) => 0.8f * childCosts(node).product

// Arithmetic operations
case Constant(num) => 1 * childCosts(node).product
case Plus(num) => 1 * childCosts(node).product
case Minus(num) => -1 * childCosts(node).product
case Multiply(num) => 2 * childCosts(node).product
case Divide(num, RoundedDown) => 1 * childCosts(node).product
case Divide(num, RoundedUp) => 1 * childCosts(node).product

// Comparisons
case EqualTo(num) => 0.5f * childCosts(node).product
case GreaterThan(num) => 0.9f * childCosts(node).product
case GreaterThanOrEqualTo(num) => 0.9f * childCosts(node).product
case LessThan(num) => 0.9f * childCosts(node).product
case LessThanOrEqualTo(num) => 0.9f * childCosts(node).product


// Numbers
case Scalar(int) => scala.math.pow(childCosts(node).product,1.5f).toFloat //2.0 is too steep.
case AttributeSum(collection, attr) => 1 * childCosts(node).sum
case AttributeValue(obj, attr) => 1 * childCosts(node).sum
case Count(collection) => 1 * childCosts(node).product
case EnergyAmount(player) => 1 * childCosts(node).product

// Collections
case AllTiles => 1
case CardsInHand(player, cardType, conditions) => 1 * childCosts(node).product
case ObjectsMatchingConditions(objType, conditions) => 1 * childCosts(node).product
case Other(collection) => 1 * childCosts(node).product
case TilesMatchingConditions(conditions) => 1 * childCosts(node).product

// Labels
case m: MultiLabel => {println("multilabel. what is it?");1}
case l: Label => 1

case _ => 1 * childCosts(node).product
}
}


//try to calculate a cost for anything and everything
private def genericEstimate(a:Any): Float ={
val v = genericEstimateZ(a);println(a.toString + " has estimate " + v);v
}

private def genericEstimateZ(a:Any): Float = a match{
case n:AstNode => astEst(n)//AST node, complex.
case n:Int => n
case c:Seq[Any] => c.map(child=>genericEstimate(child)).product //multiply sequences?
// scalastyle:off regex
case _=> println("error unknown type in generic estimate."); 0
// scalastyle:on regex
}

//for each child of the node, run genericEstimate()
private def childCosts(node: AstNode) :Iterator[Float]= {
//if(only one child && that child is a seq) return seq?
println("object " + node.toString + "has " + node.productArity + " children.")
node.productIterator.map[Float](child => genericEstimate(child))
}
}

11 changes: 6 additions & 5 deletions src/main/scala/wordbots/Server.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ object Server extends ServerApp {

sealed trait Response
case class ErrorResponse(error: String) extends Response
case class SuccessfulParseResponse(js: String, tokens: Seq[String], version: String = Parser.VERSION) extends Response
case class SuccessfulParseResponse(js: String, tokens: Seq[String], estCost: String, version: String = Parser.VERSION) extends Response
case class FailedParseResponse(error: String, suggestions: Seq[String], unrecognizedTokens: Seq[String]) extends Response
object FailedParseResponse {
def apply(error: ParserError = ParserError("Parse failed"), unrecognizedTokens: Seq[String] = Seq()): FailedParseResponse = {
Expand Down Expand Up @@ -59,6 +59,7 @@ object Server extends ServerApp {
)
})
}

lazy val lexiconTerms: List[String] = lexicon.keys.toList.sorted

val service: HttpService = {
Expand All @@ -75,7 +76,7 @@ object Server extends ServerApp {
format match {
case Some("js") =>
CodeGenerator.generateJS(ast) match {
case Success(js: String) => successResponse(js, parsedTokens)
case Success(js: String) => successResponse(js, parsedTokens, CostEstimator.estimateCost(ast, mode))
case Failure(ex: Throwable) => errorResponse(ParserError(s"Invalid JavaScript produced: ${ex.getMessage}. Contact the developers."))
}
case Some("svg") => Ok(parse.toSvg, headers(Some("image/svg+xml")))
Expand All @@ -92,7 +93,7 @@ object Server extends ServerApp {
val parseResponse: Response = parseMemoized((req.input, Option(req.mode))) match {
case SuccessfulParse(_, ast, parsedTokens) =>
CodeGenerator.generateJS(ast) match {
case Success(js: String) => SuccessfulParseResponse(js, parsedTokens)
case Success(js: String) => SuccessfulParseResponse(js, parsedTokens, CostEstimator.estimateCost(ast,Option(req.mode)))
case Failure(ex: Throwable) => FailedParseResponse(ParserError(s"Invalid JavaScript produced: ${ex.getMessage}. Contact the developers."))
}
case FailedParse(error, unrecognizedTokens) => FailedParseResponse(error, unrecognizedTokens)
Expand Down Expand Up @@ -144,8 +145,8 @@ object Server extends ServerApp {
}
}

def successResponse(js: String, parsedTokens: Seq[String] = Seq()): Task[H4sResponse] = {
Ok(SuccessfulParseResponse(js, parsedTokens).asJson, headers())
def successResponse(js: String, parsedTokens: Seq[String] = Seq(), estCost: String): Task[H4sResponse] = {
Ok(SuccessfulParseResponse(js, parsedTokens, estCost).asJson, headers())
}

def errorResponse(error: ParserError = ParserError("Parse failed"), unrecognizedTokens: Seq[String] = Seq()): Task[H4sResponse] = {
Expand Down