From 7878da25b88a3757db52c6fa28b0cbb9bbeb26ad Mon Sep 17 00:00:00 2001 From: Benoit Marty Date: Fri, 14 Feb 2020 18:18:10 +0100 Subject: [PATCH] Use doSync<>() to reduce boilerplate and add test checks and more cleanup --- .../internal/crypto/ssss/QuadSTests.kt | 280 +++++++----------- 1 file changed, 114 insertions(+), 166 deletions(-) diff --git a/matrix-sdk-android/src/androidTest/java/im/vector/matrix/android/internal/crypto/ssss/QuadSTests.kt b/matrix-sdk-android/src/androidTest/java/im/vector/matrix/android/internal/crypto/ssss/QuadSTests.kt index 57c065aaf8..82594f2645 100644 --- a/matrix-sdk-android/src/androidTest/java/im/vector/matrix/android/internal/crypto/ssss/QuadSTests.kt +++ b/matrix-sdk-android/src/androidTest/java/im/vector/matrix/android/internal/crypto/ssss/QuadSTests.kt @@ -40,8 +40,10 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking -import org.junit.Assert -import org.junit.Assert.fail +import org.amshove.kluent.shouldBe +import org.junit.Assert.assertEquals +import org.junit.Assert.assertNotNull +import org.junit.Assert.assertNull import org.junit.FixMethodOrder import org.junit.Test import org.junit.runner.RunWith @@ -64,28 +66,13 @@ class QuadSTests : InstrumentedTest { fun test_Generate4SKey() { val aliceSession = mTestHelper.createAccount(TestConstants.USER_ALICE, SessionTestParams(true)) - val aliceLatch = CountDownLatch(1) - val quadS = aliceSession.sharedSecretStorageService - var recoveryKey: String? = null - val TEST_KEY_ID = "my.test.Key" - quadS.generateKey(TEST_KEY_ID, "Test Key", emptyKeySigner, - object : MatrixCallback { - override fun onSuccess(data: SsssKeyCreationInfo) { - recoveryKey = data.recoveryKey - aliceLatch.countDown() - } - - override fun onFailure(failure: Throwable) { - Assert.fail("onFailure " + failure.localizedMessage) - aliceLatch.countDown() - } - }) - - mTestHelper.await(aliceLatch) + val ssssKeyCreationInfo = mTestHelper.doSync { + quadS.generateKey(TEST_KEY_ID, "Test Key", emptyKeySigner, it) + } // Assert Account data is updated val accountDataLock = CountDownLatch(1) @@ -104,18 +91,18 @@ class QuadSTests : InstrumentedTest { mTestHelper.await(accountDataLock) - Assert.assertNotNull("Key should be stored in account data", accountData) + assertNotNull("Key should be stored in account data", accountData) val parsed = SecretStorageKeyContent.fromJson(accountData!!.content) - Assert.assertNotNull("Key Content cannot be parsed", parsed) - Assert.assertEquals("Unexpected Algorithm", SSSS_ALGORITHM_CURVE25519_AES_SHA2, parsed!!.algorithm) - Assert.assertEquals("Unexpected key name", "Test Key", parsed.name) - Assert.assertNull("Key was not generated from passphrase", parsed.passphrase) - Assert.assertNotNull("Pubkey should be defined", parsed.publicKey) + assertNotNull("Key Content cannot be parsed", parsed) + assertEquals("Unexpected Algorithm", SSSS_ALGORITHM_CURVE25519_AES_SHA2, parsed!!.algorithm) + assertEquals("Unexpected key name", "Test Key", parsed.name) + assertNull("Key was not generated from passphrase", parsed.passphrase) + assertNotNull("Pubkey should be defined", parsed.publicKey) - val privateKeySpec = Curve25519AesSha2KeySpec.fromRecoveryKey(recoveryKey!!) + val privateKeySpec = Curve25519AesSha2KeySpec.fromRecoveryKey(ssssKeyCreationInfo.recoveryKey) DefaultSharedSecretStorageService.withOlmDecryption { olmPkDecryption -> val pubKey = olmPkDecryption.setPrivateKey(privateKeySpec!!.privateKey) - Assert.assertEquals("Unexpected Public Key", pubKey, parsed.publicKey) + assertEquals("Unexpected Public Key", pubKey, parsed.publicKey) } // Set as default key @@ -137,8 +124,8 @@ class QuadSTests : InstrumentedTest { mTestHelper.await(defaultDataLock) - Assert.assertNotNull(defaultKeyAccountData?.content) - Assert.assertEquals("Unexpected default key ${defaultKeyAccountData?.content}", TEST_KEY_ID, defaultKeyAccountData?.content?.get("key")) + assertNotNull(defaultKeyAccountData?.content) + assertEquals("Unexpected default key ${defaultKeyAccountData?.content}", TEST_KEY_ID, defaultKeyAccountData?.content?.get("key")) mTestHelper.signout(aliceSession) } @@ -150,52 +137,40 @@ class QuadSTests : InstrumentedTest { val info = generatedSecret(aliceSession, keyId, true) // Store a secret - - val storeCountDownLatch = CountDownLatch(1) val clearSecret = Base64.encodeToString("42".toByteArray(), Base64.NO_PADDING or Base64.NO_WRAP) - aliceSession.sharedSecretStorageService.storeSecret( - "secret.of.life", - clearSecret, - null, // default key - TestMatrixCallback(storeCountDownLatch) - ) + mTestHelper.doSync { + aliceSession.sharedSecretStorageService.storeSecret( + "secret.of.life", + clearSecret, + null, // default key + it + ) + } val secretAccountData = assertAccountData(aliceSession, "secret.of.life") val encryptedContent = secretAccountData.content.get("encrypted") as? Map<*, *> - Assert.assertNotNull("Element should be encrypted", encryptedContent) - Assert.assertNotNull("Secret should be encrypted with default key", encryptedContent?.get(keyId)) + assertNotNull("Element should be encrypted", encryptedContent) + assertNotNull("Secret should be encrypted with default key", encryptedContent?.get(keyId)) val secret = EncryptedSecretContent.fromJson(encryptedContent?.get(keyId)) - Assert.assertNotNull(secret?.ciphertext) - Assert.assertNotNull(secret?.mac) - Assert.assertNotNull(secret?.ephemeral) + assertNotNull(secret?.ciphertext) + assertNotNull(secret?.mac) + assertNotNull(secret?.ephemeral) // Try to decrypt?? val keySpec = Curve25519AesSha2KeySpec.fromRecoveryKey(info.recoveryKey) - var decryptedSecret: String? = null + val decryptedSecret = mTestHelper.doSync { + aliceSession.sharedSecretStorageService.getSecret("secret.of.life", + null, // default key + keySpec!!, + it + ) + } - val decryptCountDownLatch = CountDownLatch(1) - aliceSession.sharedSecretStorageService.getSecret("secret.of.life", - null, // default key - keySpec!!, - object : MatrixCallback { - override fun onFailure(failure: Throwable) { - fail("Fail to decrypt -> " + failure.localizedMessage) - decryptCountDownLatch.countDown() - } - - override fun onSuccess(data: String) { - decryptedSecret = data - decryptCountDownLatch.countDown() - } - } - ) - mTestHelper.await(decryptCountDownLatch) - - Assert.assertEquals("Secret mismatch", clearSecret, decryptedSecret) + assertEquals("Secret mismatch", clearSecret, decryptedSecret) mTestHelper.signout(aliceSession) } @@ -207,16 +182,14 @@ class QuadSTests : InstrumentedTest { val TEST_KEY_ID = "my.test.Key" - val countDownLatch = CountDownLatch(1) - quadS.generateKey(TEST_KEY_ID, "Test Key", emptyKeySigner, - TestMatrixCallback(countDownLatch)) - - mTestHelper.await(countDownLatch) + mTestHelper.doSync { + quadS.generateKey(TEST_KEY_ID, "Test Key", emptyKeySigner, it) + } // Test that we don't need to wait for an account data sync to access directly the keyid from DB - val defaultLatch = CountDownLatch(1) - quadS.setDefaultKey(TEST_KEY_ID, TestMatrixCallback(defaultLatch)) - mTestHelper.await(defaultLatch) + mTestHelper.doSync { + quadS.setDefaultKey(TEST_KEY_ID, it) + } mTestHelper.signout(aliceSession) } @@ -231,38 +204,39 @@ class QuadSTests : InstrumentedTest { val mySecretText = "Lorem ipsum dolor sit amet, consectetur adipiscing elit" - val storeLatch = CountDownLatch(1) - aliceSession.sharedSecretStorageService.storeSecret( - "my.secret", - mySecretText.toByteArray().toBase64NoPadding(), - listOf(keyId1, keyId2), - TestMatrixCallback(storeLatch) - ) - mTestHelper.await(storeLatch) + mTestHelper.doSync { + aliceSession.sharedSecretStorageService.storeSecret( + "my.secret", + mySecretText.toByteArray().toBase64NoPadding(), + listOf(keyId1, keyId2), + it + ) + } val accountDataEvent = aliceSession.getAccountDataEvent("my.secret") val encryptedContent = accountDataEvent?.content?.get("encrypted") as? Map<*, *> - Assert.assertEquals("Content should contains two encryptions", 2, encryptedContent?.keys?.size ?: 0) + assertEquals("Content should contains two encryptions", 2, encryptedContent?.keys?.size ?: 0) - Assert.assertNotNull(encryptedContent?.get(keyId1)) - Assert.assertNotNull(encryptedContent?.get(keyId2)) + assertNotNull(encryptedContent?.get(keyId1)) + assertNotNull(encryptedContent?.get(keyId2)) // Assert that can decrypt with both keys - val decryptCountDownLatch = CountDownLatch(2) - aliceSession.sharedSecretStorageService.getSecret("my.secret", - keyId1, - Curve25519AesSha2KeySpec.fromRecoveryKey(key1Info.recoveryKey)!!, - TestMatrixCallback(decryptCountDownLatch) - ) + mTestHelper.doSync { + aliceSession.sharedSecretStorageService.getSecret("my.secret", + keyId1, + Curve25519AesSha2KeySpec.fromRecoveryKey(key1Info.recoveryKey)!!, + it + ) + } - aliceSession.sharedSecretStorageService.getSecret("my.secret", - keyId2, - Curve25519AesSha2KeySpec.fromRecoveryKey(key2Info.recoveryKey)!!, - TestMatrixCallback(decryptCountDownLatch) - ) - - mTestHelper.await(decryptCountDownLatch) + mTestHelper.doSync { + aliceSession.sharedSecretStorageService.getSecret("my.secret", + keyId2, + Curve25519AesSha2KeySpec.fromRecoveryKey(key2Info.recoveryKey)!!, + it + ) + } mTestHelper.signout(aliceSession) } @@ -276,16 +250,17 @@ class QuadSTests : InstrumentedTest { val mySecretText = "Lorem ipsum dolor sit amet, consectetur adipiscing elit" - val storeLatch = CountDownLatch(1) - aliceSession.sharedSecretStorageService.storeSecret( - "my.secret", - mySecretText.toByteArray().toBase64NoPadding(), - listOf(keyId1), - TestMatrixCallback(storeLatch) - ) - mTestHelper.await(storeLatch) + mTestHelper.doSync { + aliceSession.sharedSecretStorageService.storeSecret( + "my.secret", + mySecretText.toByteArray().toBase64NoPadding(), + listOf(keyId1), + it + ) + } - val decryptCountDownLatch = CountDownLatch(2) + val decryptCountDownLatch = CountDownLatch(1) + var error = false aliceSession.sharedSecretStorageService.getSecret("my.secret", keyId1, Curve25519AesSha2KeySpec.fromPassphrase( @@ -296,29 +271,32 @@ class QuadSTests : InstrumentedTest { object : MatrixCallback { override fun onSuccess(data: String) { decryptCountDownLatch.countDown() - fail("Should not be able to decrypt") } override fun onFailure(failure: Throwable) { - Assert.assertTrue(true) + error = true decryptCountDownLatch.countDown() } } ) - // Now try with correct key - aliceSession.sharedSecretStorageService.getSecret("my.secret", - keyId1, - Curve25519AesSha2KeySpec.fromPassphrase( - passphrase, - key1Info.content?.passphrase?.salt ?: "", - key1Info.content?.passphrase?.iterations ?: 0, - null), - TestMatrixCallback(decryptCountDownLatch) - ) - mTestHelper.await(decryptCountDownLatch) + error shouldBe true + + // Now try with correct key + mTestHelper.doSync { + aliceSession.sharedSecretStorageService.getSecret("my.secret", + keyId1, + Curve25519AesSha2KeySpec.fromPassphrase( + passphrase, + key1Info.content?.passphrase?.salt ?: "", + key1Info.content?.passphrase?.iterations ?: 0, + null), + it + ) + } + mTestHelper.signout(aliceSession) } @@ -338,7 +316,7 @@ class QuadSTests : InstrumentedTest { GlobalScope.launch(Dispatchers.Main) { liveAccountData.observeForever(accountDataObserver) } mTestHelper.await(accountDataLock) - Assert.assertNotNull("Account Data type:$type should be found", accountData) + assertNotNull("Account Data type:$type should be found", accountData) return accountData!! } @@ -346,64 +324,34 @@ class QuadSTests : InstrumentedTest { private fun generatedSecret(session: Session, keyId: String, asDefault: Boolean = true): SsssKeyCreationInfo { val quadS = session.sharedSecretStorageService - var creationInfo: SsssKeyCreationInfo? = null - - val generateLatch = CountDownLatch(1) - - quadS.generateKey(keyId, keyId, emptyKeySigner, - object : MatrixCallback { - override fun onSuccess(data: SsssKeyCreationInfo) { - creationInfo = data - generateLatch.countDown() - } - - override fun onFailure(failure: Throwable) { - Assert.fail("onFailure " + failure.localizedMessage) - generateLatch.countDown() - } - }) - - mTestHelper.await(generateLatch) - - Assert.assertNotNull(creationInfo) + val creationInfo = mTestHelper.doSync { + quadS.generateKey(keyId, keyId, emptyKeySigner, it) + } assertAccountData(session, "m.secret_storage.key.$keyId") + if (asDefault) { - val setDefaultLatch = CountDownLatch(1) - quadS.setDefaultKey(keyId, TestMatrixCallback(setDefaultLatch)) - mTestHelper.await(setDefaultLatch) + mTestHelper.doSync { + quadS.setDefaultKey(keyId, it) + } assertAccountData(session, DefaultSharedSecretStorageService.DEFAULT_KEY_ID) } - return creationInfo!! + return creationInfo } private fun generatedSecretFromPassphrase(session: Session, passphrase: String, keyId: String, asDefault: Boolean = true): SsssKeyCreationInfo { val quadS = session.sharedSecretStorageService - var creationInfo: SsssKeyCreationInfo? = null - - val generateLatch = CountDownLatch(1) - - quadS.generateKeyWithPassphrase(keyId, keyId, - passphrase, - emptyKeySigner, - null, - object : MatrixCallback { - override fun onSuccess(data: SsssKeyCreationInfo) { - creationInfo = data - generateLatch.countDown() - } - - override fun onFailure(failure: Throwable) { - Assert.fail("onFailure " + failure.localizedMessage) - generateLatch.countDown() - } - }) - - mTestHelper.await(generateLatch) - - Assert.assertNotNull(creationInfo) + val creationInfo = mTestHelper.doSync { + quadS.generateKeyWithPassphrase( + keyId, + keyId, + passphrase, + emptyKeySigner, + null, + it) + } assertAccountData(session, "m.secret_storage.key.$keyId") if (asDefault) { @@ -413,6 +361,6 @@ class QuadSTests : InstrumentedTest { assertAccountData(session, DefaultSharedSecretStorageService.DEFAULT_KEY_ID) } - return creationInfo!! + return creationInfo } }