diff --git a/src/main/java/net/objecthunter/exp4j/Expression.java b/src/main/java/net/objecthunter/exp4j/Expression.java index 4777a485..59e952ab 100644 --- a/src/main/java/net/objecthunter/exp4j/Expression.java +++ b/src/main/java/net/objecthunter/exp4j/Expression.java @@ -91,8 +91,8 @@ public ValidationResult validate(boolean checkVariablesSet) { break; case Token.TOKEN_FUNCTION: final Function func = ((FunctionToken) tok).getFunction(); - final int argsNum = func.getNumArguments(); - if (argsNum > count) { + final int argsNum = ((FunctionToken) tok).getArgumentCount(); + if (func.getMinNumArguments() > argsNum || func.getMaxNumArguments() < argsNum) { errors.add("Not enough arguments for '" + func.getName() + "'"); } if (argsNum > 1) { @@ -161,12 +161,15 @@ public double evaluate() { } } else if (t.getType() == Token.TOKEN_FUNCTION) { FunctionToken func = (FunctionToken) t; - if (output.size() < func.getFunction().getNumArguments()) { + int functionArgs = func.getArgumentCount(); + if (functionArgs < func.getFunction().getMinNumArguments() || functionArgs > func.getFunction().getMaxNumArguments() || output.isEmpty()) { throw new IllegalArgumentException("Invalid number of arguments available for '" + func.getFunction().getName() + "' function"); } /* collect the arguments from the stack */ - double[] args = new double[func.getFunction().getNumArguments()]; - for (int j = 0; j < func.getFunction().getNumArguments(); j++) { + double[] args = new double[functionArgs]; + + + for (int j = 0; j < functionArgs ; j++) { args[j] = output.pop(); } output.push(func.getFunction().apply(this.reverseInPlace(args))); diff --git a/src/main/java/net/objecthunter/exp4j/function/Function.java b/src/main/java/net/objecthunter/exp4j/function/Function.java index 79ebfded..7deefb41 100644 --- a/src/main/java/net/objecthunter/exp4j/function/Function.java +++ b/src/main/java/net/objecthunter/exp4j/function/Function.java @@ -23,24 +23,26 @@ public abstract class Function { protected final String name; - protected final int numArguments; + protected final int minArguments; + protected final int maxArguments; /** * Create a new Function with a given name and number of arguments * * @param name the name of the Function - * @param numArguments the number of arguments the function takes + * @param minArguments the number of arguments the function takes */ - public Function(String name, int numArguments) { - if (numArguments < 0) { - throw new IllegalArgumentException("The number of function arguments can not be less than 0 for '" + + public Function(String name, int minArguments, int maxArguments) { + if (minArguments < 0 || minArguments > maxArguments || maxArguments > Integer.MAX_VALUE) { + throw new IllegalArgumentException("The number of function arguments can not be less than 0 or more than " +Integer.MAX_VALUE+" for '" + name + "'"); } if (!isValidFunctionName(name)) { throw new IllegalArgumentException("The function name '" + name + "' is invalid"); } this.name = name; - this.numArguments = numArguments; + this.minArguments = minArguments; + this.maxArguments = maxArguments; } @@ -50,7 +52,26 @@ public Function(String name, int numArguments) { * @param name the name of the Function */ public Function(String name) { - this(name, 1); + this(name, 1,1); + } + + + public Function(String name, int numArguments) { + this(name, numArguments,numArguments); + } + + /** + * Get the number of arguments of a function with fixed arguments length. + * This function may be called only on functions with a fixed number of arguments and will throw an @UnsupportedOperationException otherwise. + * When using functions with variable arguments length use @getMaxNumArguments and @getMinNumArguments instead. + * + * @return the number of arguments + */ + public int getNumArguments() { + if (minArguments != maxArguments) { + throw new UnsupportedOperationException("Calling getNumArgument() is not supported for var arg functions, please use getMaxNumArguments() or getMinNumArguments()"); + } + return minArguments; } /** @@ -67,10 +88,16 @@ public String getName() { * * @return the number of arguments */ - public int getNumArguments() { - return numArguments; + public int getMinNumArguments() { + return minArguments; } + public int getMaxNumArguments() { + return maxArguments; + } + + + /** * Method that does the actual calculation of the function value given the arguments * diff --git a/src/main/java/net/objecthunter/exp4j/shuntingyard/ShuntingYard.java b/src/main/java/net/objecthunter/exp4j/shuntingyard/ShuntingYard.java index b9754819..5fbd11bb 100644 --- a/src/main/java/net/objecthunter/exp4j/shuntingyard/ShuntingYard.java +++ b/src/main/java/net/objecthunter/exp4j/shuntingyard/ShuntingYard.java @@ -19,6 +19,7 @@ import net.objecthunter.exp4j.function.Function; import net.objecthunter.exp4j.operator.Operator; +import net.objecthunter.exp4j.tokenizer.FunctionToken; import net.objecthunter.exp4j.tokenizer.OperatorToken; import net.objecthunter.exp4j.tokenizer.Token; import net.objecthunter.exp4j.tokenizer.Tokenizer; @@ -40,7 +41,7 @@ public static Token[] convertToRPN(final String expression, final Map userOperators, final Set variableNames){ final Stack stack = new Stack(); final List output = new ArrayList(); - + final Stack functionTokenStack = new Stack(); final Tokenizer tokenizer = new Tokenizer(expression, userFunctions, userOperators, variableNames); while (tokenizer.hasNext()) { Token token = tokenizer.nextToken(); @@ -50,9 +51,12 @@ public static Token[] convertToRPN(final String expression, final Map max) { max = values[i]; } @@ -488,7 +488,7 @@ public void testFunction13() throws Exception { @Override public double apply(double... values) { double max = values[0]; - for (int i = 1; i < numArguments; i++) { + for (int i = 1; i < values.length; i++) { if (values[i] > max) { max = values[i]; } @@ -633,7 +633,7 @@ public void testFunction20() throws Exception { @Override public double apply(double... values) { double max = values[0]; - for (int i = 1; i < numArguments; i++) { + for (int i = 1; i < values.length; i++) { if (values[i] > max) { max = values[i]; } @@ -644,7 +644,7 @@ public double apply(double... values) { ExpressionBuilder b = new ExpressionBuilder("max(1,2,3)") .function(maxFunction); double calculated = b.build().evaluate(); - assertTrue(maxFunction.getNumArguments() == 3); + assertEquals(maxFunction.getMaxNumArguments(),maxFunction.getMinNumArguments(), 3); assertTrue(calculated == 3); } diff --git a/src/test/java/net/objecthunter/exp4j/FunctionsWithVariableArgsTest.java b/src/test/java/net/objecthunter/exp4j/FunctionsWithVariableArgsTest.java new file mode 100644 index 00000000..bbfa0489 --- /dev/null +++ b/src/test/java/net/objecthunter/exp4j/FunctionsWithVariableArgsTest.java @@ -0,0 +1,196 @@ +package net.objecthunter.exp4j; + +import net.objecthunter.exp4j.function.Function; +import org.junit.Before; +import org.junit.Test; + +import java.util.LinkedList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + + +public class FunctionsWithVariableArgsTest { + List variableArgsFunctions = new LinkedList(); + @Before + public void fillWithSomeFunctions(){ + Function avg = new Function("avg", 1,100) { + + @Override + public double apply(double... args) { + double sum = 0; + for (double arg : args) { + sum += arg; + } + return sum / args.length; + } + }; + + Function sum = new Function("sum", 1,100) { + + @Override + public double apply(double... args) { + double sum = 0; + for (double arg : args) { + sum += arg; + } + return sum; + } + }; + + Function max = new Function("max", 1,100) { + + @Override + public double apply(double... args) { + double max = args[0]; + for (double arg : args) { + if (arg > max) + max = arg; + } + return max; + } + }; + + Function min = new Function("min", 1,100) { + + @Override + public double apply(double... args) { + double min = args[0]; + for (double arg : args) { + if (arg < min) + min = arg; + } + return min; + } + }; + + Function atLeast4atMax8 = new Function("atLeast4atMax8" ,4,8) { + @Override + public double apply(double... args) { + return 0; + } + }; + + variableArgsFunctions.add(min); + variableArgsFunctions.add(max); + variableArgsFunctions.add(avg); + variableArgsFunctions.add(sum); + variableArgsFunctions.add(atLeast4atMax8); + } + + @Test + public void testFunctionsWithVariableArgs1() throws Exception { + double result,expected; + result = new ExpressionBuilder("avg(1,2,3,4)") + .functions(variableArgsFunctions) + .build() + .evaluate(); + + expected = 2.5d; + assertEquals(expected, result, 0d); + + result = new ExpressionBuilder("avg(1,1,1,1,1,1,1,1,1)") + .functions(variableArgsFunctions) + .build() + .evaluate(); + + expected = 1.0d; + assertEquals(expected, result, 0d); + + } + + @Test + public void testFunctionsWithVariableArgs2() throws Exception { + double result,expected; + result = new ExpressionBuilder("min(1,1,1,1,1,1,1,1)") + .functions(variableArgsFunctions) + .build() + .evaluate(); + + expected = 1.0d; + assertEquals(expected, result, 0d); + + result = new ExpressionBuilder("min(1,1,1,1,1,1,1)") + .functions(variableArgsFunctions) + .build() + .evaluate(); + + expected = 1.0d; + assertEquals(expected, result, 0d); + + result = new ExpressionBuilder("min(1,1,1,1)") + .functions(variableArgsFunctions) + .build() + .evaluate(); + + expected = 1.0d; + assertEquals(expected, result, 0d); + + result = new ExpressionBuilder("max(1,1,1,1,1,1,1,1)") + .functions(variableArgsFunctions) + .build() + .evaluate(); + + expected = 1.0d; + assertEquals(expected, result, 0d); + + result = new ExpressionBuilder("max(1,1,1,1,1,1)") + .functions(variableArgsFunctions) + .build() + .evaluate(); + + expected = 1.0d; + assertEquals(expected, result, 0d); + + } + + @Test + public void testFunctionsWithVariableArgs3(){ + Expression e = new ExpressionBuilder("sum(1,avg(1.11,-3.14,2.03))") + .functions(variableArgsFunctions) + .build(); + assertEquals(1.0d, e.evaluate(), 1.0d); + } + + @Test(expected = IllegalArgumentException.class) + public void testFunctionsWithVariableArgs4(){ + new ExpressionBuilder("5 + atLeast4atMax8(1)") + .functions(variableArgsFunctions) + .build() + .evaluate(); + } + + @Test(expected = IllegalArgumentException.class) + public void testFunctionsWithVariableArgs5(){ + new ExpressionBuilder("5 + atLeast4atMax8(2,3)") + .functions(variableArgsFunctions) + .build() + .evaluate(); + } + + @Test(expected = IllegalArgumentException.class) + public void testFunctionsWithVariableArgs6(){ + new ExpressionBuilder("5 + atLeast4atMax8(2,3,3)") + .functions(variableArgsFunctions) + .build() + .evaluate(); + } + + @Test + public void testFunctionsWithVariableArgs7(){ + double result = new ExpressionBuilder("5 + atLeast4atMax8(2,3,3,2,4,max(1,2,3))") + .functions(variableArgsFunctions) + .build() + .evaluate(); + assertEquals(5, result, 0.0); + } + + @Test(expected = IllegalArgumentException.class) + public void testFunctionsWithVariableArgs8(){ + new ExpressionBuilder("5 + atLeast4atMax8(1,2,3,4,5,6,7,8,9,10)") + .functions(variableArgsFunctions) + .build() + .evaluate(); + } +} + diff --git a/src/test/java/net/objecthunter/exp4j/TestUtil.java b/src/test/java/net/objecthunter/exp4j/TestUtil.java index dc36a1fe..b5645786 100644 --- a/src/test/java/net/objecthunter/exp4j/TestUtil.java +++ b/src/test/java/net/objecthunter/exp4j/TestUtil.java @@ -40,7 +40,7 @@ public static void assertCloseParenthesesToken(Token token) { public static void assertFunctionToken(Token token, String name, int i) { assertEquals(token.getType(), Token.TOKEN_FUNCTION); FunctionToken f = (FunctionToken) token; - assertEquals(i, f.getFunction().getNumArguments()); + assertEquals(i, f.getFunction().getMaxNumArguments(), f.getFunction().getMinNumArguments()); assertEquals(name, f.getFunction().getName()); }