diff --git a/src/main/scala/org/apache/spark/shuffle/S3SingleSpillShuffleMapOutputWriter.scala b/src/main/scala/org/apache/spark/shuffle/S3SingleSpillShuffleMapOutputWriter.scala index 7577851..3845626 100644 --- a/src/main/scala/org/apache/spark/shuffle/S3SingleSpillShuffleMapOutputWriter.scala +++ b/src/main/scala/org/apache/spark/shuffle/S3SingleSpillShuffleMapOutputWriter.scala @@ -5,14 +5,17 @@ package org.apache.spark.shuffle +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter import org.apache.spark.shuffle.helper.{S3ShuffleDispatcher, S3ShuffleHelper} import org.apache.spark.storage.ShuffleDataBlockId import org.apache.spark.util.Utils import java.io.{File, FileInputStream} +import java.nio.file.{Files, Path} -class S3SingleSpillShuffleMapOutputWriter(shuffleId: Int, mapId: Long) extends SingleSpillShuffleMapOutputWriter { +class S3SingleSpillShuffleMapOutputWriter(shuffleId: Int, mapId: Long) extends SingleSpillShuffleMapOutputWriter with Logging { private lazy val dispatcher = S3ShuffleDispatcher.get @@ -21,12 +24,34 @@ class S3SingleSpillShuffleMapOutputWriter(shuffleId: Int, mapId: Long) extends S partitionLengths: Array[Long], checksums: Array[Long] ): Unit = { - val in = new FileInputStream(mapSpillFile) val block = ShuffleDataBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val out = new S3MeasureOutputStream(dispatcher.createBlock(block), block.name) - // Note: HDFS does not exposed a nio-buffered write interface. - Utils.copyStream(in, out, closeStreams = true) + if (dispatcher.rootIsLocal) { + // Use NIO to move the file if the folder is local. + val now = System.nanoTime() + val path = dispatcher.getPath(block) + val fileDestination = path.toUri.getRawPath + val dir = path.getParent + if (!dispatcher.fs.exists(dir)) { + dispatcher.fs.mkdirs(dir) + } + Files.move(mapSpillFile.toPath, Path.of(fileDestination)) + val timings = System.nanoTime() - now + + val bytes = partitionLengths.sum + val tc = TaskContext.get() + val sId = tc.stageId() + val sAt = tc.stageAttemptNumber() + val t = timings / 1000000 + val bw = bytes.toDouble / (t.toDouble / 1000) / (1024 * 1024) + logInfo(s"Statistics: Stage ${sId}.${sAt} TID ${tc.taskAttemptId()} -- " + + s"Writing ${block.name} ${bytes} took ${t} ms (${bw} MiB/s)") + } else { + // Copy using a stream. + val in = new FileInputStream(mapSpillFile) + val out = new S3MeasureOutputStream(dispatcher.createBlock(block), block.name) + Utils.copyStream(in, out, closeStreams = true) + } if (dispatcher.checksumEnabled) { S3ShuffleHelper.writeChecksum(shuffleId, mapId, checksums) diff --git a/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala b/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala index 38824d9..433bc0f 100644 --- a/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala +++ b/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala @@ -46,6 +46,7 @@ class S3ShuffleDispatcher extends Logging { } private val rootDir_ = if (useSparkShuffleFetch) fallbackStoragePath else conf.get("spark.shuffle.s3.rootDir", defaultValue = "sparkS3shuffle/") val rootDir: String = if (rootDir_.endsWith("/")) rootDir_ else rootDir_ + "/" + val rootIsLocal: Boolean = URI.create(rootDir).getScheme == "file" // Optional val bufferSize: Int = conf.getInt("spark.shuffle.s3.bufferSize", defaultValue = 8 * 1024 * 1024)