Skip to content

Commit

Permalink
Enhance trampoline, fix function signature, optimize lambda
Browse files Browse the repository at this point in the history
  • Loading branch information
mmhelloworld committed Dec 27, 2021
1 parent 41951e1 commit f133107
Show file tree
Hide file tree
Showing 18 changed files with 381 additions and 150 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ jobs:
with:
repo_token: "${{ secrets.GITHUB_TOKEN }}"
prerelease: false
title: "Release 0.4.0-rc.2"
title: "Release 0.4.0-rc.3"
files: |
idris-jvm-compiler/target/idris2-0.4.0-SNAPSHOT.zip
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
import java.io.IOException;
import java.io.OutputStream;
import java.lang.reflect.InvocationTargetException;
import java.util.Collection;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Predicate;
import java.util.regex.Pattern;
import java.util.stream.Stream;

import static java.util.Collections.emptyList;
import static java.util.Collections.synchronizedSet;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;

public final class AsmGlobalState {
Expand All @@ -23,16 +28,25 @@ public final class AsmGlobalState {
private final Set<String> untypedFunctions;
private final Set<String> constructors;
private final String programName;
private final Collection<Predicate<String>> trampolinePredicates;
private final Map<String, Assembler> assemblers;

public AsmGlobalState(String programName) {
public AsmGlobalState(String programName, Collection<String> trampolinePatterns) {
this.programName = programName;
this.trampolinePredicates = trampolinePatterns.stream()
.map(Pattern::compile)
.map(Pattern::asPredicate)
.collect(toList());
functions = new ConcurrentHashMap<>();
untypedFunctions = synchronizedSet(new HashSet<>());
constructors = synchronizedSet(new HashSet<>());
assemblers = new ConcurrentHashMap<>();
}

public AsmGlobalState(String programName) {
this(programName, emptyList());
}

public synchronized void addFunction(String name, Object value) {
functions.put(name, value);
}
Expand Down Expand Up @@ -105,4 +119,8 @@ public void writeClass(String className, ClassWriter classWriter, String outputC
}
}

public boolean shouldTrampoline(String name) {
return trampolinePredicates.stream()
.anyMatch(trampolinePredicate -> trampolinePredicate.test(name));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1132,11 +1132,13 @@ public void lineNumber(int lineNumber, String label) {
public void localVariable(String name, String typeDescriptor, String signature, String lineNumberStartLabel,
String lineNumberEndLabel, int index) {
Label start = (Label) env.get(lineNumberStartLabel);
requireNonNull(start, format("Line number start label '%s' for variable %s at index %d must not be null",
lineNumberStartLabel, name, index));
requireNonNull(start,
format("Line number start label '%s' for variable %s at index %d must not be null for method %s/%s",
lineNumberStartLabel, name, index, className, methodName));
Label end = (Label) env.get(lineNumberEndLabel);
requireNonNull(end, format("Line number end label '%s' for variable %s at index %d must not be null",
lineNumberEndLabel, name, index));
requireNonNull(end,
format("Line number end label '%s' for variable %s at index %d must not be null for method %s/%s",
lineNumberEndLabel, name, index, className, methodName));
mv.visitLocalVariable(name, typeDescriptor, signature, start, end, index);
}

Expand Down
4 changes: 2 additions & 2 deletions idris-jvm-compiler/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@
<repositoryName>exec/idris2_app</repositoryName>
<repositoryLayout>flat</repositoryLayout>
<assembleDirectory>${project.parent.basedir}/build</assembleDirectory>
<extraJvmArguments>-Xss36m -Xms3g -Xmx3g</extraJvmArguments>
<extraJvmArguments>-Xss70m -Xms3g -Xmx3g</extraJvmArguments>
<programs>
<program>
<mainClass>idris2.Main</mainClass>
Expand All @@ -194,7 +194,7 @@
<repositoryName>lib</repositoryName>
<repositoryLayout>flat</repositoryLayout>
<assembleDirectory>${project.build.directory}/assembly</assembleDirectory>
<extraJvmArguments>-Xss36m -Xms3g -Xmx3g</extraJvmArguments>
<extraJvmArguments>-Xss70m -Xms3g -Xmx3g</extraJvmArguments>
<programs>
<program>
<mainClass>idris2.Main</mainClass>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package io.github.mmhelloworld.idrisjvm.runtime;

public interface Function3<T1, T2, T3, R> {
R apply(T1 t1, T2 t2, T3 t3);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package io.github.mmhelloworld.idrisjvm.runtime;

public interface Function4<T1, T2, T3, T4, R> {
R apply(T1 t1, T2 t2, T3 t3, T4 t4);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package io.github.mmhelloworld.idrisjvm.runtime;

public interface Function5<T1, T2, T3, T4, T5, R> {
R apply(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package io.github.mmhelloworld.idrisjvm.runtime;

import java.util.function.BiFunction;
import java.util.function.Function;

public final class Functions {
private Functions() {
}

public static final Function<?, ?> IDENTITY = a -> a;

public static final Function<?, Function<?, ?>> IDENTITY_1 = c -> IDENTITY;

public static final Function<?, Function<?, Function<?, ?>>> IDENTITY_2 = d -> IDENTITY_1;

public static final Function<?, Function<?, ?>> CONSTANT = a -> b -> a;

public static final Function<?, Function<?, Function<?, ?>>> CONSTANT_1 = c -> CONSTANT;

public static <T1, T2, R> Function<T1, Function<T2, R>> curry(BiFunction<T1, T2, R> f) {
return t1 -> t2 -> f.apply(t1, t2);
}

public static <T1, T2, T3, R> Function<T1, Function<T2, Function<T3, R>>> curry(Function3<T1, T2, T3, R> f) {
return t1 -> t2 -> t3 -> f.apply(t1, t2, t3);
}

public static <T1, T2, T3, T4, R> Function<T1, Function<T2, Function<T3, Function<T4, R>>>> curry(
Function4<T1, T2, T3, T4, R> f) {
return t1 -> t2 -> t3 -> t4 -> f.apply(t1, t2, t3, t4);
}

public static <T1, T2, T3, T4, T5, R> Function<T1, Function<T2, Function<T3, Function<T4, Function<T5, R>>>>> curry(
Function5<T1, T2, T3, T4, T5, R> f) {
return t1 -> t2 -> t3 -> t4 -> t5 -> f.apply(t1, t2, t3, t4, t5);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.github.mmhelloworld.idrisjvm.runtime;

import static io.github.mmhelloworld.idrisjvm.runtime.Runtime.unwrap;

public final class MemoizedDelayed implements Delayed {
private boolean initialized;
private Delayed delayed;
Expand All @@ -8,7 +10,7 @@ public MemoizedDelayed(Delayed delayed) {
this.delayed = () -> {
synchronized(this) {
if(!initialized) {
Object value = delayed.evaluate();
Object value = unwrap(delayed.evaluate());
this.delayed = () -> value;
initialized = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ public static Thunk createThunk(Object value) {

public static Object unwrap(Object possibleThunk) {
if (possibleThunk instanceof Thunk) {
return ((Thunk) possibleThunk).getObject();
Thunk thunk = (Thunk) possibleThunk;
while (thunk != null && thunk.isRedex()) {
thunk = thunk.evaluate();
}
return thunk == null ? null : thunk.getObject();
} else {
return possibleThunk;
}
Expand Down Expand Up @@ -165,17 +169,17 @@ public static double unwrapDoubleThunk(Object possibleThunk) {
}

public static ForkJoinTask<?> fork(Function<Object, Object> action) {
return commonPool().submit((Runnable) () -> {
return commonPool().submit(() -> {
try {
action.apply(0);
unwrap(action.apply(0));
} catch (Exception e) {
e.printStackTrace();
}
});
}

public static ForkJoinTask<?> fork(Delayed action) {
return commonPool().submit(action::evaluate);
return commonPool().submit(() -> unwrap(action.evaluate()));
}

public static void await(ForkJoinTask<?> task) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.github.mmhelloworld.idrisjvm.runtime;

import java.util.NoSuchElementException;

import static java.util.Objects.requireNonNull;

@FunctionalInterface
Expand All @@ -11,11 +13,7 @@ default boolean isRedex() {
}

default Object getObject() {
Thunk thunk = this;
while (thunk != null && thunk.isRedex()) {
thunk = thunk.evaluate();
}
return thunk == null ? null : thunk.getObject();
throw new NoSuchElementException("Unevaluated thunk");
}

default int getInt() {
Expand Down
33 changes: 25 additions & 8 deletions src/Compiler/Jvm/Asm.idr
Original file line number Diff line number Diff line change
Expand Up @@ -361,12 +361,12 @@ namespace AsmGlobalState

public export
%foreign
"jvm:<init>(String io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState),io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState"
prim_newAsmGlobalState : String -> PrimIO AsmGlobalState
"jvm:<init>(String java/util/Collection io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState),io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState"
prim_newAsmGlobalState : String -> List String -> PrimIO AsmGlobalState

public export
newAsmGlobalState : HasIO io => String -> io AsmGlobalState
newAsmGlobalState programName = primIO $ prim_newAsmGlobalState programName
newAsmGlobalState : HasIO io => String -> List String -> io AsmGlobalState
newAsmGlobalState programName trampolinePatterns = primIO $ prim_newAsmGlobalState programName trampolinePatterns

public export
%foreign jvm' "io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState" ".getAssembler"
Expand Down Expand Up @@ -466,6 +466,15 @@ namespace AsmGlobalState
classCodeEnd state outputDirectory outputFile mainClass =
primIO $ prim_classCodeEnd state outputDirectory outputFile mainClass

public export
%foreign jvm' "io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState" ".shouldTrampoline"
"io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState String" "boolean"
prim_shouldTrampoline : AsmGlobalState -> String -> PrimIO Bool

public export
shouldTrampoline : HasIO io => AsmGlobalState -> String -> io Bool
shouldTrampoline state name = primIO $ prim_shouldTrampoline state name

public export
record AsmState where
constructor MkAsmState
Expand Down Expand Up @@ -1162,14 +1171,21 @@ export
getVariableType : String -> Asm InferredType
getVariableType name = getVariableTypeAtScope !getCurrentScopeIndex name

updateArgumentsForUntyped : Map Int InferredType -> Nat -> IO ()
updateArgumentsForUntyped _ Z = pure ()
updateArgumentsForUntyped types (S n) = do
ignore $ Map.put types (cast n) inferredObjectType
updateArgumentsForUntyped types n

export
updateScopeVariableTypes : Asm ()
updateScopeVariableTypes = go (scopeCounter !GetState - 1) where
updateScopeVariableTypes : Nat -> Asm ()
updateScopeVariableTypes arity = go (scopeCounter !GetState - 1) where
go : Int -> Asm ()
go scopeIndex =
if scopeIndex < 0 then Pure ()
else do
variableTypes <- retrieveVariableTypesAtScope scopeIndex
when (scopeIndex == 0) $ LiftIo $ updateArgumentsForUntyped variableTypes arity
variableIndices <- retrieveVariableIndicesByName scopeIndex
scope <- getScope scopeIndex
saveScope $ record {allVariableTypes = variableTypes, allVariableIndices = variableIndices} scope
Expand Down Expand Up @@ -1201,7 +1217,7 @@ addVariableType var ty = do
%inline
export
lambdaMaxCountPerMethod: Int
lambdaMaxCountPerMethod = 25
lambdaMaxCountPerMethod = 50

export
getLambdaImplementationMethodName : String -> Asm Jname
Expand Down Expand Up @@ -1634,7 +1650,8 @@ runAsm state (ClassCodeStart version access className sig parent intf anns) = as
the (JList String) $ believe_me intf, the (JList JAnnotation) $ believe_me janns]

runAsm state (CreateClass opts) =
assemble state $ jvmInstance () "io/github/mmhelloworld/idrisjvm/assembler/Assembler.createClass" [toJClassOpts opts]
assemble state $ jvmInstance () "io/github/mmhelloworld/idrisjvm/assembler/Assembler.createClass"
[assembler state, toJClassOpts opts]
runAsm state (CreateField accs sourceFileName className fieldName desc sig fieldInitialValue) = assemble state $ do
let jaccs = sum $ accessNum <$> accs
jvmInstance () "io/github/mmhelloworld/idrisjvm/assembler/Assembler.createField"
Expand Down
Loading

0 comments on commit f133107

Please sign in to comment.