Skip to content

Commit

Permalink
SNOW-838143 Support PKCS#8 RSA Private Key (#39)
Browse files Browse the repository at this point in the history
* support pkcs#1 private key

* also report pkcs8 error message

* update comment
  • Loading branch information
sfc-gh-bli authored Jun 15, 2023
1 parent 016dfa9 commit e273f3c
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 32 deletions.
77 changes: 47 additions & 30 deletions src/main/scala/com/snowflake/snowpark/internal/ParameterUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.snowflake.snowpark.internal

import net.snowflake.client.core.SFSessionProperty

import java.security.spec.RSAPrivateCrtKeySpec
import java.security.spec.{PKCS8EncodedKeySpec, RSAPrivateCrtKeySpec}
import java.security.{GeneralSecurityException, KeyFactory, PrivateKey}
import java.util.Properties
import org.apache.commons.codec.binary.Base64
Expand Down Expand Up @@ -104,38 +104,55 @@ private[snowpark] object ParameterUtils extends Logging {
}

private[snowpark] def parsePrivateKey(key: String): PrivateKey = {
// try to parse pkcs#8 format first,
// if it fails, then try to parse pkcs#1 format.
try {
val decoded = Base64.decodeBase64(key)
val derReader = new DerInputStream(decoded)
val seq = derReader.getSequence(0)

if (seq.length < 9) {
throw new GeneralSecurityException("Could not parse a PKCS1 private key.")
}

// seq(0) is version, skip
val modulus = seq(1).getBigInteger
val publicExp = seq(2).getBigInteger
val privateExp = seq(3).getBigInteger
val prime1 = seq(4).getBigInteger
val prime2 = seq(5).getBigInteger
val exp1 = seq(6).getBigInteger
val exp2 = seq(7).getBigInteger
val crtCoef = seq(8).getBigInteger
val keySpec = new RSAPrivateCrtKeySpec(
modulus,
publicExp,
privateExp,
prime1,
prime2,
exp1,
exp2,
crtCoef)
val keyFactory = KeyFactory.getInstance("RSA")
keyFactory.generatePrivate(keySpec)
val kf = KeyFactory.getInstance("RSA")
val keySpec = new PKCS8EncodedKeySpec(decoded)
kf.generatePrivate(keySpec)
} catch {
case e: Exception =>
throw ErrorMessage.MISC_INVALID_RSA_PRIVATE_KEY(e.getMessage)
case pkcs8Exception: Exception =>
// try to read PKCS#1 key
try {
val decoded = Base64.decodeBase64(key)
val derReader = new DerInputStream(decoded)
val seq = derReader.getSequence(0)

if (seq.length < 9) {
throw new GeneralSecurityException("Could not parse a PKCS1 private key.")
}

// seq(0) is version, skip
val modulus = seq(1).getBigInteger
val publicExp = seq(2).getBigInteger
val privateExp = seq(3).getBigInteger
val prime1 = seq(4).getBigInteger
val prime2 = seq(5).getBigInteger
val exp1 = seq(6).getBigInteger
val exp2 = seq(7).getBigInteger
val crtCoef = seq(8).getBigInteger
val keySpec = new RSAPrivateCrtKeySpec(
modulus,
publicExp,
privateExp,
prime1,
prime2,
exp1,
exp2,
crtCoef)
val keyFactory = KeyFactory.getInstance("RSA")
keyFactory.generatePrivate(keySpec)
} catch {
case pkcs1Exception: Exception =>
val errorMessage =
s"""Failed to parse PKCS#8 RSA Private key
|${pkcs8Exception.getMessage}
|Failed to parse PKCS#1 RSA Private key
|${pkcs1Exception.getMessage}
|""".stripMargin
throw ErrorMessage.MISC_INVALID_RSA_PRIVATE_KEY(errorMessage)
}
}
}

Expand Down
26 changes: 24 additions & 2 deletions src/test/scala/com/snowflake/snowpark/ParameterSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ package com.snowflake.snowpark
import com.snowflake.snowpark.internal.{ParameterUtils, ServerConnection}
import net.snowflake.client.core.SFSessionProperty

import java.security.KeyPairGenerator
import java.security.spec.PKCS8EncodedKeySpec
import java.util.Base64

class ParameterSuite extends SNTestBase {

val options: Map[String, String] = Session.loadConfFromFile(defaultProfile)
Expand Down Expand Up @@ -61,8 +65,26 @@ class ParameterSuite extends SNTestBase {
// scalastyle:on
}

assertThrows[Exception](
val ex = intercept[SnowparkClientException] {
ParameterUtils
.jdbcConfig(optionWithoutKey + ("privatekey" -> "wrong key"), isScalaAPI = true))
.jdbcConfig(optionWithoutKey + ("privatekey" -> "wrong key"), isScalaAPI = true)
}
assert(ex.message.contains("Failed to parse PKCS#8 RSA Private key"))
assert(ex.message.contains("Failed to parse PKCS#1 RSA Private key"))
}

test("enable to read PKCS#8 private keys") {
// no need to verify PKCS#1 format key additionally,
// since all Github Action tests use PKCS#1 key to authenticate with Snowflake server.
ParameterUtils.parsePrivateKey(generatePKCS8Key())
}

private def generatePKCS8Key(): String = {
val keyPairGenerator = KeyPairGenerator.getInstance("RSA")
keyPairGenerator.initialize(2048)
val keyPair = keyPairGenerator.generateKeyPair()
val privateKey = keyPair.getPrivate
val encodedKeySpec = new PKCS8EncodedKeySpec(privateKey.getEncoded)
Base64.getEncoder.encodeToString(encodedKeySpec.getEncoded)
}
}

0 comments on commit e273f3c

Please sign in to comment.