From 65782ce938cdcd784e4fbed480a21ec2e537060c Mon Sep 17 00:00:00 2001 From: Minjae Gwon Date: Tue, 28 Nov 2023 02:14:43 +0900 Subject: [PATCH] fix: partition --- core/src/main/scala/Key.scala | 10 +++++++-- master/src/main/scala/Master.scala | 35 +++++++++++++++++------------- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/Key.scala b/core/src/main/scala/Key.scala index 0c11567..9d0a2c9 100644 --- a/core/src/main/scala/Key.scala +++ b/core/src/main/scala/Key.scala @@ -8,17 +8,23 @@ object Key { def fromByteString(byteString: ByteString): Key = new Key( byteString.toByteArray ) + + def min(size: Int = 10) = new Key(Array.fill[Byte](size)(0x00.toByte)) + + def max(size: Int = 10) = new Key(Array.fill[Byte](size)(0xff.toByte)) } -class Key(val underlying: Array[Byte]) extends AnyVal with Ordered[Key] { +case class Key(underlying: Array[Byte]) extends AnyVal with Ordered[Key] { def is(that: Key): Boolean = underlying sameElements that.underlying def hex: String = underlying.map("%02x" format _).mkString + def prior: Key = new Key(underlying.init :+ (underlying.last - 1).toByte) + override def compare(that: Key): Int = underlying .zip(that.underlying) .map { case (a, b) => - a - b + a.toChar - b.toChar } .find(_ != 0) .getOrElse(0) diff --git a/master/src/main/scala/Master.scala b/master/src/main/scala/Master.scala index 2bbc7a2..48d9342 100644 --- a/master/src/main/scala/Master.scala +++ b/master/src/main/scala/Master.scala @@ -13,19 +13,25 @@ import scala.concurrent.ExecutionContextExecutor object Master extends Logging { private def workersWithKeyRange( keys: List[Key], - workers: List[WorkerMetadata] - ): List[WorkerMetadata] = - keys - .sliding( - keys.size / workers.size, - keys.size / workers.size + workers: List[WorkerMetadata], + min: Key, + max: Key + ): List[WorkerMetadata] = { + val keysWithLowerBound = keys :+ min + + val startKeys = keysWithLowerBound.sorted + .grouped( + (keysWithLowerBound.size.toDouble / workers.size.ceil).ceil.toInt ) .toList - .map(keys => KeyRange.tupled(keys.head, keys.last)) - .zip(workers) - .map { case (keyRange, worker) => - worker.copy(keyRange = Some(keyRange)) - } + .map(_.head) + + val pairs = startKeys.zip(startKeys.tail.map(_.prior) :+ max) + + workers.zip(pairs).map { case (worker, (min, max)) => + worker.copy(keyRange = Some(KeyRange(min, max))) + } + } def main(args: Array[String]): Unit = { val masterArguments = new MasterArguments(args) @@ -73,11 +79,10 @@ object Master extends Logging { .flatMap(_.sampledKeys) .map(Key.fromByteString) - logger.info("[Master] Sampled") - - val sortedSampledKeys = sampledKeys.sorted + logger.info(s"[Master] Sampled $sampledKeys") - val workers = workersWithKeyRange(sortedSampledKeys, workerInfo) + val workers = + workersWithKeyRange(sampledKeys, workerInfo, Key.min(), Key.max()) logger.info(s"[Master] Key ranges with worker: $workers")