crypto: Introduce some locks for some of our e2ee operations

This commit is contained in:
Damir Jelić 2021-04-09 12:42:22 +02:00
parent 8692f05e34
commit 74a1c226a4

View File

@ -24,6 +24,7 @@ import com.squareup.moshi.Types
import dagger.Lazy import dagger.Lazy
import java.io.File import java.io.File
import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.ConcurrentHashMap
import javax.inject.Inject import javax.inject.Inject
import kotlin.jvm.Throws import kotlin.jvm.Throws
import kotlin.math.max import kotlin.math.max
@ -32,6 +33,8 @@ import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.cancelChildren import kotlinx.coroutines.cancelChildren
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import org.matrix.android.sdk.api.MatrixCallback import org.matrix.android.sdk.api.MatrixCallback
import org.matrix.android.sdk.api.NoOpMatrixCallback import org.matrix.android.sdk.api.NoOpMatrixCallback
@ -152,6 +155,11 @@ internal class DefaultCryptoService @Inject constructor(
private var olmMachine: OlmMachine? = null private var olmMachine: OlmMachine? = null
private val deviceObserver: DeviceUpdateObserver = DeviceUpdateObserver() private val deviceObserver: DeviceUpdateObserver = DeviceUpdateObserver()
// Locks for some of our operations
private val keyClaimLock: Mutex = Mutex()
private val outgointRequestsLock: Mutex = Mutex()
private val roomKeyShareLocks: ConcurrentHashMap<String, Mutex> = ConcurrentHashMap()
suspend fun onStateEvent(roomId: String, event: Event) { suspend fun onStateEvent(roomId: String, event: Event) {
when (event.getClearType()) { when (event.getClearType()) {
EventType.STATE_ROOM_ENCRYPTION -> onRoomEncryptionEvent(roomId, event) EventType.STATE_ROOM_ENCRYPTION -> onRoomEncryptionEvent(roomId, event)
@ -650,23 +658,28 @@ internal class DefaultCryptoService @Inject constructor(
} }
private suspend fun preshareGroupSession(roomId: String, roomMembers: List<String>) { private suspend fun preshareGroupSession(roomId: String, roomMembers: List<String>) {
// TODO this needs to be locked per room keyClaimLock.withLock {
val request = olmMachine!!.getMissingSessions(roomMembers) val request = olmMachine!!.getMissingSessions(roomMembers)
if (request != null) {
if (request != null) { // This request can only be a keys claim request.
// This request can only be a keys claim request. when (request) {
when (request) { is Request.KeysClaim -> {
is Request.KeysClaim -> { claimKeys(request)
claimKeys(request) }
} }
} }
} }
for (toDeviceRequest in olmMachine!!.shareGroupSession(roomId, roomMembers)) { val keyShareLock = roomKeyShareLocks.getOrDefault(roomId, Mutex())
// This request can only be a to-device request.
when (toDeviceRequest) { keyShareLock.withLock {
is Request.ToDevice -> { for (toDeviceRequest in olmMachine!!.shareGroupSession(roomId, roomMembers)) {
sendToDevice(toDeviceRequest) // TODO these requests should be sent out in parallel
// This request can only be a to-device request.
when (toDeviceRequest) {
is Request.ToDevice -> {
sendToDevice(toDeviceRequest)
}
} }
} }
} }
@ -699,7 +712,6 @@ internal class DefaultCryptoService @Inject constructor(
} }
private suspend fun queryKeys(outgoingRequest: Request.KeysQuery) { private suspend fun queryKeys(outgoingRequest: Request.KeysQuery) {
Timber.v("HELLO KEYS QUERY REQUEST ${outgoingRequest.users}")
val params = DownloadKeysForUsersTask.Params(outgoingRequest.users, null) val params = DownloadKeysForUsersTask.Params(outgoingRequest.users, null)
try { try {
@ -729,7 +741,6 @@ internal class DefaultCryptoService @Inject constructor(
} }
private suspend fun claimKeys(request: Request.KeysClaim) { private suspend fun claimKeys(request: Request.KeysClaim) {
// TODO this needs to be locked per call
val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(request.oneTimeKeys) val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(request.oneTimeKeys)
val response = oneTimeKeysForUsersDeviceTask.execute(claimParams) val response = oneTimeKeysForUsersDeviceTask.execute(claimParams)
val adapter = MoshiProvider val adapter = MoshiProvider
@ -741,18 +752,19 @@ internal class DefaultCryptoService @Inject constructor(
} }
private suspend fun sendOutgoingRequests() { private suspend fun sendOutgoingRequests() {
// TODO this needs to be locked per call outgointRequestsLock.withLock {
// TODO these requests should be sent out in parallel // TODO these requests should be sent out in parallel
for (outgoingRequest in olmMachine!!.outgoingRequests()) { for (outgoingRequest in olmMachine!!.outgoingRequests()) {
when (outgoingRequest) { when (outgoingRequest) {
is Request.KeysUpload -> { is Request.KeysUpload -> {
uploadKeys(outgoingRequest) uploadKeys(outgoingRequest)
} }
is Request.KeysQuery -> { is Request.KeysQuery -> {
queryKeys(outgoingRequest) queryKeys(outgoingRequest)
} }
is Request.ToDevice -> { is Request.ToDevice -> {
// Timber.v("HELLO TO DEVICE REQUEST ${outgoingRequest.body}") // Timber.v("HELLO TO DEVICE REQUEST ${outgoingRequest.body}")
}
} }
} }
} }