Github user vanzin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/8880#discussion_r47144401
  
    --- Diff: 
yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleEncryptionSuite.scala
 ---
    @@ -0,0 +1,226 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.deploy.yarn
    +
    +import java.io.{File, FileInputStream, IOException, InputStream, 
OutputStream}
    +import java.util.{ArrayList, LinkedList, List, Random, UUID}
    +import java.security.PrivilegedExceptionAction
    +
    +import scala.runtime.AbstractFunction1
    +
    +import com.google.common.collect.HashMultiset
    +import com.google.common.io.ByteStreams
    +import org.apache.hadoop.security.{Credentials, UserGroupInformation}
    +import org.junit.Assert.assertEquals
    +import org.mockito.Mock
    +import org.mockito.MockitoAnnotations
    +import org.mockito.invocation.InvocationOnMock
    +import org.mockito.stubbing.Answer
    +import org.mockito.Answers.RETURNS_SMART_NULLS
    +import org.mockito.Matchers.any
    +import org.mockito.Matchers.anyInt
    +import org.mockito.Mockito.doAnswer
    +import org.mockito.Mockito.when
    +import org.scalatest.{BeforeAndAfterAll, Matchers}
    +
    +import org.apache.spark._
    +import org.apache.spark.crypto.{CryptoConf, CryptoStreamUtils}
    +import org.apache.spark.deploy.SparkHadoopUtil
    +import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics}
    +import org.apache.spark.io.CompressionCodec
    +import org.apache.spark.memory.{TestMemoryManager, TaskMemoryManager}
    +import org.apache.spark.network.util.LimitedInputStream
    +import org.apache.spark.serializer.{KryoSerializer, SerializerInstance}
    +import org.apache.spark.shuffle.IndexShuffleBlockResolver
    +import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, 
UnsafeShuffleWriter}
    +import org.apache.spark.storage.{BlockId, BlockManager, DiskBlockManager, 
DiskBlockObjectWriter,
    +  TempShuffleBlockId}
    +import org.apache.spark.util.Utils
    +
    +class YarnShuffleEncryptionSuite extends SparkFunSuite with Matchers with 
BeforeAndAfterAll{
    +  val NUM_PARTITITONS = 4
    +  val conf = new SparkConf()
    +  val tempDir = Utils.createTempDir("test", "test")
    +  val mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir)
    +  val memoryManager = new TestMemoryManager(conf)
    +  val taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
    +  val serializer = new KryoSerializer(new SparkConf())
    +  val hashPartitioner = new HashPartitioner(NUM_PARTITITONS)
    +  val taskMetrics = new TaskMetrics()
    +  val spillFilesCreated = new LinkedList[File]()
    +
    +  var partitionSizesInMergedFile: Array[Long] = null
    +  val ugi = UserGroupInformation.createUserForTesting("testuser", 
Array[String] {
    +    "testgroup"
    +  })
    +
    +  @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: 
BlockManager = _
    +  @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: 
DiskBlockManager = _
    +  @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext 
= _
    +  @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: 
IndexShuffleBlockResolver = _
    +  @Mock(
    +    answer = RETURNS_SMART_NULLS) private var shuffleDep: 
ShuffleDependency[Object, Object,
    +      Object] = _
    +
    +  override def beforeAll(): Unit = {
    +    MockitoAnnotations.initMocks(this)
    +    System.setProperty("SPARK_YARN_MODE", "true")
    +    ugi.doAs(new PrivilegedExceptionAction[Unit]() {
    +      override def run(): Unit = {
    +        initialShuffleWriter()
    +      }
    +    })
    +  }
    +
    +  override def afterAll(): Unit = {
    +    Utils.deleteRecursively(tempDir);
    +    val leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
    +    if (leakedMemory != 0) {
    +      fail("Test leaked " + leakedMemory + " bytes of managed memory");
    +    }
    +  }
    +
    +  test("Test Yarn shuffle encryption basic write") {
    +    val doAsAction = new PrivilegedExceptionAction[Unit] {
    +      override def run(): Unit = {
    +        testYarnShuffleEncryptionWrite()
    +      }
    +    }
    +    ugi.doAs(doAsAction)
    +  }
    +
    +  private[this] def testYarnShuffleEncryptionWrite(): Unit = {
    +    val dataToWrite = new ArrayList[Product2[Object, Object]]()
    +    for (i <- 1 to NUM_PARTITITONS) {
    +      val current = i - 1
    +      dataToWrite.add(
    +        new Tuple2[Object, Object](current.asInstanceOf[Object], 
current.asInstanceOf[Object]))
    +    }
    +    val writer = createWriter(true)
    +    writer.write(dataToWrite.iterator())
    +    writer.stop(true)
    +
    +    assertEquals(HashMultiset.create(dataToWrite), 
HashMultiset.create(readRecordsFromFile()))
    +  }
    +
    +  @throws(classOf[IOException])
    +  private[this] def createWriter(isShuffleEncrypted: Boolean): 
UnsafeShuffleWriter[Object,
    +      Object] = {
    +    conf.set("spark.shuffle.encryption.enabled", 
String.valueOf(isShuffleEncrypted))
    +    new UnsafeShuffleWriter[Object, Object](
    +      blockManager,
    +      blockResolver,
    +      taskMemoryManager,
    +      new SerializedShuffleHandle[Object, Object](0, 1, shuffleDep),
    +      0, // map id
    +      taskContext,
    +      conf
    +    )
    +  }
    +
    +  private[this] def readRecordsFromFile(): List[Tuple2[Object, Object]] = {
    +    val recordsList = new ArrayList[Tuple2[Object, Object]]()
    +    var startOffset = 0L
    +    for (i <- 1 to NUM_PARTITITONS) {
    +      val partitionSize = partitionSizesInMergedFile(i - 1)
    +      if (partitionSize > 0) {
    +        var in: InputStream = new FileInputStream(mergedOutputFile)
    +        ByteStreams.skipFully(in, startOffset)
    +        in = new LimitedInputStream(in, partitionSize)
    +        if (CryptoConf.isShuffleEncryptionEnabled(conf)) {
    +          in = CryptoStreamUtils.createCryptoInputStream(in, conf)
    +        }
    +        if (conf.getBoolean("spark.shuffle.compress", true)) {
    +          in = CompressionCodec.createCodec(conf).compressedInputStream(in)
    +        }
    +        val recordsStream = serializer.newInstance().deserializeStream(in)
    +        val records = recordsStream.asKeyValueIterator
    +        while (records.hasNext) {
    +          val record = records.next()
    +          assertEquals(i - 1, hashPartitioner.getPartition(record._1))
    +          recordsList.add(record.asInstanceOf[Tuple2[Object, Object]])
    +        }
    +        recordsStream.close()
    +        startOffset += partitionSize
    +      }
    +    }
    +    recordsList
    +  }
    +
    +  private[this] def initialShuffleWriter(): Unit = {
    +    val keys = new Array[Byte](16)
    +    new Random().nextBytes(keys)
    +    val creds = new Credentials()
    +    creds.addSecretKey(CryptoConf.SPARK_SHUFFLE_TOKEN, keys)
    +    SparkHadoopUtil.get.addCurrentUserCredentials(creds)
    +    when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
    +    when(blockManager.conf).thenReturn(conf)
    +    when(blockManager.getDiskWriter(any(classOf[BlockId]), 
any(classOf[File]),
    +      any(classOf[SerializerInstance]), anyInt, 
any(classOf[ShuffleWriteMetrics]))).thenAnswer(
    +          new Answer[DiskBlockObjectWriter]() {
    +            @throws(classOf[Throwable])
    +            override def answer(invocationOnMock: InvocationOnMock): 
DiskBlockObjectWriter = {
    +              val args = invocationOnMock.getArguments()
    +              new DiskBlockObjectWriter(args(1).asInstanceOf[File],
    +                args(2).asInstanceOf[SerializerInstance],
    +                args(3).asInstanceOf[Integer], new CompressStream(), false,
    +                args(4).asInstanceOf[ShuffleWriteMetrics], 
args(0).asInstanceOf[BlockId],
    +                conf)
    +            }
    +          })
    +
    +    when(blockResolver.getDataFile(anyInt(), 
anyInt())).thenReturn(mergedOutputFile);
    +    doAnswer(new Answer[Unit]() {
    +      @throws(classOf[Throwable])
    +      override def answer(invocationOnMock: InvocationOnMock): Unit = {
    +        partitionSizesInMergedFile = 
(invocationOnMock.getArguments()(2)).asInstanceOf[Array[Long]]
    +        val tmp = invocationOnMock.getArguments()(3)
    +        mergedOutputFile.delete()
    +        tmp.asInstanceOf[File].renameTo(mergedOutputFile)
    +      }
    +    }).when(blockResolver).writeIndexFileAndCommit(anyInt(), anyInt(), 
any(classOf[Array[Long]]),
    +          any(classOf[File]))
    +
    +    when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
    +      new Answer[Tuple2[TempShuffleBlockId, File]]() {
    +        @throws(classOf[Throwable])
    +        override def answer(invocationOnMock: InvocationOnMock): 
Tuple2[TempShuffleBlockId, File]
    --- End diff --
    
    Can you say `(TempShuffleBlockId, File)` instead of 
`Tuple2[TempShuffleBlockId, File]`? Then the `= {` fits in this line.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to