Merge pull request #5380 from vector-im/feature/bca/crypto_fix_rolling_uisi

Refactoring for safer olm and megolm session usage
This commit is contained in:
Benoit Marty 2022-03-10 11:13:07 +01:00 committed by GitHub
commit ce4ad8819a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1520 additions and 518 deletions

View File

@ -23,7 +23,7 @@ object TestConstants {
const val TESTS_HOME_SERVER_URL = "http://10.0.2.2:8080" const val TESTS_HOME_SERVER_URL = "http://10.0.2.2:8080"
// Time out to use when waiting for server response. // Time out to use when waiting for server response.
private const val AWAIT_TIME_OUT_MILLIS = 30_000 private const val AWAIT_TIME_OUT_MILLIS = 60_000
// Time out to use when waiting for server response, when the debugger is connected. 10 minutes // Time out to use when waiting for server response, when the debugger is connected. 10 minutes
private const val AWAIT_TIME_OUT_WITH_DEBUGGER_MILLIS = 10 * 60_000 private const val AWAIT_TIME_OUT_WITH_DEBUGGER_MILLIS = 10 * 60_000

View File

@ -0,0 +1,648 @@
/*
* Copyright 2022 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.matrix.android.sdk.internal.crypto
import android.util.Log
import androidx.test.filters.LargeTest
import kotlinx.coroutines.delay
import org.amshove.kluent.fail
import org.amshove.kluent.internal.assertEquals
import org.junit.Assert
import org.junit.FixMethodOrder
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.runners.MethodSorters
import org.matrix.android.sdk.InstrumentedTest
import org.matrix.android.sdk.api.session.Session
import org.matrix.android.sdk.api.session.crypto.MXCryptoError
import org.matrix.android.sdk.api.session.events.model.EventType
import org.matrix.android.sdk.api.session.events.model.toModel
import org.matrix.android.sdk.api.session.room.Room
import org.matrix.android.sdk.api.session.room.failure.JoinRoomFailure
import org.matrix.android.sdk.api.session.room.model.Membership
import org.matrix.android.sdk.api.session.room.model.message.MessageContent
import org.matrix.android.sdk.api.session.room.send.SendState
import org.matrix.android.sdk.api.session.room.timeline.TimelineSettings
import org.matrix.android.sdk.common.CommonTestHelper
import org.matrix.android.sdk.common.CryptoTestHelper
import org.matrix.android.sdk.common.SessionTestParams
import org.matrix.android.sdk.common.TestMatrixCallback
import org.matrix.android.sdk.internal.crypto.algorithms.olm.OlmDecryptionResult
import org.matrix.android.sdk.internal.crypto.keysbackup.model.MegolmBackupCreationInfo
import org.matrix.android.sdk.internal.crypto.keysbackup.model.rest.KeysVersion
import org.matrix.android.sdk.internal.crypto.keysbackup.model.rest.KeysVersionResult
import org.matrix.android.sdk.internal.crypto.model.ImportRoomKeysResult
import org.matrix.android.sdk.internal.crypto.model.event.EncryptedEventContent
@RunWith(JUnit4::class)
@FixMethodOrder(MethodSorters.JVM)
@LargeTest
class E2eeSanityTests : InstrumentedTest {
private val testHelper = CommonTestHelper(context())
private val cryptoTestHelper = CryptoTestHelper(testHelper)
/**
* Simple test that create an e2ee room.
* Some new members are added, and a message is sent.
* We check that the message is e2e and can be decrypted.
*
* Additional users join, we check that they can't decrypt history
*
* Alice sends a new message, then check that the new one can be decrypted
*/
@Test
fun testSendingE2EEMessages() {
val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true)
val aliceSession = cryptoTestData.firstSession
val e2eRoomID = cryptoTestData.roomId
val aliceRoomPOV = aliceSession.getRoom(e2eRoomID)!!
// add some more users and invite them
val otherAccounts = listOf("benoit", "valere", "ganfra") // , "adam", "manu")
.map {
testHelper.createAccount(it, SessionTestParams(true))
}
Log.v("#E2E TEST", "All accounts created")
// we want to invite them in the room
otherAccounts.forEach {
testHelper.runBlockingTest {
Log.v("#E2E TEST", "Alice invites ${it.myUserId}")
aliceRoomPOV.invite(it.myUserId)
}
}
// All user should accept invite
otherAccounts.forEach { otherSession ->
waitForAndAcceptInviteInRoom(otherSession, e2eRoomID)
Log.v("#E2E TEST", "${otherSession.myUserId} joined room $e2eRoomID")
}
// check that alice see them as joined (not really necessary?)
ensureMembersHaveJoined(aliceSession, otherAccounts, e2eRoomID)
Log.v("#E2E TEST", "All users have joined the room")
Log.v("#E2E TEST", "Alice is sending the message")
val text = "This is my message"
val sentEventId: String? = sendMessageInRoom(aliceRoomPOV, text)
// val sentEvent = testHelper.sendTextMessage(aliceRoomPOV, "Hello all", 1).first()
Assert.assertTrue("Message should be sent", sentEventId != null)
// All should be able to decrypt
otherAccounts.forEach { otherSession ->
testHelper.waitWithLatch { latch ->
testHelper.retryPeriodicallyWithLatch(latch) {
val timeLineEvent = otherSession.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId!!)
timeLineEvent != null &&
timeLineEvent.isEncrypted() &&
timeLineEvent.root.getClearType() == EventType.MESSAGE
}
}
}
// Add a new user to the room, and check that he can't decrypt
val newAccount = listOf("adam") // , "adam", "manu")
.map {
testHelper.createAccount(it, SessionTestParams(true))
}
newAccount.forEach {
testHelper.runBlockingTest {
Log.v("#E2E TEST", "Alice invites ${it.myUserId}")
aliceRoomPOV.invite(it.myUserId)
}
}
newAccount.forEach {
waitForAndAcceptInviteInRoom(it, e2eRoomID)
}
ensureMembersHaveJoined(aliceSession, newAccount, e2eRoomID)
// wait a bit
testHelper.runBlockingTest {
delay(3_000)
}
// check that messages are encrypted (uisi)
newAccount.forEach { otherSession ->
testHelper.waitWithLatch { latch ->
testHelper.retryPeriodicallyWithLatch(latch) {
val timeLineEvent = otherSession.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId!!).also {
Log.v("#E2E TEST", "Event seen by new user ${it?.root?.getClearType()}|${it?.root?.mCryptoError}")
}
timeLineEvent != null &&
timeLineEvent.root.getClearType() == EventType.ENCRYPTED &&
timeLineEvent.root.mCryptoError == MXCryptoError.ErrorType.UNKNOWN_INBOUND_SESSION_ID
}
}
}
// Let alice send a new message
Log.v("#E2E TEST", "Alice sends a new message")
val secondMessage = "2 This is my message"
val secondSentEventId: String? = sendMessageInRoom(aliceRoomPOV, secondMessage)
// new members should be able to decrypt it
newAccount.forEach { otherSession ->
testHelper.waitWithLatch { latch ->
testHelper.retryPeriodicallyWithLatch(latch) {
val timeLineEvent = otherSession.getRoom(e2eRoomID)?.getTimeLineEvent(secondSentEventId!!).also {
Log.v("#E2E TEST", "Second Event seen by new user ${it?.root?.getClearType()}|${it?.root?.mCryptoError}")
}
timeLineEvent != null &&
timeLineEvent.root.getClearType() == EventType.MESSAGE &&
secondMessage.equals(timeLineEvent.root.getClearContent().toModel<MessageContent>()?.body)
}
}
}
otherAccounts.forEach {
testHelper.signOutAndClose(it)
}
newAccount.forEach { testHelper.signOutAndClose(it) }
cryptoTestData.cleanUp(testHelper)
}
/**
* Quick test for basic keybackup
* 1. Create e2e between Alice and Bob
* 2. Alice sends 3 messages, using 3 different sessions
* 3. Ensure bob can decrypt
* 4. Create backup for bob and uplaod keys
*
* 5. Sign out alice and bob to ensure no gossiping will happen
*
* 6. Let bob sign in with a new session
* 7. Ensure history is UISI
* 8. Import backup
* 9. Check that new session can decrypt
*/
@Test
fun testBasicBackupImport() {
val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true)
val aliceSession = cryptoTestData.firstSession
val bobSession = cryptoTestData.secondSession!!
val e2eRoomID = cryptoTestData.roomId
Log.v("#E2E TEST", "Create and start keybackup for bob ...")
val keysBackupService = bobSession.cryptoService().keysBackupService()
val keyBackupPassword = "FooBarBaz"
val megolmBackupCreationInfo = testHelper.doSync<MegolmBackupCreationInfo> {
keysBackupService.prepareKeysBackupVersion(keyBackupPassword, null, it)
}
val version = testHelper.doSync<KeysVersion> {
keysBackupService.createKeysBackupVersion(megolmBackupCreationInfo, it)
}
Log.v("#E2E TEST", "... Key backup started and enabled for bob")
// Bob session should now have
val aliceRoomPOV = aliceSession.getRoom(e2eRoomID)!!
// let's send a few message to bob
val sentEventIds = mutableListOf<String>()
val messagesText = listOf("1. Hello", "2. Bob", "3. Good morning")
messagesText.forEach { text ->
val sentEventId = sendMessageInRoom(aliceRoomPOV, text)!!.also {
sentEventIds.add(it)
}
testHelper.waitWithLatch { latch ->
testHelper.retryPeriodicallyWithLatch(latch) {
val timeLineEvent = bobSession.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId)
timeLineEvent != null &&
timeLineEvent.isEncrypted() &&
timeLineEvent.root.getClearType() == EventType.MESSAGE
}
}
// we want more so let's discard the session
aliceSession.cryptoService().discardOutboundSession(e2eRoomID)
testHelper.runBlockingTest {
delay(1_000)
}
}
Log.v("#E2E TEST", "Bob received all and can decrypt")
// Let's wait a bit to be sure that bob has backed up the session
Log.v("#E2E TEST", "Force key backup for Bob...")
testHelper.waitWithLatch { latch ->
keysBackupService.backupAllGroupSessions(
null,
TestMatrixCallback(latch, true)
)
}
Log.v("#E2E TEST", "... Keybackup done for Bob")
// Now lets logout both alice and bob to ensure that we won't have any gossiping
val bobUserId = bobSession.myUserId
Log.v("#E2E TEST", "Logout alice and bob...")
testHelper.signOutAndClose(aliceSession)
testHelper.signOutAndClose(bobSession)
Log.v("#E2E TEST", "..Logout alice and bob...")
testHelper.runBlockingTest {
delay(1_000)
}
// Create a new session for bob
Log.v("#E2E TEST", "Create a new session for Bob")
val newBobSession = testHelper.logIntoAccount(bobUserId, SessionTestParams(true))
// check that bob can't currently decrypt
Log.v("#E2E TEST", "check that bob can't currently decrypt")
sentEventIds.forEach { sentEventId ->
testHelper.waitWithLatch { latch ->
testHelper.retryPeriodicallyWithLatch(latch) {
val timeLineEvent = newBobSession.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId)?.also {
Log.v("#E2E TEST", "Event seen by new user ${it.root.getClearType()}|${it.root.mCryptoError}")
}
timeLineEvent != null &&
timeLineEvent.root.getClearType() == EventType.ENCRYPTED
}
}
}
// after initial sync events are not decrypted, so we have to try manually
ensureCannotDecrypt(sentEventIds, newBobSession, e2eRoomID, MXCryptoError.ErrorType.UNKNOWN_INBOUND_SESSION_ID)
// Let's now import keys from backup
newBobSession.cryptoService().keysBackupService().let { keysBackupService ->
val keyVersionResult = testHelper.doSync<KeysVersionResult?> {
keysBackupService.getVersion(version.version, it)
}
val importedResult = testHelper.doSync<ImportRoomKeysResult> {
keysBackupService.restoreKeyBackupWithPassword(keyVersionResult!!,
keyBackupPassword,
null,
null,
null, it)
}
assertEquals(3, importedResult.totalNumberOfKeys)
}
// ensure bob can now decrypt
ensureCanDecrypt(sentEventIds, newBobSession, e2eRoomID, messagesText)
testHelper.signOutAndClose(newBobSession)
}
/**
* Check that a new verified session that was not supposed to get the keys initially will
* get them from an older one.
*/
@Test
fun testSimpleGossip() {
val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true)
val aliceSession = cryptoTestData.firstSession
val bobSession = cryptoTestData.secondSession!!
val e2eRoomID = cryptoTestData.roomId
val aliceRoomPOV = aliceSession.getRoom(e2eRoomID)!!
cryptoTestHelper.initializeCrossSigning(bobSession)
// let's send a few message to bob
val sentEventIds = mutableListOf<String>()
val messagesText = listOf("1. Hello", "2. Bob")
Log.v("#E2E TEST", "Alice sends some messages")
messagesText.forEach { text ->
val sentEventId = sendMessageInRoom(aliceRoomPOV, text)!!.also {
sentEventIds.add(it)
}
testHelper.waitWithLatch { latch ->
testHelper.retryPeriodicallyWithLatch(latch) {
val timeLineEvent = bobSession.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId)
timeLineEvent != null &&
timeLineEvent.isEncrypted() &&
timeLineEvent.root.getClearType() == EventType.MESSAGE
}
}
}
// Ensure bob can decrypt
ensureIsDecrypted(sentEventIds, bobSession, e2eRoomID)
// Let's now add a new bob session
// Create a new session for bob
Log.v("#E2E TEST", "Create a new session for Bob")
val newBobSession = testHelper.logIntoAccount(bobSession.myUserId, SessionTestParams(true))
// check that new bob can't currently decrypt
Log.v("#E2E TEST", "check that new bob can't currently decrypt")
ensureCannotDecrypt(sentEventIds, newBobSession, e2eRoomID, MXCryptoError.ErrorType.UNKNOWN_INBOUND_SESSION_ID)
// Try to request
sentEventIds.forEach { sentEventId ->
val event = newBobSession.getRoom(e2eRoomID)!!.getTimeLineEvent(sentEventId)!!.root
newBobSession.cryptoService().requestRoomKeyForEvent(event)
}
// wait a bit
testHelper.runBlockingTest {
delay(10_000)
}
// Ensure that new bob still can't decrypt (keys must have been withheld)
ensureCannotDecrypt(sentEventIds, newBobSession, e2eRoomID, MXCryptoError.ErrorType.KEYS_WITHHELD)
// Now mark new bob session as verified
bobSession.cryptoService().verificationService().markedLocallyAsManuallyVerified(newBobSession.myUserId, newBobSession.sessionParams.deviceId!!)
newBobSession.cryptoService().verificationService().markedLocallyAsManuallyVerified(bobSession.myUserId, bobSession.sessionParams.deviceId!!)
// now let new session re-request
sentEventIds.forEach { sentEventId ->
val event = newBobSession.getRoom(e2eRoomID)!!.getTimeLineEvent(sentEventId)!!.root
newBobSession.cryptoService().reRequestRoomKeyForEvent(event)
}
// wait a bit
testHelper.runBlockingTest {
delay(10_000)
}
ensureCanDecrypt(sentEventIds, newBobSession, e2eRoomID, messagesText)
cryptoTestData.cleanUp(testHelper)
testHelper.signOutAndClose(newBobSession)
}
/**
* Test that if a better key is forwared (lower index, it is then used)
*/
@Test
fun testForwardBetterKey() {
val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true)
val aliceSession = cryptoTestData.firstSession
val bobSessionWithBetterKey = cryptoTestData.secondSession!!
val e2eRoomID = cryptoTestData.roomId
val aliceRoomPOV = aliceSession.getRoom(e2eRoomID)!!
cryptoTestHelper.initializeCrossSigning(bobSessionWithBetterKey)
// let's send a few message to bob
var firstEventId: String
val firstMessage = "1. Hello"
Log.v("#E2E TEST", "Alice sends some messages")
firstMessage.let { text ->
firstEventId = sendMessageInRoom(aliceRoomPOV, text)!!
testHelper.waitWithLatch { latch ->
testHelper.retryPeriodicallyWithLatch(latch) {
val timeLineEvent = bobSessionWithBetterKey.getRoom(e2eRoomID)?.getTimeLineEvent(firstEventId)
timeLineEvent != null &&
timeLineEvent.isEncrypted() &&
timeLineEvent.root.getClearType() == EventType.MESSAGE
}
}
}
// Ensure bob can decrypt
ensureIsDecrypted(listOf(firstEventId), bobSessionWithBetterKey, e2eRoomID)
// Let's add a new unverified session from bob
val newBobSession = testHelper.logIntoAccount(bobSessionWithBetterKey.myUserId, SessionTestParams(true))
// check that new bob can't currently decrypt
Log.v("#E2E TEST", "check that new bob can't currently decrypt")
ensureCannotDecrypt(listOf(firstEventId), newBobSession, e2eRoomID, MXCryptoError.ErrorType.UNKNOWN_INBOUND_SESSION_ID)
// Now let alice send a new message. this time the new bob session will be able to decrypt
var secondEventId: String
val secondMessage = "2. New Device?"
Log.v("#E2E TEST", "Alice sends some messages")
secondMessage.let { text ->
secondEventId = sendMessageInRoom(aliceRoomPOV, text)!!
testHelper.waitWithLatch { latch ->
testHelper.retryPeriodicallyWithLatch(latch) {
val timeLineEvent = newBobSession.getRoom(e2eRoomID)?.getTimeLineEvent(secondEventId)
timeLineEvent != null &&
timeLineEvent.isEncrypted() &&
timeLineEvent.root.getClearType() == EventType.MESSAGE
}
}
}
// check that both messages have same sessionId (it's just that we don't have index 0)
val firstEventNewBobPov = newBobSession.getRoom(e2eRoomID)?.getTimeLineEvent(firstEventId)
val secondEventNewBobPov = newBobSession.getRoom(e2eRoomID)?.getTimeLineEvent(secondEventId)
val firstSessionId = firstEventNewBobPov!!.root.content.toModel<EncryptedEventContent>()!!.sessionId!!
val secondSessionId = secondEventNewBobPov!!.root.content.toModel<EncryptedEventContent>()!!.sessionId!!
Assert.assertTrue("Should be the same session id", firstSessionId == secondSessionId)
// Confirm we can decrypt one but not the other
testHelper.runBlockingTest {
try {
newBobSession.cryptoService().decryptEvent(firstEventNewBobPov.root, "")
fail("Should not be able to decrypt event")
} catch (error: MXCryptoError) {
val errorType = (error as? MXCryptoError.Base)?.errorType
assertEquals(MXCryptoError.ErrorType.UNKNOWN_MESSAGE_INDEX, errorType)
}
}
testHelper.runBlockingTest {
try {
newBobSession.cryptoService().decryptEvent(secondEventNewBobPov.root, "")
} catch (error: MXCryptoError) {
fail("Should be able to decrypt event")
}
}
// Now let's verify bobs session, and re-request keys
bobSessionWithBetterKey.cryptoService()
.verificationService()
.markedLocallyAsManuallyVerified(newBobSession.myUserId, newBobSession.sessionParams.deviceId!!)
newBobSession.cryptoService()
.verificationService()
.markedLocallyAsManuallyVerified(bobSessionWithBetterKey.myUserId, bobSessionWithBetterKey.sessionParams.deviceId!!)
// now let new session request
newBobSession.cryptoService().requestRoomKeyForEvent(firstEventNewBobPov.root)
// wait a bit
testHelper.runBlockingTest {
delay(10_000)
}
// old session should have shared the key at earliest known index now
// we should be able to decrypt both
testHelper.runBlockingTest {
try {
newBobSession.cryptoService().decryptEvent(firstEventNewBobPov.root, "")
} catch (error: MXCryptoError) {
fail("Should be able to decrypt first event now $error")
}
}
testHelper.runBlockingTest {
try {
newBobSession.cryptoService().decryptEvent(secondEventNewBobPov.root, "")
} catch (error: MXCryptoError) {
fail("Should be able to decrypt event $error")
}
}
cryptoTestData.cleanUp(testHelper)
testHelper.signOutAndClose(newBobSession)
}
private fun sendMessageInRoom(aliceRoomPOV: Room, text: String): String? {
aliceRoomPOV.sendTextMessage(text)
var sentEventId: String? = null
testHelper.waitWithLatch(4 * 60_000) {
val timeline = aliceRoomPOV.createTimeline(null, TimelineSettings(60))
timeline.start()
testHelper.retryPeriodicallyWithLatch(it) {
val decryptedMsg = timeline.getSnapshot()
.filter { it.root.getClearType() == EventType.MESSAGE }
.also {
Log.v("#E2E TEST", "Timeline snapshot is ${it.map { "${it.root.type}|${it.root.sendState}" }.joinToString(",", "[", "]")}")
}
.filter { it.root.sendState == SendState.SYNCED }
.firstOrNull { it.root.getClearContent().toModel<MessageContent>()?.body?.startsWith(text) == true }
sentEventId = decryptedMsg?.eventId
decryptedMsg != null
}
timeline.dispose()
}
return sentEventId
}
private fun ensureMembersHaveJoined(aliceSession: Session, otherAccounts: List<Session>, e2eRoomID: String) {
testHelper.waitWithLatch {
testHelper.retryPeriodicallyWithLatch(it) {
otherAccounts.map {
aliceSession.getRoomMember(it.myUserId, e2eRoomID)?.membership
}.all {
it == Membership.JOIN
}
}
}
}
private fun waitForAndAcceptInviteInRoom(otherSession: Session, e2eRoomID: String) {
testHelper.waitWithLatch {
testHelper.retryPeriodicallyWithLatch(it) {
val roomSummary = otherSession.getRoomSummary(e2eRoomID)
(roomSummary != null && roomSummary.membership == Membership.INVITE).also {
if (it) {
Log.v("#E2E TEST", "${otherSession.myUserId} can see the invite from alice")
}
}
}
}
testHelper.runBlockingTest(60_000) {
Log.v("#E2E TEST", "${otherSession.myUserId} tries to join room $e2eRoomID")
try {
otherSession.joinRoom(e2eRoomID)
} catch (ex: JoinRoomFailure.JoinedWithTimeout) {
// it's ok we will wait after
}
}
Log.v("#E2E TEST", "${otherSession.myUserId} waiting for join echo ...")
testHelper.waitWithLatch {
testHelper.retryPeriodicallyWithLatch(it) {
val roomSummary = otherSession.getRoomSummary(e2eRoomID)
roomSummary != null && roomSummary.membership == Membership.JOIN
}
}
}
private fun ensureCanDecrypt(sentEventIds: MutableList<String>, session: Session, e2eRoomID: String, messagesText: List<String>) {
sentEventIds.forEachIndexed { index, sentEventId ->
testHelper.waitWithLatch { latch ->
testHelper.retryPeriodicallyWithLatch(latch) {
val event = session.getRoom(e2eRoomID)!!.getTimeLineEvent(sentEventId)!!.root
testHelper.runBlockingTest {
try {
session.cryptoService().decryptEvent(event, "").let { result ->
event.mxDecryptionResult = OlmDecryptionResult(
payload = result.clearEvent,
senderKey = result.senderCurve25519Key,
keysClaimed = result.claimedEd25519Key?.let { mapOf("ed25519" to it) },
forwardingCurve25519KeyChain = result.forwardingCurve25519KeyChain
)
}
} catch (error: MXCryptoError) {
// nop
}
}
event.getClearType() == EventType.MESSAGE &&
messagesText[index] == event.getClearContent()?.toModel<MessageContent>()?.body
}
}
}
}
private fun ensureIsDecrypted(sentEventIds: List<String>, session: Session, e2eRoomID: String) {
testHelper.waitWithLatch { latch ->
sentEventIds.forEach { sentEventId ->
testHelper.retryPeriodicallyWithLatch(latch) {
val timeLineEvent = session.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId)
timeLineEvent != null &&
timeLineEvent.isEncrypted() &&
timeLineEvent.root.getClearType() == EventType.MESSAGE
}
}
}
}
private fun ensureCannotDecrypt(sentEventIds: List<String>, newBobSession: Session, e2eRoomID: String, expectedError: MXCryptoError.ErrorType?) {
sentEventIds.forEach { sentEventId ->
val event = newBobSession.getRoom(e2eRoomID)!!.getTimeLineEvent(sentEventId)!!.root
testHelper.runBlockingTest {
try {
newBobSession.cryptoService().decryptEvent(event, "")
fail("Should not be able to decrypt event")
} catch (error: MXCryptoError) {
val errorType = (error as? MXCryptoError.Base)?.errorType
if (expectedError == null) {
Assert.assertNotNull(errorType)
} else {
assertEquals(expectedError, errorType, "Message expected to be UISI")
}
}
}
}
}
}

View File

@ -21,7 +21,6 @@ import androidx.test.ext.junit.runners.AndroidJUnit4
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
import org.junit.Assert.assertNotNull import org.junit.Assert.assertNotNull
import org.junit.FixMethodOrder import org.junit.FixMethodOrder
import org.junit.Ignore
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith import org.junit.runner.RunWith
import org.junit.runners.MethodSorters import org.junit.runners.MethodSorters
@ -41,7 +40,6 @@ class PreShareKeysTest : InstrumentedTest {
private val cryptoTestHelper = CryptoTestHelper(testHelper) private val cryptoTestHelper = CryptoTestHelper(testHelper)
@Test @Test
@Ignore("This test will be ignored until it is fixed")
fun ensure_outbound_session_happy_path() { fun ensure_outbound_session_happy_path() {
val testData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true) val testData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true)
val e2eRoomID = testData.roomId val e2eRoomID = testData.roomId
@ -92,7 +90,7 @@ class PreShareKeysTest : InstrumentedTest {
// Just send a real message as test // Just send a real message as test
val sentEvent = testHelper.sendTextMessage(aliceSession.getRoom(e2eRoomID)!!, "Allo", 1).first() val sentEvent = testHelper.sendTextMessage(aliceSession.getRoom(e2eRoomID)!!, "Allo", 1).first()
assertEquals(megolmSessionId, sentEvent.root.content.toModel<EncryptedEventContent>()?.sessionId, "Unexpected megolm session") assertEquals("Unexpected megolm session", megolmSessionId, sentEvent.root.content.toModel<EncryptedEventContent>()?.sessionId,)
testHelper.waitWithLatch { latch -> testHelper.waitWithLatch { latch ->
testHelper.retryPeriodicallyWithLatch(latch) { testHelper.retryPeriodicallyWithLatch(latch) {
bobSession.getRoom(e2eRoomID)?.getTimelineEvent(sentEvent.eventId)?.root?.getClearType() == EventType.MESSAGE bobSession.getRoom(e2eRoomID)?.getTimelineEvent(sentEvent.eventId)?.root?.getClearType() == EventType.MESSAGE

View File

@ -21,7 +21,6 @@ import org.amshove.kluent.shouldBe
import org.junit.Assert import org.junit.Assert
import org.junit.Before import org.junit.Before
import org.junit.FixMethodOrder import org.junit.FixMethodOrder
import org.junit.Ignore
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith import org.junit.runner.RunWith
import org.junit.runners.MethodSorters import org.junit.runners.MethodSorters
@ -85,7 +84,6 @@ class UnwedgingTest : InstrumentedTest {
* -> This is automatically fixed after SDKs restarted the olm session * -> This is automatically fixed after SDKs restarted the olm session
*/ */
@Test @Test
@Ignore("This test will be ignored until it is fixed")
fun testUnwedging() { fun testUnwedging() {
val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom() val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom()
@ -94,9 +92,7 @@ class UnwedgingTest : InstrumentedTest {
val bobSession = cryptoTestData.secondSession!! val bobSession = cryptoTestData.secondSession!!
val aliceCryptoStore = (aliceSession.cryptoService() as DefaultCryptoService).cryptoStoreForTesting val aliceCryptoStore = (aliceSession.cryptoService() as DefaultCryptoService).cryptoStoreForTesting
val olmDevice = (aliceSession.cryptoService() as DefaultCryptoService).olmDeviceForTest
// bobSession.cryptoService().setWarnOnUnknownDevices(false)
// aliceSession.cryptoService().setWarnOnUnknownDevices(false)
val roomFromBobPOV = bobSession.getRoom(aliceRoomId)!! val roomFromBobPOV = bobSession.getRoom(aliceRoomId)!!
val roomFromAlicePOV = aliceSession.getRoom(aliceRoomId)!! val roomFromAlicePOV = aliceSession.getRoom(aliceRoomId)!!
@ -175,6 +171,7 @@ class UnwedgingTest : InstrumentedTest {
Timber.i("## CRYPTO | testUnwedging: wedge the session now. Set crypto state like after the first message") Timber.i("## CRYPTO | testUnwedging: wedge the session now. Set crypto state like after the first message")
aliceCryptoStore.storeSession(OlmSessionWrapper(deserializeFromRealm<OlmSession>(oldSession)!!), bobSession.cryptoService().getMyDevice().identityKey()!!) aliceCryptoStore.storeSession(OlmSessionWrapper(deserializeFromRealm<OlmSession>(oldSession)!!), bobSession.cryptoService().getMyDevice().identityKey()!!)
olmDevice.clearOlmSessionCache()
Thread.sleep(6_000) Thread.sleep(6_000)
// Force new session, and key share // Force new session, and key share
@ -227,8 +224,10 @@ class UnwedgingTest : InstrumentedTest {
testHelper.waitWithLatch { testHelper.waitWithLatch {
testHelper.retryPeriodicallyWithLatch(it) { testHelper.retryPeriodicallyWithLatch(it) {
// we should get back the key and be able to decrypt // we should get back the key and be able to decrypt
val result = tryOrNull { val result = testHelper.runBlockingTest {
bobSession.cryptoService().decryptEvent(messagesReceivedByBob[0].root, "") tryOrNull {
bobSession.cryptoService().decryptEvent(messagesReceivedByBob[0].root, "")
}
} }
Timber.i("## CRYPTO | testUnwedging: decrypt result ${result?.clearEvent}") Timber.i("## CRYPTO | testUnwedging: decrypt result ${result?.clearEvent}")
result != null result != null

View File

@ -97,7 +97,9 @@ class KeyShareTests : InstrumentedTest {
assert(receivedEvent!!.isEncrypted()) assert(receivedEvent!!.isEncrypted())
try { try {
aliceSession2.cryptoService().decryptEvent(receivedEvent.root, "foo") commonTestHelper.runBlockingTest {
aliceSession2.cryptoService().decryptEvent(receivedEvent.root, "foo")
}
fail("should fail") fail("should fail")
} catch (failure: Throwable) { } catch (failure: Throwable) {
} }
@ -152,7 +154,9 @@ class KeyShareTests : InstrumentedTest {
} }
try { try {
aliceSession2.cryptoService().decryptEvent(receivedEvent.root, "foo") commonTestHelper.runBlockingTest {
aliceSession2.cryptoService().decryptEvent(receivedEvent.root, "foo")
}
fail("should fail") fail("should fail")
} catch (failure: Throwable) { } catch (failure: Throwable) {
} }
@ -189,7 +193,9 @@ class KeyShareTests : InstrumentedTest {
} }
try { try {
aliceSession2.cryptoService().decryptEvent(receivedEvent.root, "foo") commonTestHelper.runBlockingTest {
aliceSession2.cryptoService().decryptEvent(receivedEvent.root, "foo")
}
} catch (failure: Throwable) { } catch (failure: Throwable) {
fail("should have been able to decrypt") fail("should have been able to decrypt")
} }
@ -384,7 +390,11 @@ class KeyShareTests : InstrumentedTest {
val roomRoomBobPov = aliceSession.getRoom(roomId) val roomRoomBobPov = aliceSession.getRoom(roomId)
val beforeJoin = roomRoomBobPov!!.getTimelineEvent(secondEventId) val beforeJoin = roomRoomBobPov!!.getTimelineEvent(secondEventId)
var dRes = tryOrNull { bobSession.cryptoService().decryptEvent(beforeJoin!!.root, "") } var dRes = tryOrNull {
commonTestHelper.runBlockingTest {
bobSession.cryptoService().decryptEvent(beforeJoin!!.root, "")
}
}
assert(dRes == null) assert(dRes == null)
@ -395,7 +405,11 @@ class KeyShareTests : InstrumentedTest {
Thread.sleep(3_000) Thread.sleep(3_000)
// With the bug the first session would have improperly reshare that key :/ // With the bug the first session would have improperly reshare that key :/
dRes = tryOrNull { bobSession.cryptoService().decryptEvent(beforeJoin.root, "") } dRes = tryOrNull {
commonTestHelper.runBlockingTest {
bobSession.cryptoService().decryptEvent(beforeJoin.root, "")
}
}
Log.d("#TEST", "KS: sgould not decrypt that ${beforeJoin.root.getClearContent().toModel<MessageContent>()?.body}") Log.d("#TEST", "KS: sgould not decrypt that ${beforeJoin.root.getClearContent().toModel<MessageContent>()?.body}")
assert(dRes?.clearEvent == null) assert(dRes?.clearEvent == null)
} }

View File

@ -93,7 +93,9 @@ class WithHeldTests : InstrumentedTest {
// Bob should not be able to decrypt because the keys is withheld // Bob should not be able to decrypt because the keys is withheld
try { try {
// .. might need to wait a bit for stability? // .. might need to wait a bit for stability?
bobUnverifiedSession.cryptoService().decryptEvent(eventBobPOV.root, "") testHelper.runBlockingTest {
bobUnverifiedSession.cryptoService().decryptEvent(eventBobPOV.root, "")
}
Assert.fail("This session should not be able to decrypt") Assert.fail("This session should not be able to decrypt")
} catch (failure: Throwable) { } catch (failure: Throwable) {
val type = (failure as MXCryptoError.Base).errorType val type = (failure as MXCryptoError.Base).errorType
@ -118,7 +120,9 @@ class WithHeldTests : InstrumentedTest {
// Previous message should still be undecryptable (partially withheld session) // Previous message should still be undecryptable (partially withheld session)
try { try {
// .. might need to wait a bit for stability? // .. might need to wait a bit for stability?
bobUnverifiedSession.cryptoService().decryptEvent(eventBobPOV.root, "") testHelper.runBlockingTest {
bobUnverifiedSession.cryptoService().decryptEvent(eventBobPOV.root, "")
}
Assert.fail("This session should not be able to decrypt") Assert.fail("This session should not be able to decrypt")
} catch (failure: Throwable) { } catch (failure: Throwable) {
val type = (failure as MXCryptoError.Base).errorType val type = (failure as MXCryptoError.Base).errorType
@ -165,7 +169,9 @@ class WithHeldTests : InstrumentedTest {
val eventBobPOV = bobSession.getRoom(testData.roomId)?.getTimelineEvent(eventId) val eventBobPOV = bobSession.getRoom(testData.roomId)?.getTimelineEvent(eventId)
try { try {
// .. might need to wait a bit for stability? // .. might need to wait a bit for stability?
bobSession.cryptoService().decryptEvent(eventBobPOV!!.root, "") testHelper.runBlockingTest {
bobSession.cryptoService().decryptEvent(eventBobPOV!!.root, "")
}
Assert.fail("This session should not be able to decrypt") Assert.fail("This session should not be able to decrypt")
} catch (failure: Throwable) { } catch (failure: Throwable) {
val type = (failure as MXCryptoError.Base).errorType val type = (failure as MXCryptoError.Base).errorType
@ -233,7 +239,11 @@ class WithHeldTests : InstrumentedTest {
testHelper.retryPeriodicallyWithLatch(latch) { testHelper.retryPeriodicallyWithLatch(latch) {
val timeLineEvent = bobSecondSession.getRoom(testData.roomId)?.getTimelineEvent(eventId)?.also { val timeLineEvent = bobSecondSession.getRoom(testData.roomId)?.getTimelineEvent(eventId)?.also {
// try to decrypt and force key request // try to decrypt and force key request
tryOrNull { bobSecondSession.cryptoService().decryptEvent(it.root, "") } tryOrNull {
testHelper.runBlockingTest {
bobSecondSession.cryptoService().decryptEvent(it.root, "")
}
}
} }
sessionId = timeLineEvent?.root?.content?.toModel<EncryptedEventContent>()?.sessionId sessionId = timeLineEvent?.root?.content?.toModel<EncryptedEventContent>()?.sessionId
timeLineEvent != null timeLineEvent != null

View File

@ -121,7 +121,7 @@ interface CryptoService {
fun discardOutboundSession(roomId: String) fun discardOutboundSession(roomId: String)
@Throws(MXCryptoError::class) @Throws(MXCryptoError::class)
fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult
fun decryptEventAsync(event: Event, timeline: String, callback: MatrixCallback<MXEventDecryptionResult>) fun decryptEventAsync(event: Event, timeline: String, callback: MatrixCallback<MXEventDecryptionResult>)

View File

@ -434,6 +434,14 @@ internal class DefaultCryptoService @Inject constructor(
val currentCount = syncResponse.deviceOneTimeKeysCount.signedCurve25519 ?: 0 val currentCount = syncResponse.deviceOneTimeKeysCount.signedCurve25519 ?: 0
oneTimeKeysUploader.updateOneTimeKeyCount(currentCount) oneTimeKeysUploader.updateOneTimeKeyCount(currentCount)
} }
// unwedge if needed
try {
eventDecryptor.unwedgeDevicesIfNeeded()
} catch (failure: Throwable) {
Timber.tag(loggerTag.value).w("unwedgeDevicesIfNeeded failed")
}
// There is a limit of to_device events returned per sync. // There is a limit of to_device events returned per sync.
// If we are in a case of such limited to_device sync we can't try to generate/upload // If we are in a case of such limited to_device sync we can't try to generate/upload
// new otk now, because there might be some pending olm pre-key to_device messages that would fail if we rotate // new otk now, because there might be some pending olm pre-key to_device messages that would fail if we rotate
@ -723,7 +731,7 @@ internal class DefaultCryptoService @Inject constructor(
* @return the MXEventDecryptionResult data, or throw in case of error * @return the MXEventDecryptionResult data, or throw in case of error
*/ */
@Throws(MXCryptoError::class) @Throws(MXCryptoError::class)
override fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult { override suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
return internalDecryptEvent(event, timeline) return internalDecryptEvent(event, timeline)
} }
@ -746,7 +754,7 @@ internal class DefaultCryptoService @Inject constructor(
* @return the MXEventDecryptionResult data, or null in case of error * @return the MXEventDecryptionResult data, or null in case of error
*/ */
@Throws(MXCryptoError::class) @Throws(MXCryptoError::class)
private fun internalDecryptEvent(event: Event, timeline: String): MXEventDecryptionResult { private suspend fun internalDecryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
return eventDecryptor.decryptEvent(event, timeline) return eventDecryptor.decryptEvent(event, timeline)
} }
@ -1364,6 +1372,9 @@ internal class DefaultCryptoService @Inject constructor(
@VisibleForTesting @VisibleForTesting
val cryptoStoreForTesting = cryptoStore val cryptoStoreForTesting = cryptoStore
@VisibleForTesting
val olmDeviceForTest = olmDevice
companion object { companion object {
const val CRYPTO_MIN_FORCE_SESSION_PERIOD_MILLIS = 3_600_000 // one hour const val CRYPTO_MIN_FORCE_SESSION_PERIOD_MILLIS = 3_600_000 // one hour
} }

View File

@ -21,14 +21,13 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import org.matrix.android.sdk.api.MatrixCallback import org.matrix.android.sdk.api.MatrixCallback
import org.matrix.android.sdk.api.MatrixCoroutineDispatchers import org.matrix.android.sdk.api.MatrixCoroutineDispatchers
import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.session.crypto.MXCryptoError
import org.matrix.android.sdk.api.session.events.model.Event import org.matrix.android.sdk.api.session.events.model.Event
import org.matrix.android.sdk.api.session.events.model.EventType import org.matrix.android.sdk.api.session.events.model.EventType
import org.matrix.android.sdk.api.session.events.model.toModel import org.matrix.android.sdk.api.session.events.model.toModel
import org.matrix.android.sdk.internal.crypto.actions.EnsureOlmSessionsForDevicesAction import org.matrix.android.sdk.internal.crypto.actions.EnsureOlmSessionsForDevicesAction
import org.matrix.android.sdk.internal.crypto.actions.MessageEncrypter import org.matrix.android.sdk.internal.crypto.actions.MessageEncrypter
import org.matrix.android.sdk.internal.crypto.model.CryptoDeviceInfo
import org.matrix.android.sdk.internal.crypto.model.MXOlmSessionResult
import org.matrix.android.sdk.internal.crypto.model.MXUsersDevicesMap import org.matrix.android.sdk.internal.crypto.model.MXUsersDevicesMap
import org.matrix.android.sdk.internal.crypto.model.event.OlmEventContent import org.matrix.android.sdk.internal.crypto.model.event.OlmEventContent
import org.matrix.android.sdk.internal.crypto.store.IMXCryptoStore import org.matrix.android.sdk.internal.crypto.store.IMXCryptoStore
@ -40,6 +39,8 @@ import javax.inject.Inject
private const val SEND_TO_DEVICE_RETRY_COUNT = 3 private const val SEND_TO_DEVICE_RETRY_COUNT = 3
private val loggerTag = LoggerTag("CryptoSyncHandler", LoggerTag.CRYPTO)
@SessionScope @SessionScope
internal class EventDecryptor @Inject constructor( internal class EventDecryptor @Inject constructor(
private val cryptoCoroutineScope: CoroutineScope, private val cryptoCoroutineScope: CoroutineScope,
@ -47,13 +48,22 @@ internal class EventDecryptor @Inject constructor(
private val roomDecryptorProvider: RoomDecryptorProvider, private val roomDecryptorProvider: RoomDecryptorProvider,
private val messageEncrypter: MessageEncrypter, private val messageEncrypter: MessageEncrypter,
private val sendToDeviceTask: SendToDeviceTask, private val sendToDeviceTask: SendToDeviceTask,
private val deviceListManager: DeviceListManager,
private val ensureOlmSessionsForDevicesAction: EnsureOlmSessionsForDevicesAction, private val ensureOlmSessionsForDevicesAction: EnsureOlmSessionsForDevicesAction,
private val cryptoStore: IMXCryptoStore private val cryptoStore: IMXCryptoStore
) { ) {
// The date of the last time we forced establishment /**
// of a new session for each user:device. * Rate limit unwedge attempt, should we persist that?
private val lastNewSessionForcedDates = MXUsersDevicesMap<Long>() */
private val lastNewSessionForcedDates = mutableMapOf<WedgedDeviceInfo, Long>()
data class WedgedDeviceInfo(
val userId: String,
val senderKey: String?
)
private val wedgedDevices = mutableListOf<WedgedDeviceInfo>()
/** /**
* Decrypt an event * Decrypt an event
@ -63,7 +73,7 @@ internal class EventDecryptor @Inject constructor(
* @return the MXEventDecryptionResult data, or throw in case of error * @return the MXEventDecryptionResult data, or throw in case of error
*/ */
@Throws(MXCryptoError::class) @Throws(MXCryptoError::class)
fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult { suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
return internalDecryptEvent(event, timeline) return internalDecryptEvent(event, timeline)
} }
@ -91,38 +101,32 @@ internal class EventDecryptor @Inject constructor(
* @return the MXEventDecryptionResult data, or null in case of error * @return the MXEventDecryptionResult data, or null in case of error
*/ */
@Throws(MXCryptoError::class) @Throws(MXCryptoError::class)
private fun internalDecryptEvent(event: Event, timeline: String): MXEventDecryptionResult { private suspend fun internalDecryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
val eventContent = event.content val eventContent = event.content
if (eventContent == null) { if (eventContent == null) {
Timber.e("## CRYPTO | decryptEvent : empty event content") Timber.tag(loggerTag.value).e("decryptEvent : empty event content")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE, MXCryptoError.BAD_ENCRYPTED_MESSAGE_REASON) throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE, MXCryptoError.BAD_ENCRYPTED_MESSAGE_REASON)
} else { } else {
val algorithm = eventContent["algorithm"]?.toString() val algorithm = eventContent["algorithm"]?.toString()
val alg = roomDecryptorProvider.getOrCreateRoomDecryptor(event.roomId, algorithm) val alg = roomDecryptorProvider.getOrCreateRoomDecryptor(event.roomId, algorithm)
if (alg == null) { if (alg == null) {
val reason = String.format(MXCryptoError.UNABLE_TO_DECRYPT_REASON, event.eventId, algorithm) val reason = String.format(MXCryptoError.UNABLE_TO_DECRYPT_REASON, event.eventId, algorithm)
Timber.e("## CRYPTO | decryptEvent() : $reason") Timber.tag(loggerTag.value).e("decryptEvent() : $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, reason) throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, reason)
} else { } else {
try { try {
return alg.decryptEvent(event, timeline) return alg.decryptEvent(event, timeline)
} catch (mxCryptoError: MXCryptoError) { } catch (mxCryptoError: MXCryptoError) {
Timber.v("## CRYPTO | internalDecryptEvent : Failed to decrypt ${event.eventId} reason: $mxCryptoError") Timber.tag(loggerTag.value).d("internalDecryptEvent : Failed to decrypt ${event.eventId} reason: $mxCryptoError")
if (algorithm == MXCRYPTO_ALGORITHM_OLM) { if (algorithm == MXCRYPTO_ALGORITHM_OLM) {
if (mxCryptoError is MXCryptoError.Base && if (mxCryptoError is MXCryptoError.Base &&
mxCryptoError.errorType == MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE) { mxCryptoError.errorType == MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE) {
// need to find sending device // need to find sending device
cryptoCoroutineScope.launch(coroutineDispatchers.crypto) { val olmContent = event.content.toModel<OlmEventContent>()
val olmContent = event.content.toModel<OlmEventContent>() if (event.senderId != null && olmContent?.senderKey != null) {
cryptoStore.getUserDevices(event.senderId ?: "") markOlmSessionForUnwedging(event.senderId, olmContent.senderKey)
?.values } else {
?.firstOrNull { it.identityKey() == olmContent?.senderKey } Timber.tag(loggerTag.value).d("Can't mark as wedge malformed")
?.let {
markOlmSessionForUnwedging(event.senderId ?: "", it)
}
?: run {
Timber.i("## CRYPTO | internalDecryptEvent() : Failed to find sender crypto device for unwedging")
}
} }
} }
} }
@ -132,53 +136,91 @@ internal class EventDecryptor @Inject constructor(
} }
} }
// coroutineDispatchers.crypto scope private fun markOlmSessionForUnwedging(senderId: String, senderKey: String) {
private fun markOlmSessionForUnwedging(senderId: String, deviceInfo: CryptoDeviceInfo) { val info = WedgedDeviceInfo(senderId, senderKey)
val deviceKey = deviceInfo.identityKey() if (!wedgedDevices.contains(info)) {
Timber.tag(loggerTag.value).d("Marking device from $senderId key:$senderKey as wedged")
wedgedDevices.add(info)
}
}
val lastForcedDate = lastNewSessionForcedDates.getObject(senderId, deviceKey) ?: 0 // coroutineDispatchers.crypto scope
suspend fun unwedgeDevicesIfNeeded() {
// handle wedged devices
// Some olm decryption have failed and some device are wedged
// we should force start a new session for those
Timber.tag(loggerTag.value).v("Unwedging: ${wedgedDevices.size} are wedged")
// get the one that should be retried according to rate limit
val now = System.currentTimeMillis() val now = System.currentTimeMillis()
if (now - lastForcedDate < DefaultCryptoService.CRYPTO_MIN_FORCE_SESSION_PERIOD_MILLIS) { val toUnwedge = wedgedDevices.filter {
Timber.w("## CRYPTO | markOlmSessionForUnwedging: New session already forced with device at $lastForcedDate. Not forcing another") val lastForcedDate = lastNewSessionForcedDates[it] ?: 0
if (now - lastForcedDate < DefaultCryptoService.CRYPTO_MIN_FORCE_SESSION_PERIOD_MILLIS) {
Timber.tag(loggerTag.value).d("Unwedging, New session for $it already forced with device at $lastForcedDate")
return@filter false
}
// let's already mark that we tried now
lastNewSessionForcedDates[it] = now
true
}
if (toUnwedge.isEmpty()) {
Timber.tag(loggerTag.value).v("Nothing to unwedge")
return return
} }
Timber.tag(loggerTag.value).d("Unwedging, trying to create new session for ${toUnwedge.size} devices")
Timber.i("## CRYPTO | markOlmSessionForUnwedging from $senderId:${deviceInfo.deviceId}") toUnwedge
lastNewSessionForcedDates.setObject(senderId, deviceKey, now) .chunked(100) // safer to chunk if we ever have lots of wedged devices
.forEach { wedgedList ->
// offload this from crypto thread (?) val groupedByUserId = wedgedList.groupBy { it.userId }
cryptoCoroutineScope.launch(coroutineDispatchers.computation) { // lets download keys if needed
runCatching { ensureOlmSessionsForDevicesAction.handle(mapOf(senderId to listOf(deviceInfo)), force = true) }.fold( withContext(coroutineDispatchers.io) {
onSuccess = { sendDummyToDevice(ensured = it, deviceInfo, senderId) }, deviceListManager.downloadKeys(groupedByUserId.keys.toList(), false)
onFailure = {
Timber.e("## CRYPTO | markOlmSessionForUnwedging() : failed to ensure device info ${senderId}${deviceInfo.deviceId}")
} }
)
}
}
private suspend fun sendDummyToDevice(ensured: MXUsersDevicesMap<MXOlmSessionResult>, deviceInfo: CryptoDeviceInfo, senderId: String) { // find the matching devices
Timber.i("## CRYPTO | markOlmSessionForUnwedging() : ensureOlmSessionsForDevicesAction isEmpty:${ensured.isEmpty}") groupedByUserId
.map { groupedByUser ->
val userId = groupedByUser.key
val wedgeSenderKeysForUser = groupedByUser.value.map { it.senderKey }
val knownDevices = cryptoStore.getUserDevices(userId)?.values.orEmpty()
userId to wedgeSenderKeysForUser.mapNotNull { senderKey ->
knownDevices.firstOrNull { it.identityKey() == senderKey }
}
}
.toMap()
.let { deviceList ->
try {
// force creating new outbound session and mark them as most recent to
// be used for next encryption (dummy)
val sessionToUse = ensureOlmSessionsForDevicesAction.handle(deviceList, true)
Timber.tag(loggerTag.value).d("Unwedging, found ${sessionToUse.map.size} to send dummy to")
// Now send a blank message on that session so the other side knows about it. // Now send a dummy message on that session so the other side knows about it.
// (The keyshare request is sent in the clear so that won't do) val payloadJson = mapOf(
// We send this first such that, as long as the toDevice messages arrive in the "type" to EventType.DUMMY
// same order we sent them, the other end will get this first, set up the new session, )
// then get the keyshare request and send the key over this new session (because it val sendToDeviceMap = MXUsersDevicesMap<Any>()
// is the session it has most recently received a message on). sessionToUse.map.values
val payloadJson = mapOf<String, Any>("type" to EventType.DUMMY) .flatMap { it.values }
.map { it.deviceInfo }
.forEach { deviceInfo ->
Timber.tag(loggerTag.value).v("encrypting dummy to ${deviceInfo.deviceId}")
val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo))
sendToDeviceMap.setObject(deviceInfo.userId, deviceInfo.deviceId, encodedPayload)
}
val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo)) // now let's send that
val sendToDeviceMap = MXUsersDevicesMap<Any>() val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, sendToDeviceMap)
sendToDeviceMap.setObject(senderId, deviceInfo.deviceId, encodedPayload) withContext(coroutineDispatchers.io) {
Timber.i("## CRYPTO | markOlmSessionForUnwedging() : sending dummy to $senderId:${deviceInfo.deviceId}") sendToDeviceTask.executeRetry(sendToDeviceParams, remainingRetry = SEND_TO_DEVICE_RETRY_COUNT)
withContext(coroutineDispatchers.io) { }
val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, sendToDeviceMap) } catch (failure: Throwable) {
try { deviceList.flatMap { it.value }.joinToString { it.shortDebugString() }.let {
sendToDeviceTask.executeRetry(sendToDeviceParams, remainingRetry = SEND_TO_DEVICE_RETRY_COUNT) Timber.tag(loggerTag.value).e(failure, "## Failed to unwedge devices: $it}")
} catch (failure: Throwable) { }
Timber.e(failure, "## CRYPTO | markOlmSessionForUnwedging() : failed to send dummy to $senderId:${deviceInfo.deviceId}") }
} }
} }
} }
} }

View File

@ -19,8 +19,10 @@ package org.matrix.android.sdk.internal.crypto
import android.util.LruCache import android.util.LruCache
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Mutex
import org.matrix.android.sdk.api.MatrixCoroutineDispatchers import org.matrix.android.sdk.api.MatrixCoroutineDispatchers
import org.matrix.android.sdk.api.extensions.tryOrNull import org.matrix.android.sdk.api.extensions.tryOrNull
import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.internal.crypto.model.OlmInboundGroupSessionWrapper2 import org.matrix.android.sdk.internal.crypto.model.OlmInboundGroupSessionWrapper2
import org.matrix.android.sdk.internal.crypto.store.IMXCryptoStore import org.matrix.android.sdk.internal.crypto.store.IMXCryptoStore
import timber.log.Timber import timber.log.Timber
@ -28,6 +30,13 @@ import java.util.Timer
import java.util.TimerTask import java.util.TimerTask
import javax.inject.Inject import javax.inject.Inject
data class InboundGroupSessionHolder(
val wrapper: OlmInboundGroupSessionWrapper2,
val mutex: Mutex = Mutex()
)
private val loggerTag = LoggerTag("InboundGroupSessionStore", LoggerTag.CRYPTO)
/** /**
* Allows to cache and batch store operations on inbound group session store. * Allows to cache and batch store operations on inbound group session store.
* Because it is used in the decrypt flow, that can be called quite rapidly * Because it is used in the decrypt flow, that can be called quite rapidly
@ -42,12 +51,13 @@ internal class InboundGroupSessionStore @Inject constructor(
val senderKey: String val senderKey: String
) )
private val sessionCache = object : LruCache<CacheKey, OlmInboundGroupSessionWrapper2>(30) { private val sessionCache = object : LruCache<CacheKey, InboundGroupSessionHolder>(100) {
override fun entryRemoved(evicted: Boolean, key: CacheKey?, oldValue: OlmInboundGroupSessionWrapper2?, newValue: OlmInboundGroupSessionWrapper2?) { override fun entryRemoved(evicted: Boolean, key: CacheKey?, oldValue: InboundGroupSessionHolder?, newValue: InboundGroupSessionHolder?) {
if (evicted && oldValue != null) { if (oldValue != null) {
cryptoCoroutineScope.launch(coroutineDispatchers.crypto) { cryptoCoroutineScope.launch(coroutineDispatchers.crypto) {
Timber.v("## Inbound: entryRemoved ${oldValue.roomId}-${oldValue.senderKey}") Timber.tag(loggerTag.value).v("## Inbound: entryRemoved ${oldValue.wrapper.roomId}-${oldValue.wrapper.senderKey}")
store.storeInboundGroupSessions(listOf(oldValue)) store.storeInboundGroupSessions(listOf(oldValue).map { it.wrapper })
oldValue.wrapper.olmInboundGroupSession?.releaseSession()
} }
} }
} }
@ -59,27 +69,50 @@ internal class InboundGroupSessionStore @Inject constructor(
private val dirtySession = mutableListOf<OlmInboundGroupSessionWrapper2>() private val dirtySession = mutableListOf<OlmInboundGroupSessionWrapper2>()
@Synchronized @Synchronized
fun getInboundGroupSession(sessionId: String, senderKey: String): OlmInboundGroupSessionWrapper2? { fun clear() {
synchronized(sessionCache) { sessionCache.evictAll()
val known = sessionCache[CacheKey(sessionId, senderKey)]
Timber.v("## Inbound: getInboundGroupSession in cache ${known != null}")
return known ?: store.getInboundGroupSession(sessionId, senderKey)?.also {
Timber.v("## Inbound: getInboundGroupSession cache populate ${it.roomId}")
sessionCache.put(CacheKey(sessionId, senderKey), it)
}
}
} }
@Synchronized @Synchronized
fun storeInBoundGroupSession(wrapper: OlmInboundGroupSessionWrapper2, sessionId: String, senderKey: String) { fun getInboundGroupSession(sessionId: String, senderKey: String): InboundGroupSessionHolder? {
Timber.v("## Inbound: getInboundGroupSession mark as dirty ${wrapper.roomId}-${wrapper.senderKey}") val known = sessionCache[CacheKey(sessionId, senderKey)]
Timber.tag(loggerTag.value).v("## Inbound: getInboundGroupSession $sessionId in cache ${known != null}")
return known
?: store.getInboundGroupSession(sessionId, senderKey)?.also {
Timber.tag(loggerTag.value).v("## Inbound: getInboundGroupSession cache populate ${it.roomId}")
sessionCache.put(CacheKey(sessionId, senderKey), InboundGroupSessionHolder(it))
}?.let {
InboundGroupSessionHolder(it)
}
}
@Synchronized
fun replaceGroupSession(old: InboundGroupSessionHolder, new: InboundGroupSessionHolder, sessionId: String, senderKey: String) {
Timber.tag(loggerTag.value).v("## Replacing outdated session ${old.wrapper.roomId}-${old.wrapper.senderKey}")
dirtySession.remove(old.wrapper)
store.removeInboundGroupSession(sessionId, senderKey)
sessionCache.remove(CacheKey(sessionId, senderKey))
// release removed session
old.wrapper.olmInboundGroupSession?.releaseSession()
internalStoreGroupSession(new, sessionId, senderKey)
}
@Synchronized
fun storeInBoundGroupSession(holder: InboundGroupSessionHolder, sessionId: String, senderKey: String) {
internalStoreGroupSession(holder, sessionId, senderKey)
}
private fun internalStoreGroupSession(holder: InboundGroupSessionHolder, sessionId: String, senderKey: String) {
Timber.tag(loggerTag.value).v("## Inbound: getInboundGroupSession mark as dirty ${holder.wrapper.roomId}-${holder.wrapper.senderKey}")
// We want to batch this a bit for performances // We want to batch this a bit for performances
dirtySession.add(wrapper) dirtySession.add(holder.wrapper)
if (sessionCache[CacheKey(sessionId, senderKey)] == null) { if (sessionCache[CacheKey(sessionId, senderKey)] == null) {
// first time seen, put it in memory cache while waiting for batch insert // first time seen, put it in memory cache while waiting for batch insert
// If it's already known, no need to update cache it's already there // If it's already known, no need to update cache it's already there
sessionCache.put(CacheKey(sessionId, senderKey), wrapper) sessionCache.put(CacheKey(sessionId, senderKey), holder)
} }
timerTask?.cancel() timerTask?.cancel()
@ -96,7 +129,7 @@ internal class InboundGroupSessionStore @Inject constructor(
val toSave = mutableListOf<OlmInboundGroupSessionWrapper2>().apply { addAll(dirtySession) } val toSave = mutableListOf<OlmInboundGroupSessionWrapper2>().apply { addAll(dirtySession) }
dirtySession.clear() dirtySession.clear()
cryptoCoroutineScope.launch(coroutineDispatchers.crypto) { cryptoCoroutineScope.launch(coroutineDispatchers.crypto) {
Timber.v("## Inbound: getInboundGroupSession batching save of ${dirtySession.size}") Timber.tag(loggerTag.value).v("## Inbound: getInboundGroupSession batching save of ${toSave.size}")
tryOrNull { tryOrNull {
store.storeInboundGroupSessions(toSave) store.storeInboundGroupSessions(toSave)
} }

View File

@ -16,6 +16,11 @@
package org.matrix.android.sdk.internal.crypto package org.matrix.android.sdk.internal.crypto
import androidx.annotation.VisibleForTesting
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import org.matrix.android.sdk.api.extensions.tryOrNull
import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.session.crypto.MXCryptoError
import org.matrix.android.sdk.api.util.JSON_DICT_PARAMETERIZED_TYPE import org.matrix.android.sdk.api.util.JSON_DICT_PARAMETERIZED_TYPE
import org.matrix.android.sdk.api.util.JsonDict import org.matrix.android.sdk.api.util.JsonDict
@ -40,6 +45,8 @@ import timber.log.Timber
import java.net.URLEncoder import java.net.URLEncoder
import javax.inject.Inject import javax.inject.Inject
private val loggerTag = LoggerTag("MXOlmDevice", LoggerTag.CRYPTO)
// The libolm wrapper. // The libolm wrapper.
@SessionScope @SessionScope
internal class MXOlmDevice @Inject constructor( internal class MXOlmDevice @Inject constructor(
@ -47,9 +54,12 @@ internal class MXOlmDevice @Inject constructor(
* The store where crypto data is saved. * The store where crypto data is saved.
*/ */
private val store: IMXCryptoStore, private val store: IMXCryptoStore,
private val olmSessionStore: OlmSessionStore,
private val inboundGroupSessionStore: InboundGroupSessionStore private val inboundGroupSessionStore: InboundGroupSessionStore
) { ) {
val mutex = Mutex()
/** /**
* @return the Curve25519 key for the account. * @return the Curve25519 key for the account.
*/ */
@ -93,26 +103,26 @@ internal class MXOlmDevice @Inject constructor(
try { try {
store.getOrCreateOlmAccount() store.getOrCreateOlmAccount()
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "MXOlmDevice : cannot initialize olmAccount") Timber.tag(loggerTag.value).e(e, "MXOlmDevice : cannot initialize olmAccount")
} }
try { try {
olmUtility = OlmUtility() olmUtility = OlmUtility()
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## MXOlmDevice : OlmUtility failed with error") Timber.tag(loggerTag.value).e(e, "## MXOlmDevice : OlmUtility failed with error")
olmUtility = null olmUtility = null
} }
try { try {
deviceCurve25519Key = store.getOlmAccount().identityKeys()[OlmAccount.JSON_KEY_IDENTITY_KEY] deviceCurve25519Key = store.doWithOlmAccount { it.identityKeys()[OlmAccount.JSON_KEY_IDENTITY_KEY] }
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## MXOlmDevice : cannot find ${OlmAccount.JSON_KEY_IDENTITY_KEY} with error") Timber.tag(loggerTag.value).e(e, "## MXOlmDevice : cannot find ${OlmAccount.JSON_KEY_IDENTITY_KEY} with error")
} }
try { try {
deviceEd25519Key = store.getOlmAccount().identityKeys()[OlmAccount.JSON_KEY_FINGER_PRINT_KEY] deviceEd25519Key = store.doWithOlmAccount { it.identityKeys()[OlmAccount.JSON_KEY_FINGER_PRINT_KEY] }
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## MXOlmDevice : cannot find ${OlmAccount.JSON_KEY_FINGER_PRINT_KEY} with error") Timber.tag(loggerTag.value).e(e, "## MXOlmDevice : cannot find ${OlmAccount.JSON_KEY_FINGER_PRINT_KEY} with error")
} }
} }
@ -121,9 +131,9 @@ internal class MXOlmDevice @Inject constructor(
*/ */
fun getOneTimeKeys(): Map<String, Map<String, String>>? { fun getOneTimeKeys(): Map<String, Map<String, String>>? {
try { try {
return store.getOlmAccount().oneTimeKeys() return store.doWithOlmAccount { it.oneTimeKeys() }
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## getOneTimeKeys() : failed") Timber.tag(loggerTag.value).e(e, "## getOneTimeKeys() : failed")
} }
return null return null
@ -133,7 +143,7 @@ internal class MXOlmDevice @Inject constructor(
* @return The maximum number of one-time keys the olm account can store. * @return The maximum number of one-time keys the olm account can store.
*/ */
fun getMaxNumberOfOneTimeKeys(): Long { fun getMaxNumberOfOneTimeKeys(): Long {
return store.getOlmAccount().maxOneTimeKeys() return store.doWithOlmAccount { it.maxOneTimeKeys() }
} }
/** /**
@ -143,9 +153,9 @@ internal class MXOlmDevice @Inject constructor(
*/ */
fun getFallbackKey(): MutableMap<String, MutableMap<String, String>>? { fun getFallbackKey(): MutableMap<String, MutableMap<String, String>>? {
try { try {
return store.getOlmAccount().fallbackKey() return store.doWithOlmAccount { it.fallbackKey() }
} catch (e: Exception) { } catch (e: Exception) {
Timber.e("## getFallbackKey() : failed") Timber.tag(loggerTag.value).e("## getFallbackKey() : failed")
} }
return null return null
} }
@ -158,12 +168,14 @@ internal class MXOlmDevice @Inject constructor(
fun generateFallbackKeyIfNeeded(): Boolean { fun generateFallbackKeyIfNeeded(): Boolean {
try { try {
if (!hasUnpublishedFallbackKey()) { if (!hasUnpublishedFallbackKey()) {
store.getOlmAccount().generateFallbackKey() store.doWithOlmAccount {
store.saveOlmAccount() it.generateFallbackKey()
store.saveOlmAccount()
}
return true return true
} }
} catch (e: Exception) { } catch (e: Exception) {
Timber.e("## generateFallbackKey() : failed") Timber.tag(loggerTag.value).e("## generateFallbackKey() : failed")
} }
return false return false
} }
@ -174,10 +186,12 @@ internal class MXOlmDevice @Inject constructor(
fun forgetFallbackKey() { fun forgetFallbackKey() {
try { try {
store.getOlmAccount().forgetFallbackKey() store.doWithOlmAccount {
store.saveOlmAccount() it.forgetFallbackKey()
store.saveOlmAccount()
}
} catch (e: Exception) { } catch (e: Exception) {
Timber.e("## forgetFallbackKey() : failed") Timber.tag(loggerTag.value).e("## forgetFallbackKey() : failed")
} }
} }
@ -190,6 +204,8 @@ internal class MXOlmDevice @Inject constructor(
it.groupSession.releaseSession() it.groupSession.releaseSession()
} }
outboundGroupSessionCache.clear() outboundGroupSessionCache.clear()
inboundGroupSessionStore.clear()
olmSessionStore.clear()
} }
/** /**
@ -200,9 +216,9 @@ internal class MXOlmDevice @Inject constructor(
*/ */
fun signMessage(message: String): String? { fun signMessage(message: String): String? {
try { try {
return store.getOlmAccount().signMessage(message) return store.doWithOlmAccount { it.signMessage(message) }
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## signMessage() : failed") Timber.tag(loggerTag.value).e(e, "## signMessage() : failed")
} }
return null return null
@ -213,10 +229,12 @@ internal class MXOlmDevice @Inject constructor(
*/ */
fun markKeysAsPublished() { fun markKeysAsPublished() {
try { try {
store.getOlmAccount().markOneTimeKeysAsPublished() store.doWithOlmAccount {
store.saveOlmAccount() it.markOneTimeKeysAsPublished()
store.saveOlmAccount()
}
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## markKeysAsPublished() : failed") Timber.tag(loggerTag.value).e(e, "## markKeysAsPublished() : failed")
} }
} }
@ -227,10 +245,12 @@ internal class MXOlmDevice @Inject constructor(
*/ */
fun generateOneTimeKeys(numKeys: Int) { fun generateOneTimeKeys(numKeys: Int) {
try { try {
store.getOlmAccount().generateOneTimeKeys(numKeys) store.doWithOlmAccount {
store.saveOlmAccount() it.generateOneTimeKeys(numKeys)
store.saveOlmAccount()
}
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## generateOneTimeKeys() : failed") Timber.tag(loggerTag.value).e(e, "## generateOneTimeKeys() : failed")
} }
} }
@ -243,12 +263,14 @@ internal class MXOlmDevice @Inject constructor(
* @return the session id for the outbound session. * @return the session id for the outbound session.
*/ */
fun createOutboundSession(theirIdentityKey: String, theirOneTimeKey: String): String? { fun createOutboundSession(theirIdentityKey: String, theirOneTimeKey: String): String? {
Timber.v("## createOutboundSession() ; theirIdentityKey $theirIdentityKey theirOneTimeKey $theirOneTimeKey") Timber.tag(loggerTag.value).d("## createOutboundSession() ; theirIdentityKey $theirIdentityKey theirOneTimeKey $theirOneTimeKey")
var olmSession: OlmSession? = null var olmSession: OlmSession? = null
try { try {
olmSession = OlmSession() olmSession = OlmSession()
olmSession.initOutboundSession(store.getOlmAccount(), theirIdentityKey, theirOneTimeKey) store.doWithOlmAccount { olmAccount ->
olmSession.initOutboundSession(olmAccount, theirIdentityKey, theirOneTimeKey)
}
val olmSessionWrapper = OlmSessionWrapper(olmSession, 0) val olmSessionWrapper = OlmSessionWrapper(olmSession, 0)
@ -257,14 +279,14 @@ internal class MXOlmDevice @Inject constructor(
// this session // this session
olmSessionWrapper.onMessageReceived() olmSessionWrapper.onMessageReceived()
store.storeSession(olmSessionWrapper, theirIdentityKey) olmSessionStore.storeSession(olmSessionWrapper, theirIdentityKey)
val sessionIdentifier = olmSession.sessionIdentifier() val sessionIdentifier = olmSession.sessionIdentifier()
Timber.v("## createOutboundSession() ; olmSession.sessionIdentifier: $sessionIdentifier") Timber.tag(loggerTag.value).v("## createOutboundSession() ; olmSession.sessionIdentifier: $sessionIdentifier")
return sessionIdentifier return sessionIdentifier
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## createOutboundSession() failed") Timber.tag(loggerTag.value).e(e, "## createOutboundSession() failed")
olmSession?.releaseSession() olmSession?.releaseSession()
} }
@ -281,34 +303,38 @@ internal class MXOlmDevice @Inject constructor(
* @return {{payload: string, session_id: string}} decrypted payload, and session id of new session. * @return {{payload: string, session_id: string}} decrypted payload, and session id of new session.
*/ */
fun createInboundSession(theirDeviceIdentityKey: String, messageType: Int, ciphertext: String): Map<String, String>? { fun createInboundSession(theirDeviceIdentityKey: String, messageType: Int, ciphertext: String): Map<String, String>? {
Timber.v("## createInboundSession() : theirIdentityKey: $theirDeviceIdentityKey") Timber.tag(loggerTag.value).d("## createInboundSession() : theirIdentityKey: $theirDeviceIdentityKey")
var olmSession: OlmSession? = null var olmSession: OlmSession? = null
try { try {
try { try {
olmSession = OlmSession() olmSession = OlmSession()
olmSession.initInboundSessionFrom(store.getOlmAccount(), theirDeviceIdentityKey, ciphertext) store.doWithOlmAccount { olmAccount ->
olmSession.initInboundSessionFrom(olmAccount, theirDeviceIdentityKey, ciphertext)
}
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## createInboundSession() : the session creation failed") Timber.tag(loggerTag.value).e(e, "## createInboundSession() : the session creation failed")
return null return null
} }
Timber.v("## createInboundSession() : sessionId: ${olmSession.sessionIdentifier()}") Timber.tag(loggerTag.value).v("## createInboundSession() : sessionId: ${olmSession.sessionIdentifier()}")
try { try {
store.getOlmAccount().removeOneTimeKeys(olmSession) store.doWithOlmAccount { olmAccount ->
store.saveOlmAccount() olmAccount.removeOneTimeKeys(olmSession)
store.saveOlmAccount()
}
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## createInboundSession() : removeOneTimeKeys failed") Timber.tag(loggerTag.value).e(e, "## createInboundSession() : removeOneTimeKeys failed")
} }
Timber.v("## createInboundSession() : ciphertext: $ciphertext") Timber.tag(loggerTag.value).v("## createInboundSession() : ciphertext: $ciphertext")
try { try {
val sha256 = olmUtility!!.sha256(URLEncoder.encode(ciphertext, "utf-8")) val sha256 = olmUtility!!.sha256(URLEncoder.encode(ciphertext, "utf-8"))
Timber.v("## createInboundSession() :ciphertext: SHA256: $sha256") Timber.tag(loggerTag.value).v("## createInboundSession() :ciphertext: SHA256: $sha256")
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## createInboundSession() :ciphertext: cannot encode ciphertext") Timber.tag(loggerTag.value).e(e, "## createInboundSession() :ciphertext: cannot encode ciphertext")
} }
val olmMessage = OlmMessage() val olmMessage = OlmMessage()
@ -324,9 +350,9 @@ internal class MXOlmDevice @Inject constructor(
// This counts as a received message: set last received message time to now // This counts as a received message: set last received message time to now
olmSessionWrapper.onMessageReceived() olmSessionWrapper.onMessageReceived()
store.storeSession(olmSessionWrapper, theirDeviceIdentityKey) olmSessionStore.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## createInboundSession() : decryptMessage failed") Timber.tag(loggerTag.value).e(e, "## createInboundSession() : decryptMessage failed")
} }
val res = HashMap<String, String>() val res = HashMap<String, String>()
@ -343,7 +369,7 @@ internal class MXOlmDevice @Inject constructor(
return res return res
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## createInboundSession() : OlmSession creation failed") Timber.tag(loggerTag.value).e(e, "## createInboundSession() : OlmSession creation failed")
olmSession?.releaseSession() olmSession?.releaseSession()
} }
@ -357,8 +383,8 @@ internal class MXOlmDevice @Inject constructor(
* @param theirDeviceIdentityKey the Curve25519 identity key for the remote device. * @param theirDeviceIdentityKey the Curve25519 identity key for the remote device.
* @return a list of known session ids for the device. * @return a list of known session ids for the device.
*/ */
fun getSessionIds(theirDeviceIdentityKey: String): List<String>? { fun getSessionIds(theirDeviceIdentityKey: String): List<String> {
return store.getDeviceSessionIds(theirDeviceIdentityKey) return olmSessionStore.getDeviceSessionIds(theirDeviceIdentityKey)
} }
/** /**
@ -368,7 +394,7 @@ internal class MXOlmDevice @Inject constructor(
* @return the session id, or null if no established session. * @return the session id, or null if no established session.
*/ */
fun getSessionId(theirDeviceIdentityKey: String): String? { fun getSessionId(theirDeviceIdentityKey: String): String? {
return store.getLastUsedSessionId(theirDeviceIdentityKey) return olmSessionStore.getLastUsedSessionId(theirDeviceIdentityKey)
} }
/** /**
@ -379,30 +405,30 @@ internal class MXOlmDevice @Inject constructor(
* @param payloadString the payload to be encrypted and sent * @param payloadString the payload to be encrypted and sent
* @return the cipher text * @return the cipher text
*/ */
fun encryptMessage(theirDeviceIdentityKey: String, sessionId: String, payloadString: String): Map<String, Any>? { suspend fun encryptMessage(theirDeviceIdentityKey: String, sessionId: String, payloadString: String): Map<String, Any>? {
var res: MutableMap<String, Any>? = null
val olmMessage: OlmMessage
val olmSessionWrapper = getSessionForDevice(theirDeviceIdentityKey, sessionId) val olmSessionWrapper = getSessionForDevice(theirDeviceIdentityKey, sessionId)
if (olmSessionWrapper != null) { if (olmSessionWrapper != null) {
try { try {
Timber.v("## encryptMessage() : olmSession.sessionIdentifier: $sessionId") Timber.tag(loggerTag.value).v("## encryptMessage() : olmSession.sessionIdentifier: $sessionId")
// Timber.v("## encryptMessage() : payloadString: " + payloadString);
olmMessage = olmSessionWrapper.olmSession.encryptMessage(payloadString) val olmMessage = olmSessionWrapper.mutex.withLock {
store.storeSession(olmSessionWrapper, theirDeviceIdentityKey) olmSessionWrapper.olmSession.encryptMessage(payloadString)
res = HashMap() }
return mapOf(
res["body"] = olmMessage.mCipherText "body" to olmMessage.mCipherText,
res["type"] = olmMessage.mType "type" to olmMessage.mType,
} catch (e: Exception) { ).also {
Timber.e(e, "## encryptMessage() : failed") olmSessionStore.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
}
} catch (e: Throwable) {
Timber.tag(loggerTag.value).e(e, "## encryptMessage() : failed to encrypt olm with device|session:$theirDeviceIdentityKey|$sessionId")
return null
} }
} else { } else {
Timber.e("## encryptMessage() : Failed to encrypt unknown session $sessionId") Timber.tag(loggerTag.value).e("## encryptMessage() : Failed to encrypt unknown session $sessionId")
return null
} }
return res
} }
/** /**
@ -414,7 +440,8 @@ internal class MXOlmDevice @Inject constructor(
* @param sessionId the id of the active session. * @param sessionId the id of the active session.
* @return the decrypted payload. * @return the decrypted payload.
*/ */
fun decryptMessage(ciphertext: String, messageType: Int, sessionId: String, theirDeviceIdentityKey: String): String? { @kotlin.jvm.Throws
suspend fun decryptMessage(ciphertext: String, messageType: Int, sessionId: String, theirDeviceIdentityKey: String): String? {
var payloadString: String? = null var payloadString: String? = null
val olmSessionWrapper = getSessionForDevice(theirDeviceIdentityKey, sessionId) val olmSessionWrapper = getSessionForDevice(theirDeviceIdentityKey, sessionId)
@ -424,13 +451,13 @@ internal class MXOlmDevice @Inject constructor(
olmMessage.mCipherText = ciphertext olmMessage.mCipherText = ciphertext
olmMessage.mType = messageType.toLong() olmMessage.mType = messageType.toLong()
try { payloadString =
payloadString = olmSessionWrapper.olmSession.decryptMessage(olmMessage) olmSessionWrapper.mutex.withLock {
olmSessionWrapper.onMessageReceived() olmSessionWrapper.olmSession.decryptMessage(olmMessage).also {
store.storeSession(olmSessionWrapper, theirDeviceIdentityKey) olmSessionWrapper.onMessageReceived()
} catch (e: Exception) { }
Timber.e(e, "## decryptMessage() : decryptMessage failed") }
} olmSessionStore.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
} }
return payloadString return payloadString
@ -469,7 +496,7 @@ internal class MXOlmDevice @Inject constructor(
store.storeCurrentOutboundGroupSessionForRoom(roomId, session) store.storeCurrentOutboundGroupSessionForRoom(roomId, session)
return session.sessionIdentifier() return session.sessionIdentifier()
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "createOutboundGroupSession") Timber.tag(loggerTag.value).e(e, "createOutboundGroupSession")
session?.releaseSession() session?.releaseSession()
} }
@ -521,7 +548,7 @@ internal class MXOlmDevice @Inject constructor(
try { try {
return outboundGroupSessionCache[sessionId]!!.groupSession.sessionKey() return outboundGroupSessionCache[sessionId]!!.groupSession.sessionKey()
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## getSessionKey() : failed") Timber.tag(loggerTag.value).e(e, "## getSessionKey() : failed")
} }
} }
return null return null
@ -550,8 +577,8 @@ internal class MXOlmDevice @Inject constructor(
if (sessionId.isNotEmpty() && payloadString.isNotEmpty()) { if (sessionId.isNotEmpty() && payloadString.isNotEmpty()) {
try { try {
return outboundGroupSessionCache[sessionId]!!.groupSession.encryptMessage(payloadString) return outboundGroupSessionCache[sessionId]!!.groupSession.encryptMessage(payloadString)
} catch (e: Exception) { } catch (e: Throwable) {
Timber.e(e, "## encryptGroupMessage() : failed") Timber.tag(loggerTag.value).e(e, "## encryptGroupMessage() : failed")
} }
} }
return null return null
@ -578,52 +605,64 @@ internal class MXOlmDevice @Inject constructor(
forwardingCurve25519KeyChain: List<String>, forwardingCurve25519KeyChain: List<String>,
keysClaimed: Map<String, String>, keysClaimed: Map<String, String>,
exportFormat: Boolean): Boolean { exportFormat: Boolean): Boolean {
val session = OlmInboundGroupSessionWrapper2(sessionKey, exportFormat) val candidateSession = OlmInboundGroupSessionWrapper2(sessionKey, exportFormat)
runCatching { getInboundGroupSession(sessionId, senderKey, roomId) } val existingSessionHolder = tryOrNull { getInboundGroupSession(sessionId, senderKey, roomId) }
.fold( val existingSession = existingSessionHolder?.wrapper
{ // If we have an existing one we should check if the new one is not better
// If we already have this session, consider updating it if (existingSession != null) {
Timber.e("## addInboundGroupSession() : Update for megolm session $senderKey/$sessionId") Timber.tag(loggerTag.value).d("## addInboundGroupSession() check if known session is better than candidate session")
try {
val existingFirstKnown = existingSession.firstKnownIndex ?: return false.also {
// This is quite unexpected, could throw if native was released?
Timber.tag(loggerTag.value).e("## addInboundGroupSession() null firstKnownIndex on existing session")
candidateSession.olmInboundGroupSession?.releaseSession()
// Probably should discard it?
}
val newKnownFirstIndex = candidateSession.firstKnownIndex
// If our existing session is better we keep it
if (newKnownFirstIndex != null && existingFirstKnown <= newKnownFirstIndex) {
Timber.tag(loggerTag.value).d("## addInboundGroupSession() : ignore session our is better $senderKey/$sessionId")
candidateSession.olmInboundGroupSession?.releaseSession()
return false
}
} catch (failure: Throwable) {
Timber.tag(loggerTag.value).e("## addInboundGroupSession() Failed to add inbound: ${failure.localizedMessage}")
candidateSession.olmInboundGroupSession?.releaseSession()
return false
}
}
val existingFirstKnown = it.firstKnownIndex!! Timber.tag(loggerTag.value).d("## addInboundGroupSession() : Candidate session should be added $senderKey/$sessionId")
val newKnownFirstIndex = session.firstKnownIndex
// If our existing session is better we keep it // sanity check on the new session
if (newKnownFirstIndex != null && existingFirstKnown <= newKnownFirstIndex) { val candidateOlmInboundSession = candidateSession.olmInboundGroupSession
session.olmInboundGroupSession?.releaseSession() if (null == candidateOlmInboundSession) {
return false Timber.tag(loggerTag.value).e("## addInboundGroupSession : invalid session <null>")
}
},
{
// Nothing to do in case of error
}
)
// sanity check
if (null == session.olmInboundGroupSession) {
Timber.e("## addInboundGroupSession : invalid session")
return false return false
} }
try { try {
if (session.olmInboundGroupSession!!.sessionIdentifier() != sessionId) { if (candidateOlmInboundSession.sessionIdentifier() != sessionId) {
Timber.e("## addInboundGroupSession : ERROR: Mismatched group session ID from senderKey: $senderKey") Timber.tag(loggerTag.value).e("## addInboundGroupSession : ERROR: Mismatched group session ID from senderKey: $senderKey")
session.olmInboundGroupSession!!.releaseSession() candidateOlmInboundSession.releaseSession()
return false return false
} }
} catch (e: Exception) { } catch (e: Throwable) {
session.olmInboundGroupSession?.releaseSession() candidateOlmInboundSession.releaseSession()
Timber.e(e, "## addInboundGroupSession : sessionIdentifier() failed") Timber.tag(loggerTag.value).e(e, "## addInboundGroupSession : sessionIdentifier() failed")
return false return false
} }
session.senderKey = senderKey candidateSession.senderKey = senderKey
session.roomId = roomId candidateSession.roomId = roomId
session.keysClaimed = keysClaimed candidateSession.keysClaimed = keysClaimed
session.forwardingCurve25519KeyChain = forwardingCurve25519KeyChain candidateSession.forwardingCurve25519KeyChain = forwardingCurve25519KeyChain
inboundGroupSessionStore.storeInBoundGroupSession(session, sessionId, senderKey) if (existingSession != null) {
// store.storeInboundGroupSessions(listOf(session)) inboundGroupSessionStore.replaceGroupSession(existingSessionHolder, InboundGroupSessionHolder(candidateSession), sessionId, senderKey)
} else {
inboundGroupSessionStore.storeInBoundGroupSession(InboundGroupSessionHolder(candidateSession), sessionId, senderKey)
}
return true return true
} }
@ -638,57 +677,70 @@ internal class MXOlmDevice @Inject constructor(
val sessions = ArrayList<OlmInboundGroupSessionWrapper2>(megolmSessionsData.size) val sessions = ArrayList<OlmInboundGroupSessionWrapper2>(megolmSessionsData.size)
for (megolmSessionData in megolmSessionsData) { for (megolmSessionData in megolmSessionsData) {
val sessionId = megolmSessionData.sessionId val sessionId = megolmSessionData.sessionId ?: continue
val senderKey = megolmSessionData.senderKey val senderKey = megolmSessionData.senderKey ?: continue
val roomId = megolmSessionData.roomId val roomId = megolmSessionData.roomId
var session: OlmInboundGroupSessionWrapper2? = null var candidateSessionToImport: OlmInboundGroupSessionWrapper2? = null
try { try {
session = OlmInboundGroupSessionWrapper2(megolmSessionData) candidateSessionToImport = OlmInboundGroupSessionWrapper2(megolmSessionData)
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## importInboundGroupSession() : Update for megolm session $senderKey/$sessionId") Timber.tag(loggerTag.value).e(e, "## importInboundGroupSession() : Update for megolm session $senderKey/$sessionId")
} }
// sanity check // sanity check
if (session?.olmInboundGroupSession == null) { if (candidateSessionToImport?.olmInboundGroupSession == null) {
Timber.e("## importInboundGroupSession : invalid session") Timber.tag(loggerTag.value).e("## importInboundGroupSession : invalid session")
continue continue
} }
val candidateOlmInboundGroupSession = candidateSessionToImport.olmInboundGroupSession
try { try {
if (session.olmInboundGroupSession?.sessionIdentifier() != sessionId) { if (candidateOlmInboundGroupSession?.sessionIdentifier() != sessionId) {
Timber.e("## importInboundGroupSession : ERROR: Mismatched group session ID from senderKey: $senderKey") Timber.tag(loggerTag.value).e("## importInboundGroupSession : ERROR: Mismatched group session ID from senderKey: $senderKey")
if (session.olmInboundGroupSession != null) session.olmInboundGroupSession!!.releaseSession() candidateOlmInboundGroupSession?.releaseSession()
continue continue
} }
} catch (e: Exception) { } catch (e: Exception) {
Timber.e(e, "## importInboundGroupSession : sessionIdentifier() failed") Timber.tag(loggerTag.value).e(e, "## importInboundGroupSession : sessionIdentifier() failed")
session.olmInboundGroupSession!!.releaseSession() candidateOlmInboundGroupSession?.releaseSession()
continue continue
} }
runCatching { getInboundGroupSession(sessionId, senderKey, roomId) } val existingSessionHolder = tryOrNull { getInboundGroupSession(sessionId, senderKey, roomId) }
.fold( val existingSession = existingSessionHolder?.wrapper
{
// If we already have this session, consider updating it
Timber.e("## importInboundGroupSession() : Update for megolm session $senderKey/$sessionId")
// For now we just ignore updates. TODO: implement something here if (existingSession == null) {
if (it.firstKnownIndex!! <= session.firstKnownIndex!!) { // Session does not already exist, add it
// Ignore this, keep existing Timber.tag(loggerTag.value).d("## importInboundGroupSession() : importing new megolm session $senderKey/$sessionId")
session.olmInboundGroupSession!!.releaseSession() sessions.add(candidateSessionToImport)
} else { } else {
sessions.add(session) Timber.tag(loggerTag.value).e("## importInboundGroupSession() : Update for megolm session $senderKey/$sessionId")
} val existingFirstKnown = tryOrNull { existingSession.firstKnownIndex }
Unit val candidateFirstKnownIndex = tryOrNull { candidateSessionToImport.firstKnownIndex }
},
{
// Session does not already exist, add it
sessions.add(session)
}
) if (existingFirstKnown == null || candidateFirstKnownIndex == null) {
// should not happen?
candidateSessionToImport.olmInboundGroupSession?.releaseSession()
Timber.tag(loggerTag.value)
.w("## importInboundGroupSession() : Can't check session null index $existingFirstKnown/$candidateFirstKnownIndex")
} else {
if (existingFirstKnown <= candidateSessionToImport.firstKnownIndex!!) {
// Ignore this, keep existing
candidateOlmInboundGroupSession.releaseSession()
} else {
// update cache with better session
inboundGroupSessionStore.replaceGroupSession(
existingSessionHolder,
InboundGroupSessionHolder(candidateSessionToImport),
sessionId,
senderKey
)
sessions.add(candidateSessionToImport)
}
}
}
} }
store.storeInboundGroupSessions(sessions) store.storeInboundGroupSessions(sessions)
@ -696,18 +748,6 @@ internal class MXOlmDevice @Inject constructor(
return sessions return sessions
} }
/**
* Remove an inbound group session
*
* @param sessionId the session identifier.
* @param sessionKey base64-encoded secret key.
*/
fun removeInboundGroupSession(sessionId: String?, sessionKey: String?) {
if (null != sessionId && null != sessionKey) {
store.removeInboundGroupSession(sessionId, sessionKey)
}
}
/** /**
* Decrypt a received message with an inbound group session. * Decrypt a received message with an inbound group session.
* *
@ -719,19 +759,24 @@ internal class MXOlmDevice @Inject constructor(
* @return the decrypting result. Nil if the sessionId is unknown. * @return the decrypting result. Nil if the sessionId is unknown.
*/ */
@Throws(MXCryptoError::class) @Throws(MXCryptoError::class)
fun decryptGroupMessage(body: String, suspend fun decryptGroupMessage(body: String,
roomId: String, roomId: String,
timeline: String?, timeline: String?,
sessionId: String, sessionId: String,
senderKey: String): OlmDecryptionResult { senderKey: String): OlmDecryptionResult {
val session = getInboundGroupSession(sessionId, senderKey, roomId) val sessionHolder = getInboundGroupSession(sessionId, senderKey, roomId)
val wrapper = sessionHolder.wrapper
val inboundGroupSession = wrapper.olmInboundGroupSession
?: throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, "Session is null")
// Check that the room id matches the original one for the session. This stops // Check that the room id matches the original one for the session. This stops
// the HS pretending a message was targeting a different room. // the HS pretending a message was targeting a different room.
if (roomId == session.roomId) { if (roomId == wrapper.roomId) {
val decryptResult = try { val decryptResult = try {
session.olmInboundGroupSession!!.decryptMessage(body) sessionHolder.mutex.withLock {
inboundGroupSession.decryptMessage(body)
}
} catch (e: OlmException) { } catch (e: OlmException) {
Timber.e(e, "## decryptGroupMessage () : decryptMessage failed") Timber.tag(loggerTag.value).e(e, "## decryptGroupMessage () : decryptMessage failed")
throw MXCryptoError.OlmError(e) throw MXCryptoError.OlmError(e)
} }
@ -742,32 +787,32 @@ internal class MXOlmDevice @Inject constructor(
if (timelineSet.contains(messageIndexKey)) { if (timelineSet.contains(messageIndexKey)) {
val reason = String.format(MXCryptoError.DUPLICATE_MESSAGE_INDEX_REASON, decryptResult.mIndex) val reason = String.format(MXCryptoError.DUPLICATE_MESSAGE_INDEX_REASON, decryptResult.mIndex)
Timber.e("## decryptGroupMessage() : $reason") Timber.tag(loggerTag.value).e("## decryptGroupMessage() : $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.DUPLICATED_MESSAGE_INDEX, reason) throw MXCryptoError.Base(MXCryptoError.ErrorType.DUPLICATED_MESSAGE_INDEX, reason)
} }
timelineSet.add(messageIndexKey) timelineSet.add(messageIndexKey)
} }
inboundGroupSessionStore.storeInBoundGroupSession(session, sessionId, senderKey) inboundGroupSessionStore.storeInBoundGroupSession(sessionHolder, sessionId, senderKey)
val payload = try { val payload = try {
val adapter = MoshiProvider.providesMoshi().adapter<JsonDict>(JSON_DICT_PARAMETERIZED_TYPE) val adapter = MoshiProvider.providesMoshi().adapter<JsonDict>(JSON_DICT_PARAMETERIZED_TYPE)
val payloadString = convertFromUTF8(decryptResult.mDecryptedMessage) val payloadString = convertFromUTF8(decryptResult.mDecryptedMessage)
adapter.fromJson(payloadString) adapter.fromJson(payloadString)
} catch (e: Exception) { } catch (e: Exception) {
Timber.e("## decryptGroupMessage() : fails to parse the payload") Timber.tag(loggerTag.value).e("## decryptGroupMessage() : fails to parse the payload")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_DECRYPTED_FORMAT, MXCryptoError.BAD_DECRYPTED_FORMAT_TEXT_REASON) throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_DECRYPTED_FORMAT, MXCryptoError.BAD_DECRYPTED_FORMAT_TEXT_REASON)
} }
return OlmDecryptionResult( return OlmDecryptionResult(
payload, payload,
session.keysClaimed, wrapper.keysClaimed,
senderKey, senderKey,
session.forwardingCurve25519KeyChain wrapper.forwardingCurve25519KeyChain
) )
} else { } else {
val reason = String.format(MXCryptoError.INBOUND_SESSION_MISMATCH_ROOM_ID_REASON, roomId, session.roomId) val reason = String.format(MXCryptoError.INBOUND_SESSION_MISMATCH_ROOM_ID_REASON, roomId, wrapper.roomId)
Timber.e("## decryptGroupMessage() : $reason") Timber.tag(loggerTag.value).e("## decryptGroupMessage() : $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.INBOUND_SESSION_MISMATCH_ROOM_ID, reason) throw MXCryptoError.Base(MXCryptoError.ErrorType.INBOUND_SESSION_MISMATCH_ROOM_ID, reason)
} }
} }
@ -819,7 +864,7 @@ internal class MXOlmDevice @Inject constructor(
private fun getSessionForDevice(theirDeviceIdentityKey: String, sessionId: String): OlmSessionWrapper? { private fun getSessionForDevice(theirDeviceIdentityKey: String, sessionId: String): OlmSessionWrapper? {
// sanity check // sanity check
return if (theirDeviceIdentityKey.isEmpty() || sessionId.isEmpty()) null else { return if (theirDeviceIdentityKey.isEmpty() || sessionId.isEmpty()) null else {
store.getDeviceSession(sessionId, theirDeviceIdentityKey) olmSessionStore.getDeviceSession(sessionId, theirDeviceIdentityKey)
} }
} }
@ -832,25 +877,26 @@ internal class MXOlmDevice @Inject constructor(
* @param senderKey the base64-encoded curve25519 key of the sender. * @param senderKey the base64-encoded curve25519 key of the sender.
* @return the inbound group session. * @return the inbound group session.
*/ */
fun getInboundGroupSession(sessionId: String?, senderKey: String?, roomId: String?): OlmInboundGroupSessionWrapper2 { fun getInboundGroupSession(sessionId: String?, senderKey: String?, roomId: String?): InboundGroupSessionHolder {
if (sessionId.isNullOrBlank() || senderKey.isNullOrBlank()) { if (sessionId.isNullOrBlank() || senderKey.isNullOrBlank()) {
throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_SENDER_KEY, MXCryptoError.ERROR_MISSING_PROPERTY_REASON) throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_SENDER_KEY, MXCryptoError.ERROR_MISSING_PROPERTY_REASON)
} }
val session = inboundGroupSessionStore.getInboundGroupSession(sessionId, senderKey) val holder = inboundGroupSessionStore.getInboundGroupSession(sessionId, senderKey)
val session = holder?.wrapper
if (session != null) { if (session != null) {
// Check that the room id matches the original one for the session. This stops // Check that the room id matches the original one for the session. This stops
// the HS pretending a message was targeting a different room. // the HS pretending a message was targeting a different room.
if (roomId != session.roomId) { if (roomId != session.roomId) {
val errorDescription = String.format(MXCryptoError.INBOUND_SESSION_MISMATCH_ROOM_ID_REASON, roomId, session.roomId) val errorDescription = String.format(MXCryptoError.INBOUND_SESSION_MISMATCH_ROOM_ID_REASON, roomId, session.roomId)
Timber.e("## getInboundGroupSession() : $errorDescription") Timber.tag(loggerTag.value).e("## getInboundGroupSession() : $errorDescription")
throw MXCryptoError.Base(MXCryptoError.ErrorType.INBOUND_SESSION_MISMATCH_ROOM_ID, errorDescription) throw MXCryptoError.Base(MXCryptoError.ErrorType.INBOUND_SESSION_MISMATCH_ROOM_ID, errorDescription)
} else { } else {
return session return holder
} }
} else { } else {
Timber.w("## getInboundGroupSession() : Cannot retrieve inbound group session $sessionId") Timber.tag(loggerTag.value).w("## getInboundGroupSession() : UISI $sessionId")
throw MXCryptoError.Base(MXCryptoError.ErrorType.UNKNOWN_INBOUND_SESSION_ID, MXCryptoError.UNKNOWN_INBOUND_SESSION_ID_REASON) throw MXCryptoError.Base(MXCryptoError.ErrorType.UNKNOWN_INBOUND_SESSION_ID, MXCryptoError.UNKNOWN_INBOUND_SESSION_ID_REASON)
} }
} }
@ -866,4 +912,9 @@ internal class MXOlmDevice @Inject constructor(
fun hasInboundSessionKeys(roomId: String, senderKey: String, sessionId: String): Boolean { fun hasInboundSessionKeys(roomId: String, senderKey: String, sessionId: String): Boolean {
return runCatching { getInboundGroupSession(sessionId, senderKey, roomId) }.isSuccess return runCatching { getInboundGroupSession(sessionId, senderKey, roomId) }.isSuccess
} }
@VisibleForTesting
fun clearOlmSessionCache() {
olmSessionStore.clear()
}
} }

View File

@ -0,0 +1,159 @@
/*
* Copyright 2022 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.matrix.android.sdk.internal.crypto
import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.internal.crypto.model.OlmSessionWrapper
import org.matrix.android.sdk.internal.crypto.store.IMXCryptoStore
import org.matrix.olm.OlmSession
import timber.log.Timber
import javax.inject.Inject
private val loggerTag = LoggerTag("OlmSessionStore", LoggerTag.CRYPTO)
/**
* Keep the used olm session in memory and load them from the data layer when needed
* Access is synchronized for thread safety
*/
internal class OlmSessionStore @Inject constructor(private val store: IMXCryptoStore) {
/**
* map of device key to list of olm sessions (it is possible to have several active sessions with a device)
*/
private val olmSessions = HashMap<String, MutableList<OlmSessionWrapper>>()
/**
* Store a session between our own device and another device.
* This will be called after the session has been created but also every time it has been used
* in order to persist the correct state for next run
* @param olmSessionWrapper the end-to-end session.
* @param deviceKey the public key of the other device.
*/
@Synchronized
fun storeSession(olmSessionWrapper: OlmSessionWrapper, deviceKey: String) {
// This could be a newly created session or one that was just created
// Anyhow we should persist ratchet state for future app lifecycle
addNewSessionInCache(olmSessionWrapper, deviceKey)
store.storeSession(olmSessionWrapper, deviceKey)
}
/**
* Get all the Olm Sessions we are sharing with the given device.
*
* @param deviceKey the public key of the other device.
* @return A set of sessionId, or empty if device is not known
*/
@Synchronized
fun getDeviceSessionIds(deviceKey: String): List<String> {
// we need to get the persisted ids first
val persistedKnownSessions = store.getDeviceSessionIds(deviceKey)
.orEmpty()
.toMutableList()
// Do we have some in cache not yet persisted?
olmSessions.getOrPut(deviceKey) { mutableListOf() }.forEach { cached ->
getSafeSessionIdentifier(cached.olmSession)?.let { cachedSessionId ->
if (!persistedKnownSessions.contains(cachedSessionId)) {
persistedKnownSessions.add(cachedSessionId)
}
}
}
return persistedKnownSessions
}
/**
* Retrieve an end-to-end session between our own device and another
* device.
*
* @param sessionId the session Id.
* @param deviceKey the public key of the other device.
* @return the session wrapper if found
*/
@Synchronized
fun getDeviceSession(sessionId: String, deviceKey: String): OlmSessionWrapper? {
// get from cache or load and add to cache
return internalGetSession(sessionId, deviceKey)
}
/**
* Retrieve the last used sessionId, regarding `lastReceivedMessageTs`, or null if no session exist
*
* @param deviceKey the public key of the other device.
* @return last used sessionId, or null if not found
*/
@Synchronized
fun getLastUsedSessionId(deviceKey: String): String? {
// We want to avoid to load in memory old session if possible
val lastPersistedUsedSession = store.getLastUsedSessionId(deviceKey)
var candidate = lastPersistedUsedSession?.let { internalGetSession(it, deviceKey) }
// we should check if we have one in cache with a higher last message received?
olmSessions[deviceKey].orEmpty().forEach { inCache ->
if (inCache.lastReceivedMessageTs > (candidate?.lastReceivedMessageTs ?: 0L)) {
candidate = inCache
}
}
return candidate?.olmSession?.sessionIdentifier()
}
/**
* Release all sessions and clear cache
*/
@Synchronized
fun clear() {
olmSessions.entries.onEach { entry ->
entry.value.onEach { it.olmSession.releaseSession() }
}
olmSessions.clear()
}
private fun internalGetSession(sessionId: String, deviceKey: String): OlmSessionWrapper? {
return getSessionInCache(sessionId, deviceKey)
?: // deserialize from store
return store.getDeviceSession(sessionId, deviceKey)?.also {
addNewSessionInCache(it, deviceKey)
}
}
private fun getSessionInCache(sessionId: String, deviceKey: String): OlmSessionWrapper? {
return olmSessions[deviceKey]?.firstOrNull {
getSafeSessionIdentifier(it.olmSession) == sessionId
}
}
private fun getSafeSessionIdentifier(session: OlmSession): String? {
return try {
session.sessionIdentifier()
} catch (throwable: Throwable) {
Timber.tag(loggerTag.value).w("Failed to load sessionId from loaded olm session")
null
}
}
private fun addNewSessionInCache(session: OlmSessionWrapper, deviceKey: String) {
val sessionId = getSafeSessionIdentifier(session.olmSession) ?: return
olmSessions.getOrPut(deviceKey) { mutableListOf() }.let {
val existing = it.firstOrNull { getSafeSessionIdentifier(it.olmSession) == sessionId }
it.add(session)
// remove and release if was there but with different instance
if (existing != null && existing.olmSession != session.olmSession) {
// mm not sure when this could happen
// anyhow we should remove and release the one known
it.remove(existing)
existing.olmSession.releaseSession()
}
}
}
}

View File

@ -16,14 +16,18 @@
package org.matrix.android.sdk.internal.crypto.actions package org.matrix.android.sdk.internal.crypto.actions
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext
import org.matrix.android.sdk.api.MatrixCoroutineDispatchers
import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.internal.crypto.MXOlmDevice import org.matrix.android.sdk.internal.crypto.MXOlmDevice
import org.matrix.android.sdk.internal.crypto.model.CryptoDeviceInfo import org.matrix.android.sdk.internal.crypto.model.CryptoDeviceInfo
import org.matrix.android.sdk.internal.crypto.model.MXKey import org.matrix.android.sdk.internal.crypto.model.MXKey
import org.matrix.android.sdk.internal.crypto.model.MXOlmSessionResult import org.matrix.android.sdk.internal.crypto.model.MXOlmSessionResult
import org.matrix.android.sdk.internal.crypto.model.MXUsersDevicesMap import org.matrix.android.sdk.internal.crypto.model.MXUsersDevicesMap
import org.matrix.android.sdk.internal.crypto.model.toDebugString
import org.matrix.android.sdk.internal.crypto.tasks.ClaimOneTimeKeysForUsersDeviceTask import org.matrix.android.sdk.internal.crypto.tasks.ClaimOneTimeKeysForUsersDeviceTask
import org.matrix.android.sdk.internal.session.SessionScope
import timber.log.Timber import timber.log.Timber
import javax.inject.Inject import javax.inject.Inject
@ -31,90 +35,90 @@ private const val ONE_TIME_KEYS_RETRY_COUNT = 3
private val loggerTag = LoggerTag("EnsureOlmSessionsForDevicesAction", LoggerTag.CRYPTO) private val loggerTag = LoggerTag("EnsureOlmSessionsForDevicesAction", LoggerTag.CRYPTO)
@SessionScope
internal class EnsureOlmSessionsForDevicesAction @Inject constructor( internal class EnsureOlmSessionsForDevicesAction @Inject constructor(
private val olmDevice: MXOlmDevice, private val olmDevice: MXOlmDevice,
private val coroutineDispatchers: MatrixCoroutineDispatchers,
private val oneTimeKeysForUsersDeviceTask: ClaimOneTimeKeysForUsersDeviceTask) { private val oneTimeKeysForUsersDeviceTask: ClaimOneTimeKeysForUsersDeviceTask) {
private val ensureMutex = Mutex()
/**
* We want to synchronize a bit here, because we are iterating to check existing olm session and
* also adding some
*/
suspend fun handle(devicesByUser: Map<String, List<CryptoDeviceInfo>>, force: Boolean = false): MXUsersDevicesMap<MXOlmSessionResult> { suspend fun handle(devicesByUser: Map<String, List<CryptoDeviceInfo>>, force: Boolean = false): MXUsersDevicesMap<MXOlmSessionResult> {
val devicesWithoutSession = ArrayList<CryptoDeviceInfo>() ensureMutex.withLock {
val results = MXUsersDevicesMap<MXOlmSessionResult>()
val deviceList = devicesByUser.flatMap { it.value }
Timber.tag(loggerTag.value)
.d("ensure olm forced:$force for ${deviceList.joinToString { it.shortDebugString() }}")
val devicesToCreateSessionWith = mutableListOf<CryptoDeviceInfo>()
if (force) {
// we take all devices and will query otk for them
devicesToCreateSessionWith.addAll(deviceList)
} else {
// only peek devices without active session
deviceList.forEach { deviceInfo ->
val deviceId = deviceInfo.deviceId
val userId = deviceInfo.userId
val key = deviceInfo.identityKey() ?: return@forEach Unit.also {
Timber.tag(loggerTag.value).w("Ignoring device ${deviceInfo.shortDebugString()} without identity key")
}
val results = MXUsersDevicesMap<MXOlmSessionResult>() // is there a session that as been already used?
val sessionId = olmDevice.getSessionId(key)
for ((userId, deviceList) in devicesByUser) { if (sessionId.isNullOrEmpty()) {
for (deviceInfo in deviceList) { Timber.tag(loggerTag.value).d("Found no existing olm session ${deviceInfo.shortDebugString()} add to claim list")
val deviceId = deviceInfo.deviceId devicesToCreateSessionWith.add(deviceInfo)
val key = deviceInfo.identityKey() } else {
if (key == null) { Timber.tag(loggerTag.value).d("using olm session $sessionId for (${deviceInfo.userId}|$deviceId)")
Timber.w("## CRYPTO | Ignoring device (${deviceInfo.userId}|$deviceId) without identity key") val olmSessionResult = MXOlmSessionResult(deviceInfo, sessionId)
continue results.setObject(userId, deviceId, olmSessionResult)
}
val sessionId = olmDevice.getSessionId(key)
if (sessionId.isNullOrEmpty() || force) {
Timber.tag(loggerTag.value).d("Found no existing olm session (${deviceInfo.userId}|$deviceId) (force=$force)")
devicesWithoutSession.add(deviceInfo)
} else {
Timber.tag(loggerTag.value).d("using olm session $sessionId for (${deviceInfo.userId}|$deviceId)")
}
val olmSessionResult = MXOlmSessionResult(deviceInfo, sessionId)
results.setObject(userId, deviceId, olmSessionResult)
}
}
Timber.tag(loggerTag.value).d("Devices without olm session (count:${devicesWithoutSession.size}) :" +
" ${devicesWithoutSession.joinToString { "${it.userId}|${it.deviceId}" }}")
if (devicesWithoutSession.size == 0) {
return results
}
// Prepare the request for claiming one-time keys
val usersDevicesToClaim = MXUsersDevicesMap<String>()
val oneTimeKeyAlgorithm = MXKey.KEY_SIGNED_CURVE_25519_TYPE
for (device in devicesWithoutSession) {
usersDevicesToClaim.setObject(device.userId, device.deviceId, oneTimeKeyAlgorithm)
}
// TODO: this has a race condition - if we try to send another message
// while we are claiming a key, we will end up claiming two and setting up
// two sessions.
//
// That should eventually resolve itself, but it's poor form.
Timber.tag(loggerTag.value).i("claimOneTimeKeysForUsersDevices() : ${usersDevicesToClaim.toDebugString()}")
val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(usersDevicesToClaim)
val oneTimeKeys = oneTimeKeysForUsersDeviceTask.executeRetry(claimParams, remainingRetry = ONE_TIME_KEYS_RETRY_COUNT)
Timber.tag(loggerTag.value).v("claimOneTimeKeysForUsersDevices() : keysClaimResponse.oneTimeKeys: $oneTimeKeys")
for ((userId, deviceInfos) in devicesByUser) {
for (deviceInfo in deviceInfos) {
var oneTimeKey: MXKey? = null
val deviceIds = oneTimeKeys.getUserDeviceIds(userId)
if (null != deviceIds) {
for (deviceId in deviceIds) {
val olmSessionResult = results.getObject(userId, deviceId)
if (olmSessionResult?.sessionId != null && !force) {
// We already have a result for this device
continue
}
val key = oneTimeKeys.getObject(userId, deviceId)
if (key?.type == oneTimeKeyAlgorithm) {
oneTimeKey = key
}
if (oneTimeKey == null) {
Timber.tag(loggerTag.value).d("No one time key for $userId|$deviceId")
continue
}
// Update the result for this device in results
olmSessionResult?.sessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo)
} }
} }
} }
if (devicesToCreateSessionWith.isEmpty()) {
// no session to create
return results
}
val usersDevicesToClaim = MXUsersDevicesMap<String>().apply {
devicesToCreateSessionWith.forEach {
setObject(it.userId, it.deviceId, MXKey.KEY_SIGNED_CURVE_25519_TYPE)
}
}
// Let's now claim one time keys
val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(usersDevicesToClaim)
val oneTimeKeys = withContext(coroutineDispatchers.io) {
oneTimeKeysForUsersDeviceTask.executeRetry(claimParams, ONE_TIME_KEYS_RETRY_COUNT)
}
// let now start olm session using the new otks
devicesToCreateSessionWith.forEach { deviceInfo ->
val userId = deviceInfo.userId
val deviceId = deviceInfo.deviceId
// Did we get an OTK
val oneTimeKey = oneTimeKeys.getObject(userId, deviceId)
if (oneTimeKey == null) {
Timber.tag(loggerTag.value).d("No otk for ${deviceInfo.shortDebugString()}")
} else if (oneTimeKey.type != MXKey.KEY_SIGNED_CURVE_25519_TYPE) {
Timber.tag(loggerTag.value).d("Bad otk type (${oneTimeKey.type}) for ${deviceInfo.shortDebugString()}")
} else {
val olmSessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo)
if (olmSessionId != null) {
val olmSessionResult = MXOlmSessionResult(deviceInfo, olmSessionId)
results.setObject(userId, deviceId, olmSessionResult)
} else {
Timber
.tag(loggerTag.value)
.d("## CRYPTO | cant unwedge failed to create outbound ${deviceInfo.shortDebugString()}")
}
}
}
return results
} }
return results
} }
private fun verifyKeyAndStartSession(oneTimeKey: MXKey, userId: String, deviceInfo: CryptoDeviceInfo): String? { private fun verifyKeyAndStartSession(oneTimeKey: MXKey, userId: String, deviceInfo: CryptoDeviceInfo): String? {

View File

@ -16,6 +16,7 @@
package org.matrix.android.sdk.internal.crypto.actions package org.matrix.android.sdk.internal.crypto.actions
import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.api.session.events.model.Content import org.matrix.android.sdk.api.session.events.model.Content
import org.matrix.android.sdk.internal.crypto.MXCRYPTO_ALGORITHM_OLM import org.matrix.android.sdk.internal.crypto.MXCRYPTO_ALGORITHM_OLM
import org.matrix.android.sdk.internal.crypto.MXOlmDevice import org.matrix.android.sdk.internal.crypto.MXOlmDevice
@ -28,6 +29,8 @@ import org.matrix.android.sdk.internal.util.convertToUTF8
import timber.log.Timber import timber.log.Timber
import javax.inject.Inject import javax.inject.Inject
private val loggerTag = LoggerTag("MessageEncrypter", LoggerTag.CRYPTO)
internal class MessageEncrypter @Inject constructor( internal class MessageEncrypter @Inject constructor(
@UserId @UserId
private val userId: String, private val userId: String,
@ -42,7 +45,7 @@ internal class MessageEncrypter @Inject constructor(
* @param deviceInfos list of device infos to encrypt for. * @param deviceInfos list of device infos to encrypt for.
* @return the content for an m.room.encrypted event. * @return the content for an m.room.encrypted event.
*/ */
fun encryptMessage(payloadFields: Content, deviceInfos: List<CryptoDeviceInfo>): EncryptedMessage { suspend fun encryptMessage(payloadFields: Content, deviceInfos: List<CryptoDeviceInfo>): EncryptedMessage {
val deviceInfoParticipantKey = deviceInfos.associateBy { it.identityKey()!! } val deviceInfoParticipantKey = deviceInfos.associateBy { it.identityKey()!! }
val payloadJson = payloadFields.toMutableMap() val payloadJson = payloadFields.toMutableMap()
@ -66,7 +69,7 @@ internal class MessageEncrypter @Inject constructor(
val sessionId = olmDevice.getSessionId(deviceKey) val sessionId = olmDevice.getSessionId(deviceKey)
if (!sessionId.isNullOrEmpty()) { if (!sessionId.isNullOrEmpty()) {
Timber.v("Using sessionid $sessionId for device $deviceKey") Timber.tag(loggerTag.value).d("Using sessionid $sessionId for device $deviceKey")
payloadJson["recipient"] = deviceInfo.userId payloadJson["recipient"] = deviceInfo.userId
payloadJson["recipient_keys"] = mapOf("ed25519" to deviceInfo.fingerprint()!!) payloadJson["recipient_keys"] = mapOf("ed25519" to deviceInfo.fingerprint()!!)

View File

@ -36,7 +36,7 @@ internal interface IMXDecrypting {
* @return the decryption information, or an error * @return the decryption information, or an error
*/ */
@Throws(MXCryptoError::class) @Throws(MXCryptoError::class)
fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult
/** /**
* Handle a key event. * Handle a key event.

View File

@ -45,7 +45,7 @@ internal interface IMXGroupEncryption {
* *
* @return true in case of success * @return true in case of success
*/ */
suspend fun reshareKey(sessionId: String, suspend fun reshareKey(groupSessionId: String,
userId: String, userId: String,
deviceId: String, deviceId: String,
senderKey: String): Boolean senderKey: String): Boolean

View File

@ -19,6 +19,7 @@ package org.matrix.android.sdk.internal.crypto.algorithms.megolm
import dagger.Lazy import dagger.Lazy
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.withLock
import org.matrix.android.sdk.api.MatrixCoroutineDispatchers import org.matrix.android.sdk.api.MatrixCoroutineDispatchers
import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.session.crypto.MXCryptoError
@ -71,7 +72,7 @@ internal class MXMegolmDecryption(private val userId: String,
// private var pendingEvents: MutableMap<String /* senderKey|sessionId */, MutableMap<String /* timelineId */, MutableList<Event>>> = HashMap() // private var pendingEvents: MutableMap<String /* senderKey|sessionId */, MutableMap<String /* timelineId */, MutableList<Event>>> = HashMap()
@Throws(MXCryptoError::class) @Throws(MXCryptoError::class)
override fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult { override suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
// If cross signing is enabled, we don't send request until the keys are trusted // If cross signing is enabled, we don't send request until the keys are trusted
// There could be a race effect here when xsigning is enabled, we should ensure that keys was downloaded once // There could be a race effect here when xsigning is enabled, we should ensure that keys was downloaded once
val requestOnFail = cryptoStore.getMyCrossSigningInfo()?.isTrusted() == true val requestOnFail = cryptoStore.getMyCrossSigningInfo()?.isTrusted() == true
@ -79,7 +80,7 @@ internal class MXMegolmDecryption(private val userId: String,
} }
@Throws(MXCryptoError::class) @Throws(MXCryptoError::class)
private fun decryptEvent(event: Event, timeline: String, requestKeysOnFail: Boolean): MXEventDecryptionResult { private suspend fun decryptEvent(event: Event, timeline: String, requestKeysOnFail: Boolean): MXEventDecryptionResult {
Timber.tag(loggerTag.value).v("decryptEvent ${event.eventId}, requestKeysOnFail:$requestKeysOnFail") Timber.tag(loggerTag.value).v("decryptEvent ${event.eventId}, requestKeysOnFail:$requestKeysOnFail")
if (event.roomId.isNullOrBlank()) { if (event.roomId.isNullOrBlank()) {
throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_FIELDS, MXCryptoError.MISSING_FIELDS_REASON) throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_FIELDS, MXCryptoError.MISSING_FIELDS_REASON)
@ -345,7 +346,22 @@ internal class MXMegolmDecryption(private val userId: String,
return return
} }
val userId = request.userId ?: return val userId = request.userId ?: return
cryptoCoroutineScope.launch(coroutineDispatchers.crypto) { cryptoCoroutineScope.launch(coroutineDispatchers.crypto) {
val body = request.requestBody
val sessionHolder = try {
olmDevice.getInboundGroupSession(body.sessionId, body.senderKey, body.roomId)
} catch (failure: Throwable) {
Timber.tag(loggerTag.value).e(failure, "shareKeysWithDevice: failed to get session for request $body")
return@launch
}
val export = sessionHolder.mutex.withLock {
sessionHolder.wrapper.exportKeys()
} ?: return@launch Unit.also {
Timber.tag(loggerTag.value).e("shareKeysWithDevice: failed to export group session ${body.sessionId}")
}
runCatching { deviceListManager.downloadKeys(listOf(userId), false) } runCatching { deviceListManager.downloadKeys(listOf(userId), false) }
.mapCatching { .mapCatching {
val deviceId = request.deviceId val deviceId = request.deviceId
@ -355,7 +371,6 @@ internal class MXMegolmDecryption(private val userId: String,
} else { } else {
val devicesByUser = mapOf(userId to listOf(deviceInfo)) val devicesByUser = mapOf(userId to listOf(deviceInfo))
val usersDeviceMap = ensureOlmSessionsForDevicesAction.handle(devicesByUser) val usersDeviceMap = ensureOlmSessionsForDevicesAction.handle(devicesByUser)
val body = request.requestBody
val olmSessionResult = usersDeviceMap.getObject(userId, deviceId) val olmSessionResult = usersDeviceMap.getObject(userId, deviceId)
if (olmSessionResult?.sessionId == null) { if (olmSessionResult?.sessionId == null) {
// no session with this device, probably because there // no session with this device, probably because there
@ -365,19 +380,10 @@ internal class MXMegolmDecryption(private val userId: String,
} }
Timber.tag(loggerTag.value).i("shareKeysWithDevice() : sharing session ${body.sessionId} with device $userId:$deviceId") Timber.tag(loggerTag.value).i("shareKeysWithDevice() : sharing session ${body.sessionId} with device $userId:$deviceId")
val payloadJson = mutableMapOf<String, Any>("type" to EventType.FORWARDED_ROOM_KEY) val payloadJson = mapOf(
runCatching { olmDevice.getInboundGroupSession(body.sessionId, body.senderKey, body.roomId) } "type" to EventType.FORWARDED_ROOM_KEY,
.fold( "content" to export
{ )
// TODO
payloadJson["content"] = it.exportKeys() ?: ""
},
{
// TODO
Timber.tag(loggerTag.value).e(it, "shareKeysWithDevice: failed to get session for request $body")
}
)
val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo)) val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo))
val sendToDeviceMap = MXUsersDevicesMap<Any>() val sendToDeviceMap = MXUsersDevicesMap<Any>()

View File

@ -18,6 +18,8 @@ package org.matrix.android.sdk.internal.crypto.algorithms.megolm
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext
import org.matrix.android.sdk.api.MatrixCoroutineDispatchers import org.matrix.android.sdk.api.MatrixCoroutineDispatchers
import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.session.crypto.MXCryptoError
@ -88,7 +90,7 @@ internal class MXMegolmEncryption(
Timber.tag(loggerTag.value).v("encryptEventContent : getDevicesInRoom") Timber.tag(loggerTag.value).v("encryptEventContent : getDevicesInRoom")
val devices = getDevicesInRoom(userIds) val devices = getDevicesInRoom(userIds)
Timber.tag(loggerTag.value).d("encrypt event in room=$roomId - devices count in room ${devices.allowedDevices.toDebugCount()}") Timber.tag(loggerTag.value).d("encrypt event in room=$roomId - devices count in room ${devices.allowedDevices.toDebugCount()}")
Timber.tag(loggerTag.value).v("encryptEventContent ${System.currentTimeMillis() - ts}: getDevicesInRoom ${devices.allowedDevices.map}") Timber.tag(loggerTag.value).v("encryptEventContent ${System.currentTimeMillis() - ts}: getDevicesInRoom ${devices.allowedDevices.toDebugString()}")
val outboundSession = ensureOutboundSession(devices.allowedDevices) val outboundSession = ensureOutboundSession(devices.allowedDevices)
return encryptContent(outboundSession, eventType, eventContent) return encryptContent(outboundSession, eventType, eventContent)
@ -142,8 +144,9 @@ internal class MXMegolmEncryption(
Timber.tag(loggerTag.value).v("prepareNewSessionInRoom() ") Timber.tag(loggerTag.value).v("prepareNewSessionInRoom() ")
val sessionId = olmDevice.createOutboundGroupSessionForRoom(roomId) val sessionId = olmDevice.createOutboundGroupSessionForRoom(roomId)
val keysClaimedMap = HashMap<String, String>() val keysClaimedMap = mapOf(
keysClaimedMap["ed25519"] = olmDevice.deviceEd25519Key!! "ed25519" to olmDevice.deviceEd25519Key!!
)
olmDevice.addInboundGroupSession(sessionId!!, olmDevice.getSessionKey(sessionId)!!, roomId, olmDevice.deviceCurve25519Key!!, olmDevice.addInboundGroupSession(sessionId!!, olmDevice.getSessionKey(sessionId)!!, roomId, olmDevice.deviceCurve25519Key!!,
emptyList(), keysClaimedMap, false) emptyList(), keysClaimedMap, false)
@ -303,11 +306,13 @@ internal class MXMegolmEncryption(
Timber.tag(loggerTag.value).d("sending to device room key for ${session.sessionId} to ${contentMap.toDebugString()}") Timber.tag(loggerTag.value).d("sending to device room key for ${session.sessionId} to ${contentMap.toDebugString()}")
val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, contentMap) val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, contentMap)
try { try {
sendToDeviceTask.execute(sendToDeviceParams) withContext(coroutineDispatchers.io) {
sendToDeviceTask.execute(sendToDeviceParams)
}
Timber.tag(loggerTag.value).i("shareUserDevicesKey() : sendToDevice succeeds after ${System.currentTimeMillis() - t0} ms") Timber.tag(loggerTag.value).i("shareUserDevicesKey() : sendToDevice succeeds after ${System.currentTimeMillis() - t0} ms")
} catch (failure: Throwable) { } catch (failure: Throwable) {
// What to do here... // What to do here...
Timber.tag(loggerTag.value).e("shareUserDevicesKey() : Failed to share session <${session.sessionId}> with $devicesByUser ") Timber.tag(loggerTag.value).e("shareUserDevicesKey() : Failed to share <${session.sessionId}>")
} }
} else { } else {
Timber.tag(loggerTag.value).i("shareUserDevicesKey() : no need to share key") Timber.tag(loggerTag.value).i("shareUserDevicesKey() : no need to share key")
@ -346,9 +351,12 @@ internal class MXMegolmEncryption(
} }
) )
try { try {
sendToDeviceTask.execute(params) withContext(coroutineDispatchers.io) {
sendToDeviceTask.execute(params)
}
} catch (failure: Throwable) { } catch (failure: Throwable) {
Timber.tag(loggerTag.value).e("notifyKeyWithHeld() : Failed to notify withheld key for $targets session: $sessionId ") Timber.tag(loggerTag.value)
.e("notifyKeyWithHeld() :$sessionId Failed to send withheld ${targets.map { "${it.userId}|${it.deviceId}" }}")
} }
} }
@ -432,20 +440,20 @@ internal class MXMegolmEncryption(
} }
} }
override suspend fun reshareKey(sessionId: String, override suspend fun reshareKey(groupSessionId: String,
userId: String, userId: String,
deviceId: String, deviceId: String,
senderKey: String): Boolean { senderKey: String): Boolean {
Timber.tag(loggerTag.value).i("process reshareKey for $sessionId to $userId:$deviceId") Timber.tag(loggerTag.value).i("process reshareKey for $groupSessionId to $userId:$deviceId")
val deviceInfo = cryptoStore.getUserDevice(userId, deviceId) ?: return false val deviceInfo = cryptoStore.getUserDevice(userId, deviceId) ?: return false
.also { Timber.tag(loggerTag.value).w("reshareKey: Device not found") } .also { Timber.tag(loggerTag.value).w("reshareKey: Device not found") }
// Get the chain index of the key we previously sent this device // Get the chain index of the key we previously sent this device
val wasSessionSharedWithUser = cryptoStore.getSharedSessionInfo(roomId, sessionId, deviceInfo) val wasSessionSharedWithUser = cryptoStore.getSharedSessionInfo(roomId, groupSessionId, deviceInfo)
if (!wasSessionSharedWithUser.found) { if (!wasSessionSharedWithUser.found) {
// This session was never shared with this user // This session was never shared with this user
// Send a room key with held // Send a room key with held
notifyKeyWithHeld(listOf(UserDevice(userId, deviceId)), sessionId, senderKey, WithHeldCode.UNAUTHORISED) notifyKeyWithHeld(listOf(UserDevice(userId, deviceId)), groupSessionId, senderKey, WithHeldCode.UNAUTHORISED)
Timber.tag(loggerTag.value).w("reshareKey: ERROR : Never shared megolm with this device") Timber.tag(loggerTag.value).w("reshareKey: ERROR : Never shared megolm with this device")
return false return false
} }
@ -456,42 +464,47 @@ internal class MXMegolmEncryption(
} }
val devicesByUser = mapOf(userId to listOf(deviceInfo)) val devicesByUser = mapOf(userId to listOf(deviceInfo))
val usersDeviceMap = ensureOlmSessionsForDevicesAction.handle(devicesByUser) val usersDeviceMap = try {
val olmSessionResult = usersDeviceMap.getObject(userId, deviceId) ensureOlmSessionsForDevicesAction.handle(devicesByUser)
olmSessionResult?.sessionId // no session with this device, probably because there were no one-time keys. } catch (failure: Throwable) {
// ensureOlmSessionsForDevicesAction has already done the logging, so just skip it. null
?: return false.also { }
Timber.tag(loggerTag.value).w("reshareKey: no session with this device, probably because there were no one-time keys") val olmSessionResult = usersDeviceMap?.getObject(userId, deviceId)
} if (olmSessionResult?.sessionId == null) {
Timber.tag(loggerTag.value).w("reshareKey: no session with this device, probably because there were no one-time keys")
return false
}
Timber.tag(loggerTag.value).i(" reshareKey: $groupSessionId:$chainIndex with device $userId:$deviceId using session ${olmSessionResult.sessionId}")
Timber.tag(loggerTag.value).i(" reshareKey: sharing keys for session $senderKey|$sessionId:$chainIndex with device $userId:$deviceId") val sessionHolder = try {
olmDevice.getInboundGroupSession(groupSessionId, senderKey, roomId)
} catch (failure: Throwable) {
Timber.tag(loggerTag.value).e(failure, "shareKeysWithDevice: failed to get session $groupSessionId")
return false
}
val payloadJson = mutableMapOf<String, Any>("type" to EventType.FORWARDED_ROOM_KEY) val export = sessionHolder.mutex.withLock {
sessionHolder.wrapper.exportKeys()
} ?: return false.also {
Timber.tag(loggerTag.value).e("shareKeysWithDevice: failed to export group session $groupSessionId")
}
runCatching { olmDevice.getInboundGroupSession(sessionId, senderKey, roomId) } val payloadJson = mapOf(
.fold( "type" to EventType.FORWARDED_ROOM_KEY,
{ "content" to export
// TODO )
payloadJson["content"] = it.exportKeys(chainIndex.toLong()) ?: ""
},
{
// TODO
Timber.tag(loggerTag.value).e(it, "reshareKey: failed to get session $sessionId|$senderKey|$roomId")
}
)
val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo)) val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo))
val sendToDeviceMap = MXUsersDevicesMap<Any>() val sendToDeviceMap = MXUsersDevicesMap<Any>()
sendToDeviceMap.setObject(userId, deviceId, encodedPayload) sendToDeviceMap.setObject(userId, deviceId, encodedPayload)
Timber.tag(loggerTag.value).i("reshareKey() : sending session $sessionId to $userId:$deviceId") Timber.tag(loggerTag.value).i("reshareKey() : sending session $groupSessionId to $userId:$deviceId")
val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, sendToDeviceMap) val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, sendToDeviceMap)
return try { return try {
sendToDeviceTask.execute(sendToDeviceParams) sendToDeviceTask.execute(sendToDeviceParams)
Timber.tag(loggerTag.value).i("reshareKey() : successfully send <$sessionId> to $userId:$deviceId") Timber.tag(loggerTag.value).i("reshareKey() : successfully send <$groupSessionId> to $userId:$deviceId")
true true
} catch (failure: Throwable) { } catch (failure: Throwable) {
Timber.tag(loggerTag.value).e(failure, "reshareKey() : fail to send <$sessionId> to $userId:$deviceId") Timber.tag(loggerTag.value).e(failure, "reshareKey() : fail to send <$groupSessionId> to $userId:$deviceId")
false false
} }
} }

View File

@ -16,6 +16,8 @@
package org.matrix.android.sdk.internal.crypto.algorithms.olm package org.matrix.android.sdk.internal.crypto.algorithms.olm
import kotlinx.coroutines.sync.withLock
import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.session.crypto.MXCryptoError
import org.matrix.android.sdk.api.session.events.model.Event import org.matrix.android.sdk.api.session.events.model.Event
import org.matrix.android.sdk.api.session.events.model.toModel import org.matrix.android.sdk.api.session.events.model.toModel
@ -30,6 +32,7 @@ import org.matrix.android.sdk.internal.di.MoshiProvider
import org.matrix.android.sdk.internal.util.convertFromUTF8 import org.matrix.android.sdk.internal.util.convertFromUTF8
import timber.log.Timber import timber.log.Timber
private val loggerTag = LoggerTag("MXOlmDecryption", LoggerTag.CRYPTO)
internal class MXOlmDecryption( internal class MXOlmDecryption(
// The olm device interface // The olm device interface
private val olmDevice: MXOlmDevice, private val olmDevice: MXOlmDevice,
@ -38,27 +41,27 @@ internal class MXOlmDecryption(
IMXDecrypting { IMXDecrypting {
@Throws(MXCryptoError::class) @Throws(MXCryptoError::class)
override fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult { override suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
val olmEventContent = event.content.toModel<OlmEventContent>() ?: run { val olmEventContent = event.content.toModel<OlmEventContent>() ?: run {
Timber.e("## decryptEvent() : bad event format") Timber.tag(loggerTag.value).e("## decryptEvent() : bad event format")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_EVENT_FORMAT, throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_EVENT_FORMAT,
MXCryptoError.BAD_EVENT_FORMAT_TEXT_REASON) MXCryptoError.BAD_EVENT_FORMAT_TEXT_REASON)
} }
val cipherText = olmEventContent.ciphertext ?: run { val cipherText = olmEventContent.ciphertext ?: run {
Timber.e("## decryptEvent() : missing cipher text") Timber.tag(loggerTag.value).e("## decryptEvent() : missing cipher text")
throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_CIPHER_TEXT, throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_CIPHER_TEXT,
MXCryptoError.MISSING_CIPHER_TEXT_REASON) MXCryptoError.MISSING_CIPHER_TEXT_REASON)
} }
val senderKey = olmEventContent.senderKey ?: run { val senderKey = olmEventContent.senderKey ?: run {
Timber.e("## decryptEvent() : missing sender key") Timber.tag(loggerTag.value).e("## decryptEvent() : missing sender key")
throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_SENDER_KEY, throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_SENDER_KEY,
MXCryptoError.MISSING_SENDER_KEY_TEXT_REASON) MXCryptoError.MISSING_SENDER_KEY_TEXT_REASON)
} }
val messageAny = cipherText[olmDevice.deviceCurve25519Key] ?: run { val messageAny = cipherText[olmDevice.deviceCurve25519Key] ?: run {
Timber.e("## decryptEvent() : our device ${olmDevice.deviceCurve25519Key} is not included in recipients") Timber.tag(loggerTag.value).e("## decryptEvent() : our device ${olmDevice.deviceCurve25519Key} is not included in recipients")
throw MXCryptoError.Base(MXCryptoError.ErrorType.NOT_INCLUDE_IN_RECIPIENTS, MXCryptoError.NOT_INCLUDED_IN_RECIPIENT_REASON) throw MXCryptoError.Base(MXCryptoError.ErrorType.NOT_INCLUDE_IN_RECIPIENTS, MXCryptoError.NOT_INCLUDED_IN_RECIPIENT_REASON)
} }
@ -69,7 +72,7 @@ internal class MXOlmDecryption(
val decryptedPayload = decryptMessage(message, senderKey) val decryptedPayload = decryptMessage(message, senderKey)
if (decryptedPayload == null) { if (decryptedPayload == null) {
Timber.e("## decryptEvent() Failed to decrypt Olm event (id= ${event.eventId} from $senderKey") Timber.tag(loggerTag.value).e("## decryptEvent() Failed to decrypt Olm event (id= ${event.eventId} from $senderKey")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE, MXCryptoError.BAD_ENCRYPTED_MESSAGE_REASON) throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE, MXCryptoError.BAD_ENCRYPTED_MESSAGE_REASON)
} }
val payloadString = convertFromUTF8(decryptedPayload) val payloadString = convertFromUTF8(decryptedPayload)
@ -78,30 +81,30 @@ internal class MXOlmDecryption(
val payload = adapter.fromJson(payloadString) val payload = adapter.fromJson(payloadString)
if (payload == null) { if (payload == null) {
Timber.e("## decryptEvent failed : null payload") Timber.tag(loggerTag.value).e("## decryptEvent failed : null payload")
throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, MXCryptoError.MISSING_CIPHER_TEXT_REASON) throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, MXCryptoError.MISSING_CIPHER_TEXT_REASON)
} }
val olmPayloadContent = OlmPayloadContent.fromJsonString(payloadString) ?: run { val olmPayloadContent = OlmPayloadContent.fromJsonString(payloadString) ?: run {
Timber.e("## decryptEvent() : bad olmPayloadContent format") Timber.tag(loggerTag.value).e("## decryptEvent() : bad olmPayloadContent format")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_DECRYPTED_FORMAT, MXCryptoError.BAD_DECRYPTED_FORMAT_TEXT_REASON) throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_DECRYPTED_FORMAT, MXCryptoError.BAD_DECRYPTED_FORMAT_TEXT_REASON)
} }
if (olmPayloadContent.recipient.isNullOrBlank()) { if (olmPayloadContent.recipient.isNullOrBlank()) {
val reason = String.format(MXCryptoError.ERROR_MISSING_PROPERTY_REASON, "recipient") val reason = String.format(MXCryptoError.ERROR_MISSING_PROPERTY_REASON, "recipient")
Timber.e("## decryptEvent() : $reason") Timber.tag(loggerTag.value).e("## decryptEvent() : $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_PROPERTY, reason) throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_PROPERTY, reason)
} }
if (olmPayloadContent.recipient != userId) { if (olmPayloadContent.recipient != userId) {
Timber.e("## decryptEvent() : Event ${event.eventId}:" + Timber.tag(loggerTag.value).e("## decryptEvent() : Event ${event.eventId}:" +
" Intended recipient ${olmPayloadContent.recipient} does not match our id $userId") " Intended recipient ${olmPayloadContent.recipient} does not match our id $userId")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_RECIPIENT, throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_RECIPIENT,
String.format(MXCryptoError.BAD_RECIPIENT_REASON, olmPayloadContent.recipient)) String.format(MXCryptoError.BAD_RECIPIENT_REASON, olmPayloadContent.recipient))
} }
val recipientKeys = olmPayloadContent.recipientKeys ?: run { val recipientKeys = olmPayloadContent.recipientKeys ?: run {
Timber.e("## decryptEvent() : Olm event (id=${event.eventId}) contains no 'recipient_keys'" + Timber.tag(loggerTag.value).e("## decryptEvent() : Olm event (id=${event.eventId}) contains no 'recipient_keys'" +
" property; cannot prevent unknown-key attack") " property; cannot prevent unknown-key attack")
throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_PROPERTY, throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_PROPERTY,
String.format(MXCryptoError.ERROR_MISSING_PROPERTY_REASON, "recipient_keys")) String.format(MXCryptoError.ERROR_MISSING_PROPERTY_REASON, "recipient_keys"))
@ -110,31 +113,34 @@ internal class MXOlmDecryption(
val ed25519 = recipientKeys["ed25519"] val ed25519 = recipientKeys["ed25519"]
if (ed25519 != olmDevice.deviceEd25519Key) { if (ed25519 != olmDevice.deviceEd25519Key) {
Timber.e("## decryptEvent() : Event ${event.eventId}: Intended recipient ed25519 key $ed25519 did not match ours") Timber.tag(loggerTag.value).e("## decryptEvent() : Event ${event.eventId}: Intended recipient ed25519 key $ed25519 did not match ours")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_RECIPIENT_KEY, throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_RECIPIENT_KEY,
MXCryptoError.BAD_RECIPIENT_KEY_REASON) MXCryptoError.BAD_RECIPIENT_KEY_REASON)
} }
if (olmPayloadContent.sender.isNullOrBlank()) { if (olmPayloadContent.sender.isNullOrBlank()) {
Timber.e("## decryptEvent() : Olm event (id=${event.eventId}) contains no 'sender' property; cannot prevent unknown-key attack") Timber.tag(loggerTag.value)
.e("## decryptEvent() : Olm event (id=${event.eventId}) contains no 'sender' property; cannot prevent unknown-key attack")
throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_PROPERTY, throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_PROPERTY,
String.format(MXCryptoError.ERROR_MISSING_PROPERTY_REASON, "sender")) String.format(MXCryptoError.ERROR_MISSING_PROPERTY_REASON, "sender"))
} }
if (olmPayloadContent.sender != event.senderId) { if (olmPayloadContent.sender != event.senderId) {
Timber.e("Event ${event.eventId}: original sender ${olmPayloadContent.sender} does not match reported sender ${event.senderId}") Timber.tag(loggerTag.value)
.e("Event ${event.eventId}: sender ${olmPayloadContent.sender} does not match reported sender ${event.senderId}")
throw MXCryptoError.Base(MXCryptoError.ErrorType.FORWARDED_MESSAGE, throw MXCryptoError.Base(MXCryptoError.ErrorType.FORWARDED_MESSAGE,
String.format(MXCryptoError.FORWARDED_MESSAGE_REASON, olmPayloadContent.sender)) String.format(MXCryptoError.FORWARDED_MESSAGE_REASON, olmPayloadContent.sender))
} }
if (olmPayloadContent.roomId != event.roomId) { if (olmPayloadContent.roomId != event.roomId) {
Timber.e("## decryptEvent() : Event ${event.eventId}: original room ${olmPayloadContent.roomId} does not match reported room ${event.roomId}") Timber.tag(loggerTag.value)
.e("## decryptEvent() : Event ${event.eventId}: room ${olmPayloadContent.roomId} does not match reported room ${event.roomId}")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_ROOM, throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_ROOM,
String.format(MXCryptoError.BAD_ROOM_REASON, olmPayloadContent.roomId)) String.format(MXCryptoError.BAD_ROOM_REASON, olmPayloadContent.roomId))
} }
val keys = olmPayloadContent.keys ?: run { val keys = olmPayloadContent.keys ?: run {
Timber.e("## decryptEvent failed : null keys") Timber.tag(loggerTag.value).e("## decryptEvent failed : null keys")
throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT,
MXCryptoError.MISSING_CIPHER_TEXT_REASON) MXCryptoError.MISSING_CIPHER_TEXT_REASON)
} }
@ -153,8 +159,8 @@ internal class MXOlmDecryption(
* @param message message object, with 'type' and 'body' fields. * @param message message object, with 'type' and 'body' fields.
* @return payload, if decrypted successfully. * @return payload, if decrypted successfully.
*/ */
private fun decryptMessage(message: JsonDict, theirDeviceIdentityKey: String): String? { private suspend fun decryptMessage(message: JsonDict, theirDeviceIdentityKey: String): String? {
val sessionIds = olmDevice.getSessionIds(theirDeviceIdentityKey).orEmpty() val sessionIds = olmDevice.getSessionIds(theirDeviceIdentityKey)
val messageBody = message["body"] as? String ?: return null val messageBody = message["body"] as? String ?: return null
val messageType = when (val typeAsVoid = message["type"]) { val messageType = when (val typeAsVoid = message["type"]) {
@ -166,11 +172,32 @@ internal class MXOlmDecryption(
// Try each session in turn // Try each session in turn
// decryptionErrors = {}; // decryptionErrors = {};
val isPreKey = messageType == 0
// we want to synchronize on prekey if not we could end up create two olm sessions
// Not very clear but it looks like the js-sdk for consistency
return if (isPreKey) {
olmDevice.mutex.withLock {
reallyDecryptMessage(sessionIds, messageBody, messageType, theirDeviceIdentityKey)
}
} else {
reallyDecryptMessage(sessionIds, messageBody, messageType, theirDeviceIdentityKey)
}
}
private suspend fun reallyDecryptMessage(sessionIds: List<String>, messageBody: String, messageType: Int, theirDeviceIdentityKey: String): String? {
Timber.tag(loggerTag.value).d("decryptMessage() try to decrypt olm message type:$messageType from ${sessionIds.size} known sessions")
for (sessionId in sessionIds) { for (sessionId in sessionIds) {
val payload = olmDevice.decryptMessage(messageBody, messageType, sessionId, theirDeviceIdentityKey) val payload = try {
olmDevice.decryptMessage(messageBody, messageType, sessionId, theirDeviceIdentityKey)
} catch (throwable: Exception) {
// As we are trying one by one, we don't really care of the error here
Timber.tag(loggerTag.value).d("decryptMessage() failed with session $sessionId")
null
}
if (null != payload) { if (null != payload) {
Timber.v("## decryptMessage() : Decrypted Olm message from $theirDeviceIdentityKey with session $sessionId") Timber.tag(loggerTag.value).v("## decryptMessage() : Decrypted Olm message from $theirDeviceIdentityKey with session $sessionId")
return payload return payload
} else { } else {
val foundSession = olmDevice.matchesSession(theirDeviceIdentityKey, sessionId, messageType, messageBody) val foundSession = olmDevice.matchesSession(theirDeviceIdentityKey, sessionId, messageType, messageBody)
@ -178,7 +205,7 @@ internal class MXOlmDecryption(
if (foundSession) { if (foundSession) {
// Decryption failed, but it was a prekey message matching this // Decryption failed, but it was a prekey message matching this
// session, so it should have worked. // session, so it should have worked.
Timber.e("## decryptMessage() : Error decrypting prekey message with existing session id $sessionId:TODO") Timber.tag(loggerTag.value).e("## decryptMessage() : Error decrypting prekey message with existing session id $sessionId:TODO")
return null return null
} }
} }
@ -189,9 +216,9 @@ internal class MXOlmDecryption(
// didn't work. // didn't work.
if (sessionIds.isEmpty()) { if (sessionIds.isEmpty()) {
Timber.e("## decryptMessage() : No existing sessions") Timber.tag(loggerTag.value).e("## decryptMessage() : No existing sessions")
} else { } else {
Timber.e("## decryptMessage() : Error decrypting non-prekey message with existing sessions") Timber.tag(loggerTag.value).e("## decryptMessage() : Error decrypting non-prekey message with existing sessions")
} }
return null return null
@ -199,14 +226,17 @@ internal class MXOlmDecryption(
// prekey message which doesn't match any existing sessions: make a new // prekey message which doesn't match any existing sessions: make a new
// session. // session.
// XXXX Possible races here? if concurrent access for same prekey message, we might create 2 sessions?
Timber.tag(loggerTag.value).d("## decryptMessage() : Create inbound group session from prekey sender:$theirDeviceIdentityKey")
val res = olmDevice.createInboundSession(theirDeviceIdentityKey, messageType, messageBody) val res = olmDevice.createInboundSession(theirDeviceIdentityKey, messageType, messageBody)
if (null == res) { if (null == res) {
Timber.e("## decryptMessage() : Error decrypting non-prekey message with existing sessions") Timber.tag(loggerTag.value).e("## decryptMessage() : Error decrypting non-prekey message with existing sessions")
return null return null
} }
Timber.v("## decryptMessage() : Created new inbound Olm session get id ${res["session_id"]} with $theirDeviceIdentityKey") Timber.tag(loggerTag.value).v("## decryptMessage() : Created new inbound Olm session get id ${res["session_id"]} with $theirDeviceIdentityKey")
return res["payload"] return res["payload"]
} }

View File

@ -96,7 +96,7 @@ internal class UpdateTrustWorker(context: Context, params: WorkerParameters, ses
if (userList.isNotEmpty()) { if (userList.isNotEmpty()) {
// Unfortunately we don't have much info on what did exactly changed (is it the cross signing keys of that user, // Unfortunately we don't have much info on what did exactly changed (is it the cross signing keys of that user,
// or a new device?) So we check all again :/ // or a new device?) So we check all again :/
Timber.d("## CrossSigning - Updating trust for users: ${userList.logLimit()}") Timber.v("## CrossSigning - Updating trust for users: ${userList.logLimit()}")
updateTrust(userList) updateTrust(userList)
} }
@ -148,7 +148,7 @@ internal class UpdateTrustWorker(context: Context, params: WorkerParameters, ses
myUserId -> myTrustResult myUserId -> myTrustResult
else -> { else -> {
crossSigningService.checkOtherMSKTrusted(myCrossSigningInfo, entry.value).also { crossSigningService.checkOtherMSKTrusted(myCrossSigningInfo, entry.value).also {
Timber.d("## CrossSigning - user:${entry.key} result:$it") Timber.v("## CrossSigning - user:${entry.key} result:$it")
} }
} }
} }
@ -178,7 +178,7 @@ internal class UpdateTrustWorker(context: Context, params: WorkerParameters, ses
// Update trust if needed // Update trust if needed
devicesEntities?.forEach { device -> devicesEntities?.forEach { device ->
val crossSignedVerified = trustMap?.get(device)?.isCrossSignedVerified() val crossSignedVerified = trustMap?.get(device)?.isCrossSignedVerified()
Timber.d("## CrossSigning - Trust for ${device.userId}|${device.deviceId} : cross verified: ${trustMap?.get(device)}") Timber.v("## CrossSigning - Trust for ${device.userId}|${device.deviceId} : cross verified: ${trustMap?.get(device)}")
if (device.trustLevelEntity?.crossSignedVerified != crossSignedVerified) { if (device.trustLevelEntity?.crossSignedVerified != crossSignedVerified) {
Timber.d("## CrossSigning - Trust change detected for ${device.userId}|${device.deviceId} : cross verified: $crossSignedVerified") Timber.d("## CrossSigning - Trust change detected for ${device.userId}|${device.deviceId} : cross verified: $crossSignedVerified")
// need to save // need to save
@ -216,7 +216,7 @@ internal class UpdateTrustWorker(context: Context, params: WorkerParameters, ses
.equalTo(RoomSummaryEntityFields.IS_ENCRYPTED, true) .equalTo(RoomSummaryEntityFields.IS_ENCRYPTED, true)
.findFirst() .findFirst()
?.let { roomSummary -> ?.let { roomSummary ->
Timber.d("## CrossSigning - Check shield state for room $roomId") Timber.v("## CrossSigning - Check shield state for room $roomId")
val allActiveRoomMembers = RoomMemberHelper(sessionRealm, roomId).getActiveRoomMemberIds() val allActiveRoomMembers = RoomMemberHelper(sessionRealm, roomId).getActiveRoomMemberIds()
try { try {
val updatedTrust = computeRoomShield( val updatedTrust = computeRoomShield(
@ -277,7 +277,7 @@ internal class UpdateTrustWorker(context: Context, params: WorkerParameters, ses
cryptoRealm: Realm, cryptoRealm: Realm,
activeMemberUserIds: List<String>, activeMemberUserIds: List<String>,
roomSummaryEntity: RoomSummaryEntity): RoomEncryptionTrustLevel { roomSummaryEntity: RoomSummaryEntity): RoomEncryptionTrustLevel {
Timber.d("## CrossSigning - computeRoomShield ${roomSummaryEntity.roomId} -> ${activeMemberUserIds.logLimit()}") Timber.v("## CrossSigning - computeRoomShield ${roomSummaryEntity.roomId} -> ${activeMemberUserIds.logLimit()}")
// The set of “all users” depends on the type of room: // The set of “all users” depends on the type of room:
// For regular / topic rooms which have more than 2 members (including yourself) are considered when decorating a room // For regular / topic rooms which have more than 2 members (including yourself) are considered when decorating a room
// For 1:1 and group DM rooms, all other users (i.e. excluding yourself) are considered when decorating a room // For 1:1 and group DM rooms, all other users (i.e. excluding yourself) are considered when decorating a room

View File

@ -671,7 +671,6 @@ internal class DefaultKeysBackupService @Inject constructor(
Timber.e("restoreKeysWithRecoveryKey: Invalid recovery key for this keys version") Timber.e("restoreKeysWithRecoveryKey: Invalid recovery key for this keys version")
throw InvalidParameterException("Invalid recovery key") throw InvalidParameterException("Invalid recovery key")
} }
// Get a PK decryption instance // Get a PK decryption instance
pkDecryptionFromRecoveryKey(recoveryKey) pkDecryptionFromRecoveryKey(recoveryKey)
} }
@ -681,6 +680,10 @@ internal class DefaultKeysBackupService @Inject constructor(
throw InvalidParameterException("Invalid recovery key") throw InvalidParameterException("Invalid recovery key")
} }
// Save for next time and for gossiping
// Save now as it's valid, don't wait for the import as it could take long.
saveBackupRecoveryKey(recoveryKey, keysVersionResult.version)
stepProgressListener?.onStepProgress(StepProgressListener.Step.DownloadingKey) stepProgressListener?.onStepProgress(StepProgressListener.Step.DownloadingKey)
// Get backed up keys from the homeserver // Get backed up keys from the homeserver
@ -729,8 +732,6 @@ internal class DefaultKeysBackupService @Inject constructor(
if (backUp) { if (backUp) {
maybeBackupKeys() maybeBackupKeys()
} }
// Save for next time and for gossiping
saveBackupRecoveryKey(recoveryKey, keysVersionResult.version)
result result
} }
}.foldToCallback(callback) }.foldToCallback(callback)

View File

@ -70,6 +70,8 @@ data class CryptoDeviceInfo(
keys?.let { map["keys"] = it } keys?.let { map["keys"] = it }
return map return map
} }
fun shortDebugString() = "$userId|$deviceId"
} }
internal fun CryptoDeviceInfo.toRest(): DeviceKeys { internal fun CryptoDeviceInfo.toRest(): DeviceKeys {

View File

@ -16,6 +16,7 @@
package org.matrix.android.sdk.internal.crypto.model package org.matrix.android.sdk.internal.crypto.model
import kotlinx.coroutines.sync.Mutex
import org.matrix.olm.OlmSession import org.matrix.olm.OlmSession
/** /**
@ -25,7 +26,10 @@ data class OlmSessionWrapper(
// The associated olm session. // The associated olm session.
val olmSession: OlmSession, val olmSession: OlmSession,
// Timestamp at which the session last received a message. // Timestamp at which the session last received a message.
var lastReceivedMessageTs: Long = 0) { var lastReceivedMessageTs: Long = 0,
val mutex: Mutex = Mutex()
) {
/** /**
* Notify that a message has been received on this olm session so that it updates `lastReceivedMessageTs` * Notify that a message has been received on this olm session so that it updates `lastReceivedMessageTs`

View File

@ -54,7 +54,7 @@ internal interface IMXCryptoStore {
/** /**
* @return the olm account * @return the olm account
*/ */
fun getOlmAccount(): OlmAccount fun <T> doWithOlmAccount(block: (OlmAccount) -> T): T
fun getOrCreateOlmAccount(): OlmAccount fun getOrCreateOlmAccount(): OlmAccount
@ -261,7 +261,7 @@ internal interface IMXCryptoStore {
fun storeSession(olmSessionWrapper: OlmSessionWrapper, deviceKey: String) fun storeSession(olmSessionWrapper: OlmSessionWrapper, deviceKey: String)
/** /**
* Retrieve the end-to-end session ids between the logged-in user and another * Retrieve all end-to-end session ids between our own device and another
* device. * device.
* *
* @param deviceKey the public key of the other device. * @param deviceKey the public key of the other device.
@ -270,7 +270,7 @@ internal interface IMXCryptoStore {
fun getDeviceSessionIds(deviceKey: String): List<String>? fun getDeviceSessionIds(deviceKey: String): List<String>?
/** /**
* Retrieve an end-to-end session between the logged-in user and another * Retrieve an end-to-end session between our own device and another
* device. * device.
* *
* @param sessionId the session Id. * @param sessionId the session Id.

View File

@ -104,7 +104,6 @@ import timber.log.Timber
import java.util.concurrent.Executors import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import javax.inject.Inject import javax.inject.Inject
import kotlin.collections.set
@SessionScope @SessionScope
internal class RealmCryptoStore @Inject constructor( internal class RealmCryptoStore @Inject constructor(
@ -124,12 +123,6 @@ internal class RealmCryptoStore @Inject constructor(
// The olm account // The olm account
private var olmAccount: OlmAccount? = null private var olmAccount: OlmAccount? = null
// Cache for OlmSession, to release them properly
private val olmSessionsToRelease = HashMap<String, OlmSessionWrapper>()
// Cache for InboundGroupSession, to release them properly
private val inboundGroupSessionToRelease = HashMap<String, OlmInboundGroupSessionWrapper2>()
private val newSessionListeners = ArrayList<NewSessionListener>() private val newSessionListeners = ArrayList<NewSessionListener>()
override fun addNewSessionListener(listener: NewSessionListener) { override fun addNewSessionListener(listener: NewSessionListener) {
@ -213,16 +206,6 @@ internal class RealmCryptoStore @Inject constructor(
monarchyWriteAsyncExecutor.awaitTermination(1, TimeUnit.MINUTES) monarchyWriteAsyncExecutor.awaitTermination(1, TimeUnit.MINUTES)
} }
olmSessionsToRelease.forEach {
it.value.olmSession.releaseSession()
}
olmSessionsToRelease.clear()
inboundGroupSessionToRelease.forEach {
it.value.olmInboundGroupSession?.releaseSession()
}
inboundGroupSessionToRelease.clear()
olmAccount?.releaseAccount() olmAccount?.releaseAccount()
realmLocker?.close() realmLocker?.close()
@ -247,10 +230,18 @@ internal class RealmCryptoStore @Inject constructor(
} }
} }
override fun getOlmAccount(): OlmAccount { /**
return olmAccount!! * Olm account access should be synchronized
*/
override fun <T> doWithOlmAccount(block: (OlmAccount) -> T): T {
return olmAccount!!.let { olmAccount ->
synchronized(olmAccount) {
block.invoke(olmAccount)
}
}
} }
@Synchronized
override fun getOrCreateOlmAccount(): OlmAccount { override fun getOrCreateOlmAccount(): OlmAccount {
doRealmTransaction(realmConfiguration) { doRealmTransaction(realmConfiguration) {
val metaData = it.where<CryptoMetadataEntity>().findFirst() val metaData = it.where<CryptoMetadataEntity>().findFirst()
@ -680,13 +671,6 @@ internal class RealmCryptoStore @Inject constructor(
if (sessionIdentifier != null) { if (sessionIdentifier != null) {
val key = OlmSessionEntity.createPrimaryKey(sessionIdentifier, deviceKey) val key = OlmSessionEntity.createPrimaryKey(sessionIdentifier, deviceKey)
// Release memory of previously known session, if it is not the same one
if (olmSessionsToRelease[key]?.olmSession != olmSessionWrapper.olmSession) {
olmSessionsToRelease[key]?.olmSession?.releaseSession()
}
olmSessionsToRelease[key] = olmSessionWrapper
doRealmTransaction(realmConfiguration) { doRealmTransaction(realmConfiguration) {
val realmOlmSession = OlmSessionEntity().apply { val realmOlmSession = OlmSessionEntity().apply {
primaryKey = key primaryKey = key
@ -703,23 +687,18 @@ internal class RealmCryptoStore @Inject constructor(
override fun getDeviceSession(sessionId: String, deviceKey: String): OlmSessionWrapper? { override fun getDeviceSession(sessionId: String, deviceKey: String): OlmSessionWrapper? {
val key = OlmSessionEntity.createPrimaryKey(sessionId, deviceKey) val key = OlmSessionEntity.createPrimaryKey(sessionId, deviceKey)
return doRealmQueryAndCopy(realmConfiguration) {
// If not in cache (or not found), try to read it from realm it.where<OlmSessionEntity>()
if (olmSessionsToRelease[key] == null) { .equalTo(OlmSessionEntityFields.PRIMARY_KEY, key)
doRealmQueryAndCopy(realmConfiguration) { .findFirst()
it.where<OlmSessionEntity>()
.equalTo(OlmSessionEntityFields.PRIMARY_KEY, key)
.findFirst()
}
?.let {
val olmSession = it.getOlmSession()
if (olmSession != null && it.sessionId != null) {
olmSessionsToRelease[key] = OlmSessionWrapper(olmSession, it.lastReceivedMessageTs)
}
}
} }
?.let {
return olmSessionsToRelease[key] val olmSession = it.getOlmSession()
if (olmSession != null && it.sessionId != null) {
return@let OlmSessionWrapper(olmSession, it.lastReceivedMessageTs)
}
null
}
} }
override fun getLastUsedSessionId(deviceKey: String): String? { override fun getLastUsedSessionId(deviceKey: String): String? {
@ -761,13 +740,6 @@ internal class RealmCryptoStore @Inject constructor(
if (sessionIdentifier != null) { if (sessionIdentifier != null) {
val key = OlmInboundGroupSessionEntity.createPrimaryKey(sessionIdentifier, session.senderKey) val key = OlmInboundGroupSessionEntity.createPrimaryKey(sessionIdentifier, session.senderKey)
// Release memory of previously known session, if it is not the same one
if (inboundGroupSessionToRelease[key] != session) {
inboundGroupSessionToRelease[key]?.olmInboundGroupSession?.releaseSession()
}
inboundGroupSessionToRelease[key] = session
val realmOlmInboundGroupSession = OlmInboundGroupSessionEntity().apply { val realmOlmInboundGroupSession = OlmInboundGroupSessionEntity().apply {
primaryKey = key primaryKey = key
sessionId = sessionIdentifier sessionId = sessionIdentifier
@ -784,20 +756,12 @@ internal class RealmCryptoStore @Inject constructor(
override fun getInboundGroupSession(sessionId: String, senderKey: String): OlmInboundGroupSessionWrapper2? { override fun getInboundGroupSession(sessionId: String, senderKey: String): OlmInboundGroupSessionWrapper2? {
val key = OlmInboundGroupSessionEntity.createPrimaryKey(sessionId, senderKey) val key = OlmInboundGroupSessionEntity.createPrimaryKey(sessionId, senderKey)
// If not in cache (or not found), try to read it from realm return doWithRealm(realmConfiguration) {
if (inboundGroupSessionToRelease[key] == null) { it.where<OlmInboundGroupSessionEntity>()
doWithRealm(realmConfiguration) { .equalTo(OlmInboundGroupSessionEntityFields.PRIMARY_KEY, key)
it.where<OlmInboundGroupSessionEntity>() .findFirst()
.equalTo(OlmInboundGroupSessionEntityFields.PRIMARY_KEY, key) ?.getInboundGroupSession()
.findFirst()
?.getInboundGroupSession()
}
?.let {
inboundGroupSessionToRelease[key] = it
}
} }
return inboundGroupSessionToRelease[key]
} }
override fun getCurrentOutboundGroupSessionForRoom(roomId: String): OutboundGroupSessionWrapper? { override fun getCurrentOutboundGroupSessionForRoom(roomId: String): OutboundGroupSessionWrapper? {
@ -853,10 +817,6 @@ internal class RealmCryptoStore @Inject constructor(
override fun removeInboundGroupSession(sessionId: String, senderKey: String) { override fun removeInboundGroupSession(sessionId: String, senderKey: String) {
val key = OlmInboundGroupSessionEntity.createPrimaryKey(sessionId, senderKey) val key = OlmInboundGroupSessionEntity.createPrimaryKey(sessionId, senderKey)
// Release memory of previously known session
inboundGroupSessionToRelease[key]?.olmInboundGroupSession?.releaseSession()
inboundGroupSessionToRelease.remove(key)
doRealmTransaction(realmConfiguration) { doRealmTransaction(realmConfiguration) {
it.where<OlmInboundGroupSessionEntity>() it.where<OlmInboundGroupSessionEntity>()
.equalTo(OlmInboundGroupSessionEntityFields.PRIMARY_KEY, key) .equalTo(OlmInboundGroupSessionEntityFields.PRIMARY_KEY, key)

View File

@ -156,7 +156,7 @@ internal class DefaultFetchThreadTimelineTask @Inject constructor(
* Invoke the event decryption mechanism for a specific event * Invoke the event decryption mechanism for a specific event
*/ */
private fun decryptIfNeeded(event: Event, roomId: String) { private suspend fun decryptIfNeeded(event: Event, roomId: String) {
try { try {
// Event from sync does not have roomId, so add it to the event first // Event from sync does not have roomId, so add it to the event first
val result = cryptoService.decryptEvent(event.copy(roomId = roomId), "") val result = cryptoService.decryptEvent(event.copy(roomId = roomId), "")

View File

@ -18,6 +18,7 @@ package org.matrix.android.sdk.internal.session.room.summary
import io.realm.Realm import io.realm.Realm
import io.realm.kotlin.createObject import io.realm.kotlin.createObject
import kotlinx.coroutines.runBlocking
import org.matrix.android.sdk.api.extensions.orFalse import org.matrix.android.sdk.api.extensions.orFalse
import org.matrix.android.sdk.api.extensions.tryOrNull import org.matrix.android.sdk.api.extensions.tryOrNull
import org.matrix.android.sdk.api.session.events.model.EventType import org.matrix.android.sdk.api.session.events.model.EventType
@ -165,7 +166,9 @@ internal class RoomSummaryUpdater @Inject constructor(
Timber.v("Should decrypt ${latestPreviewableEvent.eventId}") Timber.v("Should decrypt ${latestPreviewableEvent.eventId}")
// mmm i want to decrypt now or is it ok to do it async? // mmm i want to decrypt now or is it ok to do it async?
tryOrNull { tryOrNull {
eventDecryptor.decryptEvent(root.asDomain(), "") runBlocking {
eventDecryptor.decryptEvent(root.asDomain(), "")
}
} }
?.let { root.setDecryptionResult(it) } ?.let { root.setDecryptionResult(it) }
} }

View File

@ -17,6 +17,7 @@ package org.matrix.android.sdk.internal.session.room.timeline
import io.realm.Realm import io.realm.Realm
import io.realm.RealmConfiguration import io.realm.RealmConfiguration
import kotlinx.coroutines.runBlocking
import org.matrix.android.sdk.api.session.crypto.CryptoService import org.matrix.android.sdk.api.session.crypto.CryptoService
import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.session.crypto.MXCryptoError
import org.matrix.android.sdk.api.session.events.model.Event import org.matrix.android.sdk.api.session.events.model.Event
@ -99,7 +100,9 @@ internal class TimelineEventDecryptor @Inject constructor(
} }
executor?.execute { executor?.execute {
Realm.getInstance(realmConfiguration).use { realm -> Realm.getInstance(realmConfiguration).use { realm ->
processDecryptRequest(request, realm) runBlocking {
processDecryptRequest(request, realm)
}
} }
} }
} }
@ -115,7 +118,7 @@ internal class TimelineEventDecryptor @Inject constructor(
threadsAwarenessHandler.makeEventThreadAware(realm, event.roomId, decryptedEvent, eventEntity) threadsAwarenessHandler.makeEventThreadAware(realm, event.roomId, decryptedEvent, eventEntity)
} }
} }
private fun processDecryptRequest(request: DecryptionRequest, realm: Realm) { private suspend fun processDecryptRequest(request: DecryptionRequest, realm: Realm) {
val event = request.event val event = request.event
val timelineId = request.timelineId val timelineId = request.timelineId

View File

@ -110,6 +110,7 @@ internal class SyncResponseHandler @Inject constructor(
// Start one big transaction // Start one big transaction
monarchy.awaitTransaction { realm -> monarchy.awaitTransaction { realm ->
// IMPORTANT nothing should be suspend here as we are accessing the realm instance (thread local)
measureTimeMillis { measureTimeMillis {
Timber.v("Handle rooms") Timber.v("Handle rooms")
reportSubtask(reporter, InitSyncStep.ImportingAccountRoom, 1, 0.7f) { reportSubtask(reporter, InitSyncStep.ImportingAccountRoom, 1, 0.7f) {

View File

@ -38,7 +38,7 @@ private val loggerTag = LoggerTag("CryptoSyncHandler", LoggerTag.CRYPTO)
internal class CryptoSyncHandler @Inject constructor(private val cryptoService: DefaultCryptoService, internal class CryptoSyncHandler @Inject constructor(private val cryptoService: DefaultCryptoService,
private val verificationService: DefaultVerificationService) { private val verificationService: DefaultVerificationService) {
fun handleToDevice(toDevice: ToDeviceSyncResponse, progressReporter: ProgressReporter? = null) { suspend fun handleToDevice(toDevice: ToDeviceSyncResponse, progressReporter: ProgressReporter? = null) {
val total = toDevice.events?.size ?: 0 val total = toDevice.events?.size ?: 0
toDevice.events?.forEachIndexed { index, event -> toDevice.events?.forEachIndexed { index, event ->
progressReporter?.reportProgress(index * 100F / total) progressReporter?.reportProgress(index * 100F / total)
@ -66,7 +66,7 @@ internal class CryptoSyncHandler @Inject constructor(private val cryptoService:
* @param timelineId the timeline identifier * @param timelineId the timeline identifier
* @return true if the event has been decrypted * @return true if the event has been decrypted
*/ */
private fun decryptToDeviceEvent(event: Event, timelineId: String?): Boolean { private suspend fun decryptToDeviceEvent(event: Event, timelineId: String?): Boolean {
Timber.v("## CRYPTO | decryptToDeviceEvent") Timber.v("## CRYPTO | decryptToDeviceEvent")
if (event.getClearType() == EventType.ENCRYPTED) { if (event.getClearType() == EventType.ENCRYPTED) {
var result: MXEventDecryptionResult? = null var result: MXEventDecryptionResult? = null
@ -80,6 +80,8 @@ internal class CryptoSyncHandler @Inject constructor(private val cryptoService:
it.identityKey() == senderKey it.identityKey() == senderKey
}?.deviceId ?: senderKey }?.deviceId ?: senderKey
Timber.e("## CRYPTO | Failed to decrypt to device event from ${event.senderId}|$deviceId reason:<${event.mCryptoError ?: exception}>") Timber.e("## CRYPTO | Failed to decrypt to device event from ${event.senderId}|$deviceId reason:<${event.mCryptoError ?: exception}>")
} catch (failure: Throwable) {
Timber.e(failure, "## CRYPTO | Failed to decrypt to device event from ${event.senderId}")
} }
if (null != result) { if (null != result) {
@ -91,7 +93,9 @@ internal class CryptoSyncHandler @Inject constructor(private val cryptoService:
) )
return true return true
} else { } else {
// should not happen // Could happen for to device events
// None of the known session could decrypt the message
// In this case unwedging process might have been started (rate limited)
Timber.e("## CRYPTO | ERROR NULL DECRYPTION RESULT from ${event.senderId}") Timber.e("## CRYPTO | ERROR NULL DECRYPTION RESULT from ${event.senderId}")
} }
} }

View File

@ -19,6 +19,7 @@ package org.matrix.android.sdk.internal.session.sync.handler.room
import dagger.Lazy import dagger.Lazy
import io.realm.Realm import io.realm.Realm
import io.realm.kotlin.createObject import io.realm.kotlin.createObject
import kotlinx.coroutines.runBlocking
import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.session.crypto.MXCryptoError
import org.matrix.android.sdk.api.session.events.model.Event import org.matrix.android.sdk.api.session.events.model.Event
import org.matrix.android.sdk.api.session.events.model.EventType import org.matrix.android.sdk.api.session.events.model.EventType
@ -379,7 +380,9 @@ internal class RoomSyncHandler @Inject constructor(private val readReceiptHandle
val isInitialSync = insertType == EventInsertType.INITIAL_SYNC val isInitialSync = insertType == EventInsertType.INITIAL_SYNC
if (event.isEncrypted() && !isInitialSync) { if (event.isEncrypted() && !isInitialSync) {
decryptIfNeeded(event, roomId) runBlocking {
decryptIfNeeded(event, roomId)
}
} }
var contentToInject: String? = null var contentToInject: String? = null
if (!isInitialSync) { if (!isInitialSync) {
@ -455,7 +458,7 @@ internal class RoomSyncHandler @Inject constructor(private val readReceiptHandle
return chunkEntity return chunkEntity
} }
private fun decryptIfNeeded(event: Event, roomId: String) { private suspend fun decryptIfNeeded(event: Event, roomId: String) {
try { try {
// Event from sync does not have roomId, so add it to the event first // Event from sync does not have roomId, so add it to the event first
val result = cryptoService.decryptEvent(event.copy(roomId = roomId), "") val result = cryptoService.decryptEvent(event.copy(roomId = roomId), "")

View File

@ -190,7 +190,7 @@ class NotifiableEventResolver @Inject constructor(
} }
} }
private fun TimelineEvent.attemptToDecryptIfNeeded(session: Session) { private suspend fun TimelineEvent.attemptToDecryptIfNeeded(session: Session) {
if (root.isEncrypted() && root.mxDecryptionResult == null) { if (root.isEncrypted() && root.mxDecryptionResult == null) {
// TODO use a global event decryptor? attache to session and that listen to new sessionId? // TODO use a global event decryptor? attache to session and that listen to new sessionId?
// for now decrypt sync // for now decrypt sync