Skip to content

Commit

Permalink
Reduce boxing by using Idris function types
Browse files Browse the repository at this point in the history
  • Loading branch information
mmhelloworld committed Nov 1, 2024
1 parent 6f2c732 commit 7c70a26
Show file tree
Hide file tree
Showing 13 changed files with 652 additions and 620 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
Expand All @@ -21,10 +20,6 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.jar.Attributes;
import java.util.jar.JarEntry;
import java.util.jar.JarOutputStream;
import java.util.jar.Manifest;
import java.util.stream.IntStream;
import java.util.stream.Stream;

Expand All @@ -33,12 +28,8 @@
import static java.lang.String.format;
import static java.lang.System.lineSeparator;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.nio.file.Files.newInputStream;
import static java.nio.file.Files.newOutputStream;
import static java.nio.file.Files.setPosixFilePermissions;
import static java.util.Objects.requireNonNull;
import static java.util.jar.Attributes.Name.MAIN_CLASS;
import static java.util.jar.Attributes.Name.MANIFEST_VERSION;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static org.objectweb.asm.ClassWriter.COMPUTE_FRAMES;
Expand Down Expand Up @@ -206,17 +197,6 @@ public Assembler() {
this.env = new HashMap<>();
}

public static void createJar(String directory, String fileName, String mainClass) throws IOException {
String jarFileName = fileName + ".jar";
File jarFile = new File(directory, jarFileName);
jarFile.delete();
try (JarOutputStream target =
new JarOutputStream(newOutputStream(jarFile.toPath()), createManifest(mainClass))) {
File sourceDirectory = new File(directory);
add(sourceDirectory, target, jarFile, sourceDirectory);
}
}

public static void createExecutable(String directoryName, String fileName, String mainClass) throws IOException {
String javaOptsProp = System.getProperty("JAVA_OPTS", System.getenv("JAVA_OPTS"));
String javaOpts = javaOptsProp == null ? "-Xss8m -Xms2g -Xmx3g" : javaOptsProp;
Expand Down Expand Up @@ -257,74 +237,6 @@ private static byte[] createExecutableFileContent(String... lines) {
return String.join(lineSeparator(), lines).getBytes(UTF_8);
}

private static Manifest createManifest(String mainClass) {
Manifest manifest = new Manifest();
Attributes manifestAttributes = manifest.getMainAttributes();
manifestAttributes.put(MANIFEST_VERSION, "1.0");
manifestAttributes.put(MAIN_CLASS, mainClass);
return manifest;
}

private static void add(File source, JarOutputStream target, File jarFile, File rootDirectory) throws IOException {
if (source.isDirectory()) {
addDirectory(source, target, jarFile, rootDirectory);
} else {
addFile(source, target, jarFile, rootDirectory);
}
if (source.isDirectory() || !source.getName().endsWith(".jar")) {
source.delete();
}
}

private static void addFile(File source, JarOutputStream jarOutputStream, File jarFile, File rootDirectory)
throws IOException {
if (source.equals(jarFile)) {
return;
}
JarEntry entry = new JarEntry(getJarEntryName(source, rootDirectory));
entry.setTime(source.lastModified());
jarOutputStream.putNextEntry(entry);
try (BufferedInputStream in = new BufferedInputStream(newInputStream(source.toPath()))) {
byte[] buffer = new byte[BUFFER_SIZE];
while (true) {
int count = in.read(buffer);
if (count == -1) {
break;
}
jarOutputStream.write(buffer, 0, count);
}
jarOutputStream.closeEntry();
}
}

private static void addDirectory(File source, JarOutputStream jarOutputStream, File jarFile, File rootDirectory)
throws IOException {
String name = getJarEntryName(source, rootDirectory);
if (!name.isEmpty()) {
createDirectory(name, source.lastModified(), jarOutputStream);
}
File[] files = requireNonNull(source.listFiles(), "Unable to get files from directory " + source);
for (File file : files) {
add(file, jarOutputStream, jarFile, rootDirectory);
}
}

private static String getJarEntryName(File source, File rootDirectory) {
String name = source.getPath().replace(rootDirectory.getPath(), "").replace("\\", "/");
return !name.startsWith("/") ? name : name.substring(1);
}

private static void createDirectory(String name, long lastModified, JarOutputStream jarOutputStream)
throws IOException {
if (!name.endsWith("/")) {
name += "/";
}
JarEntry entry = new JarEntry(name);
entry.setTime(lastModified);
jarOutputStream.putNextEntry(entry);
jarOutputStream.closeEntry();
}

private static Type getType(String typeDescriptor) {
switch (typeDescriptor) {
case "boolean":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ public static int toInt(Object that) {
return 0;
} else if (that instanceof Integer) {
return (int) that;
} else if (that instanceof Thunk) {
return ((Thunk) that).getInt();
} else if (that instanceof BigInteger) {
return ((BigInteger) that).intValue();
} else if (that instanceof Long) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.github.mmhelloworld.idrisjvm.runtime;

import java.io.PrintWriter;
import java.io.StringWriter;
import java.lang.management.ManagementFactory;
import java.nio.channels.Channels;
import java.util.List;
Expand Down Expand Up @@ -118,6 +120,18 @@ public static String getErrorMessage(int errorNumber) {
return "Error code: " + errorNumber;
}

public static String getStackTraceString() {
StackTraceElement[] trace = new Throwable().getStackTrace();
StringWriter stringWriter = new StringWriter();
PrintWriter printWriter = new PrintWriter(stringWriter, true);
for (int index = 1; index < trace.length; index++) {
StackTraceElement traceElement = trace[index];
printWriter.println("\tat " + traceElement);
}
printWriter.flush();
return stringWriter.toString();
}

public static void free(Object object) {
}

Expand Down
2 changes: 1 addition & 1 deletion libs/base/System.idr
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ namespace Escaped

%foreign supportC "idris2_time"
"javascript:lambda:() => Math.floor(new Date().getTime() / 1000)"
"jvm:time(java/lang/Object int),io/github/mmhelloworld/idrisjvm/runtime/IdrisSystem"
"jvm:time(int),io/github/mmhelloworld/idrisjvm/runtime/IdrisSystem"
prim__time : PrimIO Int

||| Return the number of seconds since epoch.
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
</developers>

<properties>
<asm.version>9.4</asm.version>
<asm.version>9.7.1</asm.version>
<assertj.version>3.16.1</assertj.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.scm.id>github</project.scm.id>
Expand Down
115 changes: 50 additions & 65 deletions src/Compiler/Jvm/Asm.idr
Original file line number Diff line number Diff line change
Expand Up @@ -451,31 +451,6 @@ namespace AsmGlobalState
addFunction : HasIO io => AsmGlobalState -> Jname -> Function -> io ()
addFunction globalState name function = jaddFunction globalState (getSimpleName name) function

export
%foreign jvm' "io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState" ".isUntypedFunction"
"io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState String" "boolean"
prim_jisUntypedFunction : AsmGlobalState -> String -> PrimIO Bool

jisUntypedFunction : HasIO io => AsmGlobalState -> String -> io Bool
jisUntypedFunction state name = primIO $ prim_jisUntypedFunction state name

public export
isUntypedFunction : HasIO io => AsmGlobalState -> Jname -> io Bool
isUntypedFunction globalState name = jisUntypedFunction globalState (getSimpleName name)

public export
%foreign jvm' "io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState" ".addUntypedFunction"
"io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState String" "void"
prim_jaddUntypedFunction : AsmGlobalState -> String -> PrimIO ()

public export
jaddUntypedFunction : HasIO io => AsmGlobalState -> String -> io ()
jaddUntypedFunction state name = primIO $ prim_jaddUntypedFunction state name

public export
addUntypedFunction : HasIO io => AsmGlobalState -> Jname -> io ()
addUntypedFunction globalState name = jaddUntypedFunction globalState (getSimpleName name)

public export
%foreign jvm' "io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState" ".classCodeEnd"
"io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState String String String" "void"
Expand Down Expand Up @@ -638,6 +613,7 @@ Show Scope where
("index", show $ index scope),
("parentIndex", show $ parentIndex scope),
("nextVariableIndex", show $ nextVariableIndex scope),
("variableTypes", show $ unsafePerformIO $ Map.toList $ variableTypes scope),
("lineNumbers", show $ lineNumbers scope),
("variableIndices", toString $ variableIndices scope),
("returnType", show $ returnType scope),
Expand Down Expand Up @@ -680,10 +656,17 @@ public export
%foreign "jvm:crash(String java/lang/Object),io/github/mmhelloworld/idrisjvm/runtime/Runtime"
crash : String -> Object

public export
%foreign "jvm:getStackTraceString(String),io/github/mmhelloworld/idrisjvm/runtime/Runtime"
getStackTraceString : PrimIO String

export
asmCrash : String -> Core a
asmCrash message = throw (InternalError message)
asmCrash message = do
stackTrace <- coreLift $ primIO getStackTraceString
throw (InternalError $ message ++ "\n" ++ stackTrace)

export
isBoolTySpec : Name -> Bool
isBoolTySpec name = name == basics "Bool" || name == (NS preludeNS (UN $ Basic "Bool"))

Expand Down Expand Up @@ -730,14 +713,6 @@ export
arrayName : Name
arrayName = NS (mkNamespace "Java.Lang") (UN $ Basic "Array")

getIdrisConstructorType : ConInfo -> (tag: Maybe Int) -> Nat -> Name -> InferredType
getIdrisConstructorType conInfo tag arity name =
if isBoolTySpec name then IBool
else if name == basics "List" then idrisListType
else if name == preludetypes "Maybe" then idrisMaybeType
else if name == preludetypes "Nat" then inferredBigIntegerType
else inferredObjectType

parseName : String -> Maybe InferredType
parseName name =
case words name of
Expand Down Expand Up @@ -2320,14 +2295,6 @@ export
getFcAndDefinition : {auto stateRef: Ref AsmState AsmState} -> String -> Core (FC, NamedDef)
getFcAndDefinition name = coreLift $ AsmGlobalState.getFcAndDefinition !getGlobalState name

export
isUntypedFunction : {auto stateRef: Ref AsmState AsmState} -> Jname -> Core Bool
isUntypedFunction name = coreLift $ AsmGlobalState.isUntypedFunction !getGlobalState name

export
addUntypedFunction : {auto stateRef: Ref AsmState AsmState} -> Jname -> Core ()
addUntypedFunction name = coreLift $ AsmGlobalState.addUntypedFunction !getGlobalState name

export
setCurrentFunction : {auto stateRef: Ref AsmState AsmState} -> Function -> Core ()
setCurrentFunction function = updateState $ { currentIdrisFunction := function }
Expand Down Expand Up @@ -2546,7 +2513,19 @@ retrieveVariableIndicesByName scopeIndex = do
go1 scopeIndex = do
scope <- getScope scopeIndex
coreLift $ updateVariableIndices acc (variableIndices scope)
maybe (pure ()) go1 (parentIndex scope)
let Just nextScopeIndex = parentIndex scope
| Nothing => pure ()
go1 nextScopeIndex

isParameter : {auto stateRef: Ref AsmState AsmState} -> String -> Core Bool
isParameter name = do
scope <- getScope 0
optIndex <- coreLift $ Map.get {value=Int} (variableIndices scope) name
case nullableToMaybe optIndex of
Nothing => pure False
Just index => do
function <- getCurrentFunction
pure (index < cast (length (parameterTypes (inferredFunctionType function))))

export
retrieveVariables : {auto stateRef: Ref AsmState AsmState} -> Int -> Core (List String)
Expand Down Expand Up @@ -2649,21 +2628,14 @@ export
getVariableType : {auto stateRef: Ref AsmState AsmState} -> String -> Core InferredType
getVariableType name = getVariableTypeAtScope !getCurrentScopeIndex name

updateArgumentsForUntyped : Map Int InferredType -> Nat -> IO ()
updateArgumentsForUntyped _ Z = pure ()
updateArgumentsForUntyped types (S n) = do
ignore $ Map.put types (cast {to=Int} n) inferredObjectType
updateArgumentsForUntyped types n

export
updateScopeVariableTypes : {auto stateRef: Ref AsmState AsmState} -> Nat -> Core ()
updateScopeVariableTypes arity = go (scopeCounter !getState - 1) where
updateScopeVariableTypes : {auto stateRef: Ref AsmState AsmState} -> Core ()
updateScopeVariableTypes = go (scopeCounter !getState - 1) where
go : Int -> Core ()
go scopeIndex =
if scopeIndex < 0 then pure ()
else do
variableTypes <- retrieveVariableTypesAtScope scopeIndex
when (scopeIndex == 0) $ coreLift $ updateArgumentsForUntyped variableTypes arity
variableIndices <- retrieveVariableIndicesByName scopeIndex
scope <- getScope scopeIndex
saveScope $ {allVariableTypes := variableTypes, allVariableIndices := variableIndices} scope
Expand All @@ -2679,18 +2651,23 @@ getVariableScope name = go !getCurrentScopeIndex where
Just _ => pure scope
Nothing => case parentIndex scope of
Just parentScopeIndex => go parentScopeIndex
Nothing => asmCrash ("Unknown variable " ++ name)
Nothing => do
let functionName = idrisName !getCurrentFunction
asmCrash ("Unknown variable \{name} in function \{show functionName}")

export
addVariableType : {auto stateRef: Ref AsmState AsmState} -> String -> InferredType -> Core InferredType
addVariableType var IUnknown = pure IUnknown
addVariableType var ty = do
addVariableType : {auto stateRef: Ref AsmState AsmState} -> String -> InferredType -> Core ()
addVariableType _ IUnknown = pure ()
addVariableType var ty = when (not !(isParameter var)) $ do
scope <- getVariableScope var
let scopeIndex = index scope
existingTy <- retrieveVariableTypeAtScope scopeIndex var
let newTy = existingTy <+> ty
_ <- coreLift $ Map.put (variableTypes scope) var newTy
pure newTy
ignore $ coreLift $ Map.put (variableTypes scope) var ty

export
retrieveVariableType : {auto stateRef: Ref AsmState AsmState} -> String -> Core InferredType
retrieveVariableType var = do
scope <- getVariableScope var
let scopeIndex = index scope
retrieveVariableTypeAtScope scopeIndex var

%inline
export
Expand Down Expand Up @@ -2733,9 +2710,18 @@ mutual
then pure $ parseName namePartsStr
else pure Nothing
parseJvmReferenceType (NmCon _ name conInfo tag args) =
if name == primio "IORes" then
maybe (asmCrash "Expected an argument for IORes") (\res => pure $ Just !(tySpec res)) (head' args)
else pure $ Just $ getIdrisConstructorType conInfo tag (length args) name
if name == primio "IORes" then
maybe (asmCrash "Expected an argument for IORes") (\res => pure $ Just !(tySpec res)) (head' args)
else pure $ Just $ getIdrisConstructorType name
where
getIdrisConstructorType : Name -> InferredType
getIdrisConstructorType name =
if isBoolTySpec name then IBool
else if name == basics "List" then idrisListType
else if name == preludetypes "Maybe" then idrisMaybeType
else if name == preludetypes "Nat" then inferredBigIntegerType
else inferredObjectType

parseJvmReferenceType (NmApp fc (NmRef _ name) _) = do
(_, MkNmFun _ def) <- getFcAndDefinition (jvmSimpleName name)
| _ => asmCrash ("Expected a function returning a tuple containing interface type and method type at " ++
Expand Down Expand Up @@ -2768,7 +2754,6 @@ mutual
ty <- tryParse expr
pure $ fromMaybe inferredObjectType ty


export
asmReturn : {auto stateRef: Ref AsmState AsmState} -> InferredType -> Core ()
asmReturn IVoid = return
Expand Down
Loading

0 comments on commit 7c70a26

Please sign in to comment.