diff --git a/README.md b/README.md
index 3ab95245..89c5f829 100644
--- a/README.md
+++ b/README.md
@@ -305,8 +305,28 @@ You can use `org.apache.spark.sql.pulsar.JsonUtils.topicOffsets(Map[String, Mess
This may cause a false alarm. You can set it to `false` when it doesn't work as you expected.
A batch query always fails if it fails to read any data from the provided offsets due to data loss.
-
+
`allowDifferentTopicSchemas`
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarMetadataReader.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarMetadataReader.scala
index 33bbb822..f560ad04 100644
--- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarMetadataReader.scala
+++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarMetadataReader.scala
@@ -22,11 +22,13 @@ import java.util.regex.Pattern
import org.apache.pulsar.client.admin.{PulsarAdmin, PulsarAdminException}
import org.apache.pulsar.client.api.{Message, MessageId, PulsarClient}
import org.apache.pulsar.client.impl.schema.BytesSchema
+import org.apache.pulsar.client.internal.DefaultImplementation
import org.apache.pulsar.common.naming.TopicName
import org.apache.pulsar.common.schema.SchemaInfo
import org.apache.spark.internal.Logging
import org.apache.spark.sql.pulsar.PulsarOptions._
+import org.apache.spark.sql.pulsar.topicinternalstats.forward._
import org.apache.spark.sql.types.StructType
/**
@@ -259,6 +261,68 @@ private[pulsar] case class PulsarMetadataReader(
}.toMap)
}
+ def fetchNextOffsetWithMaxEntries(actualOffset: Map[String, MessageId],
+ numberOfEntries: Long): SpecificPulsarOffset = {
+ getTopicPartitions()
+
+ // Collect internal stats for all topics
+ val topicStats = topicPartitions.map( topic => {
+ topic -> admin.topics().getInternalStats(topic)
+ } ).toMap.asJava
+
+ SpecificPulsarOffset(topicPartitions.map { topic =>
+ topic -> PulsarSourceUtils.seekableLatestMid {
+ // Fetch actual offset for topic
+ val topicActualMessageId = actualOffset.getOrElse(topic, MessageId.earliest)
+ try {
+ // Get the actual ledger
+ val actualLedgerId = PulsarSourceUtils.getLedgerId(topicActualMessageId)
+ // Get the actual entry ID
+ val actualEntryId = PulsarSourceUtils.getEntryId(topicActualMessageId)
+ // Get the partition index
+ val partitionIndex = PulsarSourceUtils.getPartitionIndex(topicActualMessageId)
+ // Cache topic internal stats
+ val internalStats = topicStats.get(topic)
+ // Calculate the amount of messages we will pull in
+ val numberOfEntriesPerTopic = numberOfEntries / topics.size
+ // Get a next message ID which respects
+ // the maximum number of messages
+ val (nextLedgerId, nextEntryId) = TopicInternalStatsUtils.forwardMessageId(
+ internalStats,
+ actualLedgerId,
+ actualEntryId,
+ numberOfEntriesPerTopic)
+ // Build the next message ID
+ val nextMessageId =
+ DefaultImplementation
+ .getDefaultImplementation
+ .newMessageId(nextLedgerId, nextEntryId, partitionIndex)
+ // Log state
+ val entryCountUntilNextMessageId = TopicInternalStatsUtils.numOfEntriesUntil(
+ internalStats, nextLedgerId, nextEntryId)
+ val entryCount = internalStats.numberOfEntries
+ val progress = f"${entryCountUntilNextMessageId.toFloat / entryCount.toFloat}%1.3f"
+ val logMessage = s"Pulsar Connector offset step forward. " +
+ s"[$numberOfEntriesPerTopic/$numberOfEntries]" +
+ s"${topic.reverse.take(30).reverse} $topicActualMessageId -> " +
+ s"$nextMessageId ($entryCountUntilNextMessageId/$entryCount) [$progress]"
+ log.debug(logMessage)
+ // Return the message ID
+ nextMessageId
+ } catch {
+ case e: PulsarAdminException if e.getStatusCode == 404 =>
+ MessageId.earliest
+ case e: Throwable =>
+ throw new RuntimeException(
+ s"Failed to get forwarded messageId for ${TopicName.get(topic).toString} " +
+ s"(tried to forward ${numberOfEntries} messages " +
+ s"starting from `$topicActualMessageId`)", e)
+ }
+
+ }
+ }.toMap)
+ }
+
def fetchLatestOffsetForTopic(topic: String): MessageId = {
val messageId =
try {
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala
index db1a06ba..2c2425cb 100644
--- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala
+++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala
@@ -37,6 +37,8 @@ private[pulsar] object PulsarOptions {
val TopicOptionKeys: Set[String] = Set(TopicSingle, TopicMulti, TopicPattern)
+ val MaxEntriesPerTrigger = "maxentriespertrigger"
+
val ServiceUrlOptionKey: String = "service.url"
val AdminUrlOptionKey: String = "admin.url"
val StartingOffsetsOptionKey: String = "startingOffsets".toLowerCase(Locale.ROOT)
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala
index 58658739..b29abc37 100644
--- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala
+++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala
@@ -113,7 +113,9 @@ private[pulsar] class PulsarProvider
pollTimeoutMs(caseInsensitiveParams),
failOnDataLoss(caseInsensitiveParams),
subscriptionNamePrefix,
- jsonOptions)
+ jsonOptions,
+ maxEntriesPerTrigger(caseInsensitiveParams)
+ )
}
override def createRelation(
@@ -395,6 +397,9 @@ private[pulsar] object PulsarProvider extends Logging {
(SparkEnv.get.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000).toString)
.toInt
+ private def maxEntriesPerTrigger(caseInsensitiveParams: Map[String, String]): Long =
+ caseInsensitiveParams.getOrElse(MaxEntriesPerTrigger, "-1").toLong
+
private def validateGeneralOptions(
caseInsensitiveParams: Map[String, String]): Map[String, String] = {
if (!caseInsensitiveParams.contains(ServiceUrlOptionKey)) {
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala
index ee71a685..ee0a25df 100644
--- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala
+++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala
@@ -36,7 +36,8 @@ private[pulsar] class PulsarSource(
pollTimeoutMs: Int,
failOnDataLoss: Boolean,
subscriptionNamePrefix: String,
- jsonOptions: JSONOptionsInRead)
+ jsonOptions: JSONOptionsInRead,
+ maxEntriesPerTrigger: Long)
extends Source
with Logging {
@@ -59,12 +60,21 @@ private[pulsar] class PulsarSource(
override def schema(): StructType = SchemaUtils.pulsarSourceSchema(pulsarSchema)
override def getOffset: Option[Offset] = {
- // Make sure initialTopicOffsets is initialized
initialTopicOffsets
- val latest = metadataReader.fetchLatestOffsets()
- currentTopicOffsets = Some(latest.topicOffsets)
- logDebug(s"GetOffset: ${latest.topicOffsets.toSeq.map(_.toString).sorted}")
- Some(latest.asInstanceOf[Offset])
+ val nextOffsets = if (maxEntriesPerTrigger == -1) {
+ metadataReader.fetchLatestOffsets()
+ } else {
+ currentTopicOffsets match {
+ case Some(value) =>
+ metadataReader.fetchNextOffsetWithMaxEntries(value,
+ maxEntriesPerTrigger)
+ case _ =>
+ metadataReader.fetchNextOffsetWithMaxEntries(initialTopicOffsets.topicOffsets,
+ maxEntriesPerTrigger)
+ }
+ }
+ logDebug(s"GetOffset: ${nextOffsets.topicOffsets.toSeq.map(_.toString).sorted}")
+ Some(nextOffsets.asInstanceOf[Offset])
}
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
@@ -74,9 +84,7 @@ private[pulsar] class PulsarSource(
logInfo(s"getBatch called with start = $start, end = $end")
val endTopicOffsets = SpecificPulsarOffset.getTopicOffsets(end)
- if (currentTopicOffsets.isEmpty) {
- currentTopicOffsets = Some(endTopicOffsets)
- }
+ currentTopicOffsets = Some(endTopicOffsets)
if (start.isDefined && start.get == end) {
return sqlContext.internalCreateDataFrame(
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala
index 84d44ebd..5247cd79 100644
--- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala
+++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala
@@ -120,6 +120,36 @@ private[pulsar] object PulsarSourceUtils extends Logging {
}
}
+ def getLedgerId(mid: MessageId): Long = {
+ mid match {
+ case bmid: BatchMessageIdImpl =>
+ bmid.getLedgerId
+ case midi: MessageIdImpl => midi.getLedgerId
+ case t: TopicMessageIdImpl => getLedgerId(t.getInnerMessageId)
+ case up: UserProvidedMessageId => up.getLedgerId
+ }
+ }
+
+ def getEntryId(mid: MessageId): Long = {
+ mid match {
+ case bmid: BatchMessageIdImpl =>
+ bmid.getEntryId
+ case midi: MessageIdImpl => midi.getEntryId
+ case t: TopicMessageIdImpl => getEntryId(t.getInnerMessageId)
+ case up: UserProvidedMessageId => up.getEntryId
+ }
+ }
+
+ def getPartitionIndex(mid: MessageId): Int = {
+ mid match {
+ case bmid: BatchMessageIdImpl =>
+ bmid.getPartitionIndex
+ case midi: MessageIdImpl => midi.getPartitionIndex
+ case t: TopicMessageIdImpl => getPartitionIndex(t.getInnerMessageId)
+ case up: UserProvidedMessageId => up.getPartitionIndex
+ }
+ }
+
def seekableLatestMid(mid: MessageId): MessageId = {
if (messageExists(mid)) mid else MessageId.earliest
}
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicInternalStatsUtils.scala b/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicInternalStatsUtils.scala
new file mode 100644
index 00000000..624a7023
--- /dev/null
+++ b/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicInternalStatsUtils.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pulsar.topicinternalstats.forward
+
+import scala.collection.JavaConverters.asScalaBufferConverter
+
+import org.apache.pulsar.common.policies.data.{ManagedLedgerInternalStats, PersistentTopicInternalStats}
+
+import org.apache.spark.sql.pulsar.topicinternalstats.forward.TopicInternalStatsUtils._
+
+object TopicInternalStatsUtils {
+
+ def forwardMessageId(stats: PersistentTopicInternalStats,
+ startLedgerId: Long,
+ startEntryId: Long,
+ forwardByEntryCount: Long): (Long, Long) = {
+ val ledgers = fixLastLedgerInInternalStat(stats).ledgers.asScala.toList
+ if (stats.ledgers.isEmpty || (forwardByEntryCount < 1)) {
+ // If there is no ledger info, or there is nothing to forward, stay at current ID
+ (startLedgerId, startEntryId)
+ } else {
+ // Find the start index in the list by its ledger ID
+ val startLedgerIndex: Int = stats.ledgers.asScala.find(_.ledgerId == startLedgerId) match {
+ // If found, start from there
+ case Some(index) => ledgers.indexWhere(_.ledgerId == startLedgerId)
+ // If it is not, but the value is -1, start from the beginning
+ case None if startLedgerId == -1 => 0
+ // In any other case, start from the end
+ case _ => ledgers.size - 1
+ }
+
+ // Clip the start entry ID withing th start ledger if needed
+ val startEntryIndex = Math.min(Math.max(startEntryId, 0), ledgers(startLedgerIndex).entries)
+
+ // Create an iterator over the ledgers list
+ val statsIterator =
+ new PersistentTopicInternalStatsIterator(stats, startLedgerIndex, startEntryIndex)
+
+ // Advance it forward with the amount of forward steps needed
+ val (forwardedLedgerId, forwardedEntryId) = (1L to forwardByEntryCount)
+ .map(_ => {statsIterator.next()}).last
+
+ (forwardedLedgerId, forwardedEntryId)
+ }
+ }
+
+ def numOfEntriesUntil(stats: PersistentTopicInternalStats,
+ ledgerId: Long,
+ entryId: Long): Long = {
+ val ledgers = fixLastLedgerInInternalStat(stats).ledgers.asScala
+ if (ledgers.isEmpty) {
+ 0
+ } else {
+ val ledgersBeforeStartLedger = fixLastLedgerInInternalStat(stats).ledgers
+ .asScala
+ .filter(_.ledgerId < ledgerId)
+ val entriesInLastLedger = if (ledgersBeforeStartLedger.isEmpty) {
+ Math.max(entryId, 0)
+ } else {
+ Math.min(Math.max(entryId, 0), ledgersBeforeStartLedger.last.entries)
+ }
+ entriesInLastLedger + ledgersBeforeStartLedger.map(_.entries).sum
+ }
+ }
+
+ def numOfEntriesAfter(stats: PersistentTopicInternalStats,
+ ledgerId: Long,
+ entryId: Long): Long = {
+ val ledgers = fixLastLedgerInInternalStat(stats).ledgers.asScala
+ if (ledgers.isEmpty) {
+ 0
+ } else {
+ val entryCountIncludingCurrentLedger = fixLastLedgerInInternalStat(stats).ledgers
+ .asScala
+ .filter(_.ledgerId >= ledgerId)
+ val entriesInFirstLedger = if (entryCountIncludingCurrentLedger.isEmpty) {
+ Math.max(entryId, 0)
+ } else {
+ Math.min(Math.max(entryId, 0), entryCountIncludingCurrentLedger.last.entries)
+ }
+ entryCountIncludingCurrentLedger.map(_.entries).sum - entriesInFirstLedger
+ }
+ }
+
+ def fixLastLedgerInInternalStat(
+ stats: PersistentTopicInternalStats): PersistentTopicInternalStats = {
+ if (stats.ledgers.isEmpty) {
+ stats
+ } else {
+ val lastLedgerInfo = stats.ledgers.get(stats.ledgers.size() - 1)
+ lastLedgerInfo.entries = stats.currentLedgerEntries
+ stats.ledgers.set(stats.ledgers.size() - 1, lastLedgerInfo)
+ stats
+ }
+ }
+
+}
+
+class PersistentTopicInternalStatsIterator(stats: PersistentTopicInternalStats,
+ startLedgerIndex: Int,
+ startEntryIndex: Long)
+ extends Iterator[(Long, Long)] {
+ val ledgers = fixLastLedgerInInternalStat(stats).ledgers.asScala.toList
+ private var currentLedgerIndex = startLedgerIndex
+ private var currentEntryIndex = startEntryIndex
+
+ override def hasNext: Boolean = !isLast
+ // If we are pointing to the last element
+ private def isLast: Boolean = currentLedgerIndex.equals(ledgers.size - 1) &&
+ currentEntryIndex.equals(ledgers.last.entries - 1)
+
+ override def next(): (Long, Long) = {
+ // Do not move past last element
+ if (hasNext) {
+ if (currentEntryIndex < (ledgers(currentLedgerIndex).entries - 1)) {
+ // Staying in the current ledger
+ currentEntryIndex += 1
+ } else {
+ // Advancing to the next ledger
+ currentLedgerIndex += 1
+ currentEntryIndex = 0
+ }
+ }
+ (ledgers(currentLedgerIndex).ledgerId, currentEntryIndex)
+ }
+}
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicState.scala b/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicState.scala
new file mode 100644
index 00000000..fa970ab0
--- /dev/null
+++ b/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicState.scala
@@ -0,0 +1,20 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pulsar.topicinternalstats.forward
+
+import org.apache.pulsar.common.policies.data.PersistentTopicInternalStats
+
+case class TopicState(internalStat: PersistentTopicInternalStats,
+ ledgerId: Long,
+ entryId: Long)
diff --git a/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicInternalStatsUtilsSuite.scala b/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicInternalStatsUtilsSuite.scala
new file mode 100644
index 00000000..60987b6b
--- /dev/null
+++ b/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicInternalStatsUtilsSuite.scala
@@ -0,0 +1,383 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pulsar.topicinternalstats.forward
+
+import TopicStateFixture._
+
+import org.apache.spark.SparkFunSuite
+
+class TopicInternalStatsUtilsSuite extends SparkFunSuite {
+
+ test("forward a single entry") {
+ val fakeStats = createPersistentTopicInternalStat(createLedgerInfo(1000, 500))
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 0, 0, 1)
+
+ assert(nextLedgerId == 1000)
+ assert(nextEntryId == 1)
+ }
+
+ test("forward empty ledger") {
+ val fakeStats = createPersistentTopicInternalStat()
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 0, 0, 10)
+
+ assert(nextLedgerId == 0)
+ assert(nextEntryId == 0)
+ }
+
+ test("forward within a single ledger") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 500)
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 1000, 0, 10)
+
+ assert(nextLedgerId == 1000)
+ assert(nextEntryId == 10)
+ }
+
+ test("forward within a single ledger starting from the middle") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 500)
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 1000, 25, 10)
+
+ assert(nextLedgerId == 1000)
+ assert(nextEntryId == 35)
+ }
+
+ test("forward to the next ledger") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50)
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 1000, 25, 50)
+
+ assert(nextLedgerId == 2000)
+ assert(nextEntryId == 25)
+ }
+
+ test("skip over a ledger if needed") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 1000, 25, 100)
+
+ assert(nextLedgerId == 3000)
+ assert(nextEntryId == 25)
+ }
+
+ test("forward to the end of the topic if too many entries need " +
+ "to be skipped with a single ledger") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 1000, 25, 600)
+
+ assert(nextLedgerId == 1000)
+ assert(nextEntryId == 49)
+ }
+
+ test("forward to the end of the topic if too many entries need " +
+ "to be skipped with multiple ledgers") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 1000, 25, 600)
+
+ assert(nextLedgerId == 3000)
+ assert(nextEntryId == 49)
+ }
+
+ test("forward with zero elements shall give you back what was given") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 2000, 25, 0)
+
+ assert(nextLedgerId == 2000)
+ assert(nextEntryId == 25)
+ }
+
+ test("forward from beginning of the topic") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, -1, -1, 125)
+
+ assert(nextLedgerId == 3000)
+ assert(nextEntryId == 25)
+ }
+
+ test("forward from non-existent ledger id shall forward from the last ledger instead") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 6000, 0, 25)
+
+ assert(nextLedgerId == 3000)
+ assert(nextEntryId == 25)
+ }
+
+ test("forward from non-existent entry id shall forward from end of ledger instead") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 1000, 250, 25)
+
+ assert(nextLedgerId == 2000)
+ assert(nextEntryId == 24)
+ }
+
+ test("forwarded entry id shall never be less than current entry id") {
+ val startEntryID = 200
+ val ledgerID = 1000
+ val entriesInLedger = 205
+ val forwardByEntries = 50
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(ledgerID, entriesInLedger)
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, ledgerID, startEntryID, forwardByEntries)
+ assert(nextLedgerId == ledgerID)
+ assert(nextEntryId > startEntryID)
+ }
+
+ test("number of entries until shall work with empty input") {
+ val fakeStats = createPersistentTopicInternalStat()
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, -1, -1)
+
+ assert(result == 0)
+ }
+
+ test("number of entries until with single ledger") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, 1000, 25)
+
+ assert(result == 25)
+ }
+
+ test("number of entries until with multiple ledgers") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, 2000, 25)
+
+ assert(result == 75)
+ }
+
+ test("number of entries until beginning of topic") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, -1, -1)
+
+ assert(result == 0)
+ }
+
+ test("number of entries until end of topic") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, 3000, 50)
+
+ assert(result == 150)
+ }
+
+ test("number of entries until with ledger id below boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, -2, 0)
+
+ assert(result == 0)
+ }
+
+ test("number of entries until with entry id below boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, 2000, -2)
+
+ assert(result == 50)
+
+ }
+
+ test("number of entries until with ledger id above boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, 6000, 0)
+
+ assert(result == 150)
+ }
+
+ test("number of entries until with entry id above boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, 2000, 200)
+
+ assert(result == 100)
+ }
+
+ test("number of entries after shall work with empty input") {
+ val fakeStats = createPersistentTopicInternalStat()
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, -1, -1)
+
+ assert(result == 0)
+ }
+
+ test("number of entries after with single ledger") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, 1000, 20)
+
+ assert(result == 30)
+ }
+
+ test("number of entries after with multiple ledgers") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, 1000, 20)
+
+ assert(result == 130)
+ }
+
+ test("number of entries after beginning of topic") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, -1, -1)
+
+ assert(result == 150)
+ }
+
+ test("number of entries after end of topic") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, 3000, 50)
+
+ assert(result == 0)
+ }
+
+ test("number of entries after with ledger id below boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, -2, 0)
+
+ assert(result == 150)
+ }
+
+ test("number of entries after with entry id below boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, 2000, -2)
+
+ assert(result == 100)
+ }
+
+ test("number of entries after with ledger id above boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, 6000, 0)
+
+ assert(result == 0)
+ }
+
+ test("number of entries after with entry id above boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, 2000, 200)
+
+ assert(result == 50)
+ }
+}
diff --git a/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicStateTestFixture.scala b/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicStateTestFixture.scala
new file mode 100644
index 00000000..5f2878b2
--- /dev/null
+++ b/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicStateTestFixture.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pulsar.topicinternalstats.forward
+
+import java.util
+
+import org.apache.pulsar.common.policies.data.ManagedLedgerInternalStats.LedgerInfo
+import org.apache.pulsar.common.policies.data.PersistentTopicInternalStats
+
+object TopicStateFixture {
+
+ def createTopicState(topicInternalStats: PersistentTopicInternalStats,
+ ledgerId: Long,
+ entryId: Long): TopicState = {
+ TopicState(topicInternalStats, ledgerId, entryId)
+ }
+
+ def createPersistentTopicInternalStat(ledgers: LedgerInfo*): PersistentTopicInternalStats = {
+ val result = new PersistentTopicInternalStats()
+
+ result.currentLedgerEntries = if (ledgers.isEmpty) {
+ 0
+ } else {
+ ledgers.last.entries
+ }
+
+ if (!ledgers.isEmpty) {
+ // simulating a bug in the Pulsar Admin interface
+ // (the last ledger in the list of ledgers has 0
+ // as entry count instead of the current entry
+ // count)
+ val modifiedLastEntryId = ledgers.last
+ modifiedLastEntryId.entries = 0
+ }
+ result.ledgers = util.Arrays.asList(ledgers: _*)
+ result
+ }
+
+ def createLedgerInfo(ledgerId: Long, entries: Long): LedgerInfo = {
+ val result = new LedgerInfo()
+ result.ledgerId = ledgerId
+ result.entries = entries
+ result
+ }
+}
+
|