Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Union and intersection of TimeSeries #100

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 228 additions & 7 deletions src/main/scala/com/cloudera/sparkts/DateTimeIndex.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package com.cloudera.sparkts

import java.util
import java.util.{Comparators, Comparator}

import org.threeten.extra._
Expand Down Expand Up @@ -104,6 +105,26 @@ trait DateTimeIndex extends Serializable {
*/
def locAtDateTime(dt: Long): Int

/**
* The location at which the given date-time could be inserted. It is the location of the first
* date-time that is greater than the given date-time. If the given date-time is greater than
* or equal to the last date-time in the index, the index size is returned.
*/
def insertionLoc(dt: ZonedDateTime): Int

/**
* The location at which the given date-time, as milliseconds since the epoch, could be inserted.
* It is the location of the first date-time that is greater than the given date-time. If the
* given date-time is greater than or equal to the last date-time in the index, the index size
* is returned.
*/
def insertionLoc(dt: Long): Int

/**
* Returns the contents of the DateTimeIndex as an array of nanoseconds values from the epoch.
*/
def toInstantsArray(): Array[Long]

/**
* Returns the contents of the DateTimeIndex as an array of millisecond values from the epoch.
*/
Expand All @@ -113,6 +134,21 @@ trait DateTimeIndex extends Serializable {
* Returns the contents of the DateTimeIndex as an array of ZonedDateTime
*/
def toZonedDateTimeArray(): Array[ZonedDateTime]

/**
* Returns an iterator over the contents of the DateTimeIndex as milliseconds
*/
def millisIterator(): Iterator[Long]

/**
* Returns an iterator over the contents of the DateTimeIndex as ZonedDateTime
*/
def zonedDateTimeIterator(): Iterator[ZonedDateTime]

/**
* Returns a new DateTimeIndex with instants at the specified zone
*/
def atZone(zone: ZoneId): DateTimeIndex
}

/**
Expand Down Expand Up @@ -172,6 +208,33 @@ class UniformDateTimeIndex(
locAtDateTime(longToZonedDateTime(dt, dateTimeZone))
}

override def insertionLoc(dt: ZonedDateTime): Int = {
val loc = frequency.difference(first, dt)
if (loc >= 0 && loc < size) {
if (dateTimeAtLoc(loc).compareTo(dt) <= 0) {
loc + 1
} else {
loc
}
} else if (loc < 0) {
0
} else {
size
}
}

override def insertionLoc(dt: Long): Int = {
insertionLoc(longToZonedDateTime(dt, dateTimeZone))
}

override def toInstantsArray(): Array[Long] = {
val arr = new Array[Long](periods)
for (i <- 0 until periods) {
arr(i) = zonedDateTimeToLong(dateTimeAtLoc(i))
}
arr
}

override def toMillisArray(): Array[Long] = {
val arr = new Array[Long](periods)
for (i <- 0 until periods) {
Expand All @@ -198,6 +261,34 @@ class UniformDateTimeIndex(
"uniform", dateTimeZone.toString, start.toString,
periods.toString, frequency.toString).mkString(",")
}

override def millisIterator(): Iterator[Long] = {
new Iterator[Long] {
val zdtIter = zonedDateTimeIterator

override def hasNext: Boolean = zdtIter.hasNext

override def next(): Long = zonedDateTimeToLong(zdtIter.next) / 1000000L
}
}

override def zonedDateTimeIterator(): Iterator[ZonedDateTime] = {
new Iterator[ZonedDateTime] {
var current = first

override def hasNext: Boolean = current.compareTo(last) <= 0

override def next(): ZonedDateTime = {
val ret = current
current = frequency.advance(current, 1)
ret
}
}
}

override def atZone(zone: ZoneId): UniformDateTimeIndex = {
new UniformDateTimeIndex(start.withZoneSameInstant(zone), periods, frequency, zone)
}
}

/**
Expand Down Expand Up @@ -252,6 +343,25 @@ class IrregularDateTimeIndex(
if (loc < 0) -1 else loc
}

override def insertionLoc(dt: ZonedDateTime): Int = {
insertionLoc(zonedDateTimeToLong(dt))
}

override def insertionLoc(dt: Long): Int = {
var loc = java.util.Arrays.binarySearch(instants, dt)
if (loc >= 0) {
do loc += 1
while (loc < size && instants(loc) == dt)
loc
} else {
-loc - 1
}
}

override def toInstantsArray(): Array[Long] = {
instants.clone
}

override def toMillisArray(): Array[Long] = {
instants.map(dt => dt / 1000000L)
}
Expand All @@ -269,6 +379,30 @@ class IrregularDateTimeIndex(
"irregular," + dateTimeZone.toString + "," +
instants.map(longToZonedDateTime(_, dateTimeZone).toString).mkString(",")
}

override def millisIterator(): Iterator[Long] = {
new Iterator[Long] {
val instIter = instants.iterator

override def hasNext: Boolean = instIter.hasNext

override def next(): Long = instIter.next / 1000000L
}
}

override def zonedDateTimeIterator(): Iterator[ZonedDateTime] = {
new Iterator[ZonedDateTime] {
val instIter = instants.iterator

override def hasNext: Boolean = instIter.hasNext

override def next(): ZonedDateTime = longToZonedDateTime(instIter.next, dateTimeZone)
}
}

override def atZone(zone: ZoneId): IrregularDateTimeIndex = {
new IrregularDateTimeIndex(instants, zone)
}
}

/**
Expand Down Expand Up @@ -296,8 +430,8 @@ class HybridDateTimeIndex(
override def slice(start: ZonedDateTime, end: ZonedDateTime): HybridDateTimeIndex = {
require(start.isBefore(end), s"start($start) should be less than end($end)")

val startIndex = binarySearch(0, indices.length - 1, start)
val endIndex = binarySearch(0, indices.length - 1, end)
val startIndex = binarySearch(0, indices.length - 1, start)._1
val endIndex = binarySearch(0, indices.length - 1, end)._1

val newIndices =
if (startIndex == endIndex) {
Expand Down Expand Up @@ -384,7 +518,7 @@ class HybridDateTimeIndex(
}

override def locAtDateTime(dt: ZonedDateTime): Int = {
val i = binarySearch(0, indices.length - 1, dt)
val i = binarySearch(0, indices.length - 1, dt)._1
if (i > -1) {
val loc = indices(i).locAtDateTime(dt)
if (loc > -1) sizeOnLeft(i) + loc
Expand All @@ -393,17 +527,52 @@ class HybridDateTimeIndex(
else -1
}

override def locAtDateTime(dt: Long): Int =
override def locAtDateTime(dt: Long): Int = {
locAtDateTime(longToZonedDateTime(dt, dateTimeZone))
}

override def insertionLoc(dt: ZonedDateTime): Int = {
val loc = binarySearch(0, indices.length - 1, dt)._2
if (loc >= 0) {
sizeOnLeft(loc) + indices(loc).insertionLoc(dt)
} else if (dt.isBefore(first)) {
0
} else {
size
}
}

override def insertionLoc(dt: Long): Int = {
insertionLoc(longToZonedDateTime(dt, dateTimeZone))
}

private def binarySearch(low: Int, high: Int, dt: ZonedDateTime): Int = {
/**
* Returns a tuple (a, b):
* a: is the array index of the date-time index that contains the queried date-time dt
* or -1 if dt is not found. This value is used by locAtDateTime method.
* b: is the array index of the date-time index where the queried date-time dt could
* be inserted. This value is used by insertionLoc method.
*/
private def binarySearch(low: Int, high: Int, dt: ZonedDateTime): (Int, Int) = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you include a header comment here to indicate what the two tuple values mean? ~ @sryza

Added

if (low <= high) {
val mid = (low + high) >>> 1
val midIndex = indices(mid)
if (dt.isBefore(midIndex.first)) binarySearch(low, mid - 1, dt)
else if (dt.isAfter(midIndex.last)) binarySearch(mid + 1, high, dt)
else mid
} else -1
else (mid, mid)
} else {
// if coming from the call "binarySearch(low, mid - 1, dt)"
// on the condition "if (dt.isBefore(midIndex.first))"
if (high >= 0 && dt.isAfter(indices(high).last)) (-1, high)
// if coming from the call "binarySearch(mid + 1, high, dt)"
// on the condition "if (dt.isAfter(midIndex.last))"
else if (low < indices.length && dt.isBefore(indices(low).first)) (-1, low)
else (-1, -1)
}
}

override def toInstantsArray(): Array[Long] = {
indices.map(_.toInstantsArray).reduce(_ ++ _)
}

override def toMillisArray(): Array[Long] = {
Expand All @@ -423,6 +592,58 @@ class HybridDateTimeIndex(
"hybrid," + dateTimeZone.toString + "," +
indices.map(_.toString).mkString(";")
}

override def millisIterator(): Iterator[Long] = {
new Iterator[Long] {
val indicesIter = indices.iterator
var milIter = if (indicesIter.hasNext) indicesIter.next.millisIterator else null

override def hasNext: Boolean = {
if (milIter != null) {
if (milIter.hasNext) {
true
} else if(indicesIter.hasNext) {
milIter = indicesIter.next.millisIterator
hasNext
} else {
false
}
} else {
false
}
}

override def next(): Long = if (hasNext) milIter.next else -1
}
}

override def zonedDateTimeIterator(): Iterator[ZonedDateTime] = {
new Iterator[ZonedDateTime] {
val indicesIter = indices.iterator
var zdtIter = if (indicesIter.hasNext) indicesIter.next.zonedDateTimeIterator else null

override def hasNext: Boolean = {
if (zdtIter != null) {
if (zdtIter.hasNext) {
true
} else if(indicesIter.hasNext) {
zdtIter = indicesIter.next.zonedDateTimeIterator
hasNext
} else {
false
}
} else {
false
}
}

override def next(): ZonedDateTime = if (hasNext) zdtIter.next else null
}
}

override def atZone(zone: ZoneId): HybridDateTimeIndex = {
new HybridDateTimeIndex(indices.map(_.atZone(zone)), zone)
}
}

object DateTimeIndex {
Expand Down
Loading