Add test and handle system upgrade

This commit is contained in:
Benoit Marty 2021-05-11 13:55:29 +02:00 committed by Benoit Marty
parent cef4cf09ec
commit 91be2b6f3f
4 changed files with 228 additions and 20 deletions

View File

@ -27,6 +27,7 @@ Bugfix 🐛:
- Properly clean the back stack if the user cancel registration when waiting for email validation - Properly clean the back stack if the user cancel registration when waiting for email validation
- Fix read marker visibility/position when filtering some events - Fix read marker visibility/position when filtering some events
- Fix user invitation in case of restricted profile api (#3306) - Fix user invitation in case of restricted profile api (#3306)
- Make sure the SDK can retrieve the secret storage if the system is upgraded (#3304)
Translations 🗣: Translations 🗣:
- -

View File

@ -0,0 +1,184 @@
/*
* Copyright (c) 2021 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.matrix.android.sdk.internal.session.securestorage
import android.os.Build
import androidx.test.ext.junit.runners.AndroidJUnit4
import org.amshove.kluent.shouldBeEqualTo
import org.junit.FixMethodOrder
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.MethodSorters
import org.matrix.android.sdk.InstrumentedTest
import org.matrix.android.sdk.internal.crypto.crosssigning.fromBase64
import org.matrix.android.sdk.internal.crypto.crosssigning.toBase64NoPadding
import java.io.ByteArrayOutputStream
import java.util.UUID
@RunWith(AndroidJUnit4::class)
@FixMethodOrder(MethodSorters.JVM)
class SecretStoringUtilsTest : InstrumentedTest {
private val buildVersionSdkIntProvider = TestBuildVersionSdkIntProvider()
private val secretStoringUtils = SecretStoringUtils(context(), buildVersionSdkIntProvider)
companion object {
const val TEST_STR = "This is something I want to store safely!"
}
@Test
fun testStringNominalCaseApi21() {
val alias = generateAlias()
buildVersionSdkIntProvider.value = Build.VERSION_CODES.LOLLIPOP
// Encrypt
val encrypted = secretStoringUtils.securelyStoreString(TEST_STR, alias)
// Decrypt
val decrypted = secretStoringUtils.loadSecureSecret(encrypted, alias)
decrypted shouldBeEqualTo TEST_STR
secretStoringUtils.safeDeleteKey(alias)
}
@Test
fun testStringNominalCaseApi23() {
val alias = generateAlias()
buildVersionSdkIntProvider.value = Build.VERSION_CODES.M
// Encrypt
val encrypted = secretStoringUtils.securelyStoreString(TEST_STR, alias)
// Decrypt
val decrypted = secretStoringUtils.loadSecureSecret(encrypted, alias)
decrypted shouldBeEqualTo TEST_STR
secretStoringUtils.safeDeleteKey(alias)
}
@Test
fun testStringNominalCaseApi30() {
val alias = generateAlias()
buildVersionSdkIntProvider.value = Build.VERSION_CODES.R
// Encrypt
val encrypted = secretStoringUtils.securelyStoreString(TEST_STR, alias)
// Decrypt
val decrypted = secretStoringUtils.loadSecureSecret(encrypted, alias)
decrypted shouldBeEqualTo TEST_STR
secretStoringUtils.safeDeleteKey(alias)
}
@Test
fun testStringMigration21_23() {
val alias = generateAlias()
buildVersionSdkIntProvider.value = Build.VERSION_CODES.LOLLIPOP
// Encrypt
val encrypted = secretStoringUtils.securelyStoreString(TEST_STR, alias)
// Simulate a system upgrade
buildVersionSdkIntProvider.value = Build.VERSION_CODES.M
// Decrypt
val decrypted = secretStoringUtils.loadSecureSecret(encrypted, alias)
decrypted shouldBeEqualTo TEST_STR
secretStoringUtils.safeDeleteKey(alias)
}
@Test
fun testObjectNominalCaseApi21() {
val alias = generateAlias()
buildVersionSdkIntProvider.value = Build.VERSION_CODES.LOLLIPOP
// Encrypt
val encrypted = ByteArrayOutputStream().also { outputStream ->
outputStream.use {
secretStoringUtils.securelyStoreObject(TEST_STR, alias, it)
}
}
.toByteArray()
.toBase64NoPadding()
// Decrypt
val decrypted = encrypted.fromBase64().inputStream().use {
secretStoringUtils.loadSecureSecret<String>(it, alias)
}
decrypted shouldBeEqualTo TEST_STR
secretStoringUtils.safeDeleteKey(alias)
}
@Test
fun testObjectNominalCaseApi23() {
val alias = generateAlias()
buildVersionSdkIntProvider.value = Build.VERSION_CODES.M
// Encrypt
val encrypted = ByteArrayOutputStream().also { outputStream ->
outputStream.use {
secretStoringUtils.securelyStoreObject(TEST_STR, alias, it)
}
}
.toByteArray()
.toBase64NoPadding()
// Decrypt
val decrypted = encrypted.fromBase64().inputStream().use {
secretStoringUtils.loadSecureSecret<String>(it, alias)
}
decrypted shouldBeEqualTo TEST_STR
secretStoringUtils.safeDeleteKey(alias)
}
@Test
fun testObjectNominalCaseApi30() {
val alias = generateAlias()
buildVersionSdkIntProvider.value = Build.VERSION_CODES.R
// Encrypt
val encrypted = ByteArrayOutputStream().also { outputStream ->
outputStream.use {
secretStoringUtils.securelyStoreObject(TEST_STR, alias, it)
}
}
.toByteArray()
.toBase64NoPadding()
// Decrypt
val decrypted = encrypted.fromBase64().inputStream().use {
secretStoringUtils.loadSecureSecret<String>(it, alias)
}
decrypted shouldBeEqualTo TEST_STR
secretStoringUtils.safeDeleteKey(alias)
}
@Test
fun testObjectMigration21_23() {
val alias = generateAlias()
buildVersionSdkIntProvider.value = Build.VERSION_CODES.LOLLIPOP
// Encrypt
val encrypted = ByteArrayOutputStream().also { outputStream ->
outputStream.use {
secretStoringUtils.securelyStoreObject(TEST_STR, alias, it)
}
}
.toByteArray()
.toBase64NoPadding()
// Simulate a system upgrade
buildVersionSdkIntProvider.value = Build.VERSION_CODES.M
// Decrypt
val decrypted = encrypted.fromBase64().inputStream().use {
secretStoringUtils.loadSecureSecret<String>(it, alias)
}
decrypted shouldBeEqualTo TEST_STR
secretStoringUtils.safeDeleteKey(alias)
}
private fun generateAlias() = UUID.randomUUID().toString()
}

View File

@ -0,0 +1,25 @@
/*
* Copyright (c) 2021 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.matrix.android.sdk.internal.session.securestorage
import org.matrix.android.sdk.internal.util.system.BuildVersionSdkIntProvider
class TestBuildVersionSdkIntProvider : BuildVersionSdkIntProvider {
var value: Int = 0
override fun get() = value
}

View File

@ -34,6 +34,7 @@ import java.io.InputStream
import java.io.ObjectInputStream import java.io.ObjectInputStream
import java.io.ObjectOutputStream import java.io.ObjectOutputStream
import java.io.OutputStream import java.io.OutputStream
import java.lang.IllegalArgumentException
import java.math.BigInteger import java.math.BigInteger
import java.security.KeyPairGenerator import java.security.KeyPairGenerator
import java.security.KeyStore import java.security.KeyStore
@ -134,9 +135,13 @@ internal class SecretStoringUtils @Inject constructor(
@SuppressLint("NewApi") @SuppressLint("NewApi")
@Throws(Exception::class) @Throws(Exception::class)
fun loadSecureSecret(encrypted: ByteArray, keyAlias: String): String { fun loadSecureSecret(encrypted: ByteArray, keyAlias: String): String {
return when { encrypted.inputStream().use { inputStream ->
buildVersionSdkIntProvider.get() >= Build.VERSION_CODES.M -> decryptStringM(encrypted, keyAlias) // First get the format
else -> decryptString(encrypted, keyAlias) return when (val format = inputStream.read().toByte()) {
FORMAT_API_M -> decryptStringM(inputStream, keyAlias)
FORMAT_1 -> decryptString(inputStream, keyAlias)
else -> throw IllegalArgumentException("Unknown format $format")
}
} }
} }
@ -150,9 +155,11 @@ internal class SecretStoringUtils @Inject constructor(
@SuppressLint("NewApi") @SuppressLint("NewApi")
fun <T> loadSecureSecret(inputStream: InputStream, keyAlias: String): T? { fun <T> loadSecureSecret(inputStream: InputStream, keyAlias: String): T? {
return when { // First get the format
buildVersionSdkIntProvider.get() >= Build.VERSION_CODES.M -> loadSecureObjectM(keyAlias, inputStream) return when (val format = inputStream.read().toByte()) {
else -> loadSecureObject(keyAlias, inputStream) FORMAT_API_M -> loadSecureObjectM(keyAlias, inputStream)
FORMAT_1 -> loadSecureObject(keyAlias, inputStream)
else -> throw IllegalArgumentException("Unknown format $format")
} }
} }
@ -196,7 +203,7 @@ internal class SecretStoringUtils @Inject constructor(
.setAlias(alias) .setAlias(alias)
.setSubject(X500Principal("CN=$alias")) .setSubject(X500Principal("CN=$alias"))
.setSerialNumber(BigInteger.TEN) .setSerialNumber(BigInteger.TEN)
// .setEncryptionRequired() requires that the phone as a pin/schema // .setEncryptionRequired() requires that the phone has a pin/schema
.setStartDate(start.time) .setStartDate(start.time)
.setEndDate(end.time) .setEndDate(end.time)
.build() .build()
@ -220,8 +227,8 @@ internal class SecretStoringUtils @Inject constructor(
} }
@RequiresApi(Build.VERSION_CODES.M) @RequiresApi(Build.VERSION_CODES.M)
private fun decryptStringM(encryptedChunk: ByteArray, keyAlias: String): String { private fun decryptStringM(inputStream: InputStream, keyAlias: String): String {
val (iv, encryptedText) = formatMExtract(encryptedChunk.inputStream()) val (iv, encryptedText) = formatMExtract(inputStream)
val secretKey = getOrGenerateSymmetricKeyForAliasM(keyAlias) val secretKey = getOrGenerateSymmetricKeyForAliasM(keyAlias)
@ -249,8 +256,8 @@ internal class SecretStoringUtils @Inject constructor(
return format1Make(encryptedKey, iv, encryptedBytes) return format1Make(encryptedKey, iv, encryptedBytes)
} }
private fun decryptString(data: ByteArray, keyAlias: String): String { private fun decryptString(inputStream: InputStream, keyAlias: String): String {
val (encryptedKey, iv, encrypted) = format1Extract(ByteArrayInputStream(data)) val (encryptedKey, iv, encrypted) = format1Extract(inputStream)
// we need to decrypt the key // we need to decrypt the key
val sKeyBytes = rsaDecrypt(keyAlias, ByteArrayInputStream(encryptedKey)) val sKeyBytes = rsaDecrypt(keyAlias, ByteArrayInputStream(encryptedKey))
@ -315,9 +322,6 @@ internal class SecretStoringUtils @Inject constructor(
private fun <T> loadSecureObjectM(keyAlias: String, inputStream: InputStream): T? { private fun <T> loadSecureObjectM(keyAlias: String, inputStream: InputStream): T? {
val secretKey = getOrGenerateSymmetricKeyForAliasM(keyAlias) val secretKey = getOrGenerateSymmetricKeyForAliasM(keyAlias)
val format = inputStream.read()
assert(format.toByte() == FORMAT_API_M)
val ivSize = inputStream.read() val ivSize = inputStream.read()
val iv = ByteArray(ivSize) val iv = ByteArray(ivSize)
inputStream.read(iv, 0, ivSize) inputStream.read(iv, 0, ivSize)
@ -380,9 +384,6 @@ internal class SecretStoringUtils @Inject constructor(
} }
private fun formatMExtract(bis: InputStream): Pair<ByteArray, ByteArray> { private fun formatMExtract(bis: InputStream): Pair<ByteArray, ByteArray> {
val format = bis.read().toByte()
assert(format == FORMAT_API_M)
val ivSize = bis.read() val ivSize = bis.read()
val iv = ByteArray(ivSize) val iv = ByteArray(ivSize)
bis.read(iv, 0, ivSize) bis.read(iv, 0, ivSize)
@ -401,9 +402,6 @@ internal class SecretStoringUtils @Inject constructor(
} }
private fun format1Extract(bis: InputStream): Triple<ByteArray, ByteArray, ByteArray> { private fun format1Extract(bis: InputStream): Triple<ByteArray, ByteArray, ByteArray> {
val format = bis.read()
assert(format.toByte() == FORMAT_1)
val keySizeBig = bis.read() val keySizeBig = bis.read()
val keySizeLow = bis.read() val keySizeLow = bis.read()
val encryptedKeySize = keySizeBig.shl(8) + keySizeLow val encryptedKeySize = keySizeBig.shl(8) + keySizeLow