Skip to content

Commit

Permalink
Files.lines() is just about as fast, cleaner code.
Browse files Browse the repository at this point in the history
Split the file into parts and process in parallel, runs in +/- 15 sec on M2 Pro with Java 23 EA preview

Removing old version

Added some branchless parsing to the mix.

Improved the branchless code even further, and made it a little bit more readable.

Oops.

Moved to memory mapped files as well, thanks for the inspiration bjhara!

Initial implementation, using BufferedReader, parallel processing, combining everything in a single go, sorting afterwards (unoptimized)

Files.lines() is just about as fast, cleaner code.

Split the file into parts and process in parallel, runs in +/- 15 sec on M2 Pro with Java 23 EA preview

Removing old version

Added some branchless parsing to the mix.

Improved the branchless code even further, and made it a little bit more readable.

Oops.

Moved to memory mapped files as well, thanks for the inspiration bjhara!

Not me, writing a custom HashMap....
  • Loading branch information
royvanrijn committed Jan 2, 2024
1 parent 3e25828 commit ad3b443
Show file tree
Hide file tree
Showing 2 changed files with 279 additions and 43 deletions.
8 changes: 7 additions & 1 deletion calculate_average_royvanrijn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,11 @@
#


JAVA_OPTS=""
# Added for fun, doesn't seem to be making a difference...
if [ -f "target/calculate_average_royvanrijn.jsa" ]; then
JAVA_OPTS="-XX:SharedArchiveFile=target/calculate_average_royvanrijn.jsa -Xshare:on"
else
# First run, create the archive:
JAVA_OPTS="-XX:ArchiveClassesAtExit=target/calculate_average_royvanrijn.jsa"
fi
time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_royvanrijn
314 changes: 272 additions & 42 deletions src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,70 +15,300 @@
*/
package dev.morling.onebrc;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.AbstractMap;
import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

/**
* Changelog:
*
* Initial submission: 62000 ms
* Chunked reader: 16000 ms
* Optimized parser: 13000 ms
* Branchless methods: 11000 ms
* Adding memory mapped files: 6500 ms (based on bjhara's submission)
* Skipping string creation: 4700 ms
* Custom hashmap... 4200 ms
*
* Best performing JVM on MacBook M2 Pro: 21.0.1-graal
* `sdk use java 21.0.1-graal`
*/
public class CalculateAverage_royvanrijn {

private static final String FILE = "./measurements.txt";

private record Measurement(double min, double max, double sum, long count) {
// mutable state now instead of records, ugh, less instantiation.
static final class Measurement {
int min, max, count;
long sum;

public Measurement() {
this.min = 10000;
this.max = -10000;
}

Measurement(double initialMeasurement) {
this(initialMeasurement, initialMeasurement, initialMeasurement, 1);
public Measurement updateWith(int measurement) {
min = min(min, measurement);
max = max(max, measurement);
sum += measurement;
count++;
return this;
}

public static Measurement combineWith(Measurement m1, Measurement m2) {
return new Measurement(
m1.min < m2.min ? m1.min : m2.min,
m1.max > m2.max ? m1.max : m2.max,
m1.sum + m2.sum,
m1.count()+m2.count
);
public Measurement updateWith(Measurement measurement) {
min = min(min, measurement.min);
max = max(max, measurement.max);
sum += measurement.sum;
count += measurement.count;
return this;
}

public String toString() {
return round(min) + "/" + round(sum / count) + "/" + round(max);
return round(min) + "/" + round((1.0 * sum) / count) + "/" + round(max);
}

private double round(double value) {
return Math.round(value * 10.0) / 10.0;
return Math.round(value) / 10.0;
}
}

public static void main(String[] args) throws IOException {

long before = System.currentTimeMillis();

// Function to map
try (BufferedReader br = new BufferedReader(new FileReader(FILE))) {

// Took: 124080
Map<String, Measurement> resultMap = br.lines().parallel()
.map(record -> {
// Map to <String,double>
int pivot = record.indexOf(";");
String key = record.substring(0, pivot);
double measured = Double.parseDouble(record.substring(pivot + 1));
return new AbstractMap.SimpleEntry<>(key, measured);
})
.collect(Collectors.toConcurrentMap(
// Combine/reduce:
AbstractMap.SimpleEntry::getKey,
entry -> new Measurement(entry.getValue()),
Measurement::combineWith));

System.out.print("{");
System.out.print(
resultMap.entrySet().stream().sorted(Map.Entry.comparingByKey()).map(Object::toString).collect(Collectors.joining(", ")));
System.out.println("}");
private static Map<String, Measurement> merge(Map<String, Measurement> map1,
Map<String, Measurement> map2) {
for (var entry : map2.entrySet()) {
map1.merge(entry.getKey(), entry.getValue(), (e1, e2) -> e1.updateWith(e2));
}
return map1;
}

System.out.println("Took: " + (System.currentTimeMillis() - before));
public static final void main(String[] args) throws Exception {

new CalculateAverage_royvanrijn().run();
}

private static BitTwiddledMap merge(BitTwiddledMap map1, BitTwiddledMap map2) {
for (var entry : map2.values) {
map1.getOrCreate(entry.key).updateWith(entry.measurement);
}
return map1;
}

public void run() throws Exception {

try (FileChannel fileChannel = (FileChannel) Files.newByteChannel(Path.of(FILE), EnumSet.of(StandardOpenOption.READ))) {

var customMap = splitFileChannel(fileChannel)
.parallel()
.map(this::processBuffer)
.collect(Collectors.reducing(CalculateAverage_royvanrijn::merge));

// Seems to perform better than actually using a TreeMap:
System.out.println("{" + customMap.orElseThrow().values
.stream()
.sorted(Comparator.comparing(e -> e.key))
.map(Object::toString)
.collect(Collectors.joining(", ")) + "}");
}
}

private BitTwiddledMap processBuffer(ByteBuffer bb) {

BitTwiddledMap measurements = new BitTwiddledMap();

final int limit = bb.limit();
final byte[] buffer = new byte[64];

while (bb.position() < limit) {

// Find the correct positions in the bytebuffer:

// Start:
final int startPointer = bb.position();

// Separator:
int separatorPointer = startPointer + 3; // key is at least 3 long
while (separatorPointer != limit && bb.get(separatorPointer) != ';') {
separatorPointer++;
}

// EOL:
int endPointer = separatorPointer + 3; // temperature is at least 3 long
while (endPointer != limit && bb.get(endPointer) != '\n')
endPointer++;

// Extract the name of the key and move the bytebuffer:
final int nameLength = separatorPointer - startPointer;
bb.get(buffer, 0, nameLength);
final String key = new String(buffer, 0, nameLength);

bb.get(); // skip the separator

// Extract the measurement value (10x), skip making a String altogether:
final int valueLength = endPointer - separatorPointer - 1;
bb.get(buffer, 0, valueLength);

// and get rid of the new line (handle both kinds)
byte newline = bb.get();
if (newline == '\r')
bb.get();

int measured = branchlessParseInt(buffer, valueLength);

// Update the map, computeIfAbsent has the least amount of branches I think, compared to get()/put() or merge() or compute():
measurements.getOrCreate(key).updateWith(measured);
}

return measurements;
}

/**
* Thanks to bjhara for the idea of using memory mapped files, TIL.
* @param fileChannel
* @return
* @throws IOException
*/
private static Stream<ByteBuffer> splitFileChannel(final FileChannel fileChannel) throws IOException {
return StreamSupport.stream(Spliterators.spliteratorUnknownSize(new Iterator<>() {
private static final long CHUNK_SIZE = (long) Math.pow(2, 19);

private final long size = fileChannel.size();
private long bytesRead = 0;

@Override
public ByteBuffer next() {
try {
final MappedByteBuffer mappedByteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, bytesRead, Math.min(CHUNK_SIZE, size - bytesRead));

// Adjust end to start of a line:
int realEnd = mappedByteBuffer.limit() - 1;
while (mappedByteBuffer.get(realEnd) != '\n') {
realEnd--;
}
mappedByteBuffer.limit(++realEnd);
bytesRead += realEnd;

return mappedByteBuffer;
}
catch (IOException e) {
throw new RuntimeException(e);
}
}

@Override
public boolean hasNext() {
return bytesRead < size;
}
}, Spliterator.IMMUTABLE), false);
}

/**
* Branchless parser, goes from String to int (10x):
* "-1.2" to -12
* "40.1" to 401
* etc.
*
* @param input
* @return int value x10
*/
private static int branchlessParseInt(final byte[] input, int length) {
// 0 if positive, 1 if negative
final int negative = ~(input[0] >> 4) & 1;
// 0 if nr length is 3, 1 if length is 4
final int has4 = ((length - negative) >> 2) & 1;

final int digit1 = input[negative] - '0';
final int digit2 = input[negative + has4] - '0';
final int digit3 = input[2 + negative + has4] - '0';

return (-negative ^ (has4 * (digit1 * 100) + digit2 * 10 + digit3) - negative);
}

// branchless max (unprecise for large numbers, but good enough)
static int max(final int a, final int b) {
final int diff = a - b;
final int dsgn = diff >> 31;
return a - (diff & dsgn);
}

// branchless min (unprecise for large numbers, but good enough)
static int min(final int a, final int b) {
final int diff = a - b;
final int dsgn = diff >> 31;
return b + (diff & dsgn);
}

/**
* A normal Java HashMap does all these safety things like boundary checks... we don't need that, we need speeeed.
*
* So I've written an extremely simple linear probing hashmap that should work well enough.
*/
class BitTwiddledMap {
private static final int SIZE = 2048; // A bit larger than the number of keys, needs power of two
private int[] indices = new int[SIZE]; // Hashtable is just an int[]

BitTwiddledMap() {
// Optimized fill with -1, fastest method:
int len = indices.length;
if (len > 0) {
indices[0] = -1;
}
// Value of i will be [1, 2, 4, 8, 16, 32, ..., len]
for (int i = 1; i < len; i += i) {
System.arraycopy(indices, 0, indices, i, i);
}
}

private List<Entry> values = new ArrayList<>(1024);

record Entry(int hash, String key, Measurement measurement) {
@Override
public String toString() {
return key + "=" + measurement;
}
}

/**
* Who needs methods like add(), merge(), compute() etc, we need one, getOrCreate.
* @param key
* @return
*/
public Measurement getOrCreate(String key) {
int inHash;
int index = (SIZE - 1) & (inHash = hash(key));
int valueIndex;
Entry retrievedEntry = null;
while ((valueIndex = indices[index]) != -1 && (retrievedEntry = values.get(valueIndex)).hash != inHash) {
index = (index + 1) % SIZE;
}
if (valueIndex >= 0) {
return retrievedEntry.measurement;
}
// New entry, insert into table and return.
indices[index] = values.size();
Entry toAdd = new Entry(inHash, key, new Measurement());
values.add(toAdd);
return toAdd.measurement;
}

private int hash(String key) {
// Implement your custom hash function here
int h;
return (h = key.hashCode()) ^ (h >>> 16);
}
}

}

0 comments on commit ad3b443

Please sign in to comment.