mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Speed up Barlow Twins loss (#2519)
This commit is contained in:
parent
50b78da53a
commit
1ccd03fec9
@ -4016,9 +4016,6 @@ namespace dlib
|
||||
|
||||
// Normalize both batches independently across the batch dimension
|
||||
const double eps = 1e-4;
|
||||
resizable_tensor za_norm, means_a, invstds_a;
|
||||
resizable_tensor zb_norm, means_b, invstds_b;
|
||||
resizable_tensor rms, rvs, g, b;
|
||||
g.set_size(1, sample_size);
|
||||
g = 1;
|
||||
b.set_size(1, sample_size);
|
||||
@ -4027,21 +4024,29 @@ namespace dlib
|
||||
tt::batch_normalize(eps, zb_norm, means_b, invstds_b, 1, rms, rvs, zb, g, b);
|
||||
|
||||
// Compute the empirical cross-correlation matrix
|
||||
resizable_tensor eccm;
|
||||
eccm.set_size(sample_size, sample_size);
|
||||
tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false);
|
||||
eccm /= batch_size;
|
||||
|
||||
// Compute the loss: MSE between eccm and the identity matrix.
|
||||
// Off-diagonal terms are weighed by lambda.
|
||||
const matrix<float> C = mat(eccm);
|
||||
const double diagonal_loss = sum(squared(diag(C) - 1));
|
||||
const double off_diag_loss = sum(squared(C - diagm(diag(C))));
|
||||
double loss = diagonal_loss + lambda * off_diag_loss;
|
||||
// Set sizes and setup auxiliary tensors
|
||||
if (!have_same_dimensions(eccm, identity))
|
||||
identity = identity_matrix<float>(sample_size);
|
||||
if (!have_same_dimensions(eccm, cdiag))
|
||||
cdiag.copy_size(eccm);
|
||||
if (!have_same_dimensions(eccm, cdiag_1))
|
||||
cdiag_1.copy_size(eccm);
|
||||
if (!have_same_dimensions(eccm, off_mask))
|
||||
off_mask = ones_matrix<float>(sample_size, sample_size) - identity_matrix<float>(sample_size);
|
||||
if (!have_same_dimensions(eccm, off_diag))
|
||||
off_diag.copy_size(eccm);
|
||||
if (!have_same_dimensions(grad, grad_input))
|
||||
grad_input.copy_size(grad);
|
||||
if (!have_same_dimensions(g_grad, g))
|
||||
g_grad.copy_size(g);
|
||||
if (!have_same_dimensions(b_grad, b))
|
||||
b_grad.copy_size(b);
|
||||
|
||||
// Loss gradient, which will be used as the input of the batch normalization gradient
|
||||
resizable_tensor grad_input;
|
||||
grad_input.copy_size(grad);
|
||||
auto grad_input_a = split(grad_input);
|
||||
auto grad_input_b = split(grad_input, offset);
|
||||
|
||||
@ -4051,11 +4056,15 @@ namespace dlib
|
||||
// C = eccm
|
||||
// D = off_mask: a mask that keeps only the elements outside the diagonal
|
||||
|
||||
// A diagonal matrix containing the diagonal of eccm
|
||||
tt::multiply(false, cdiag, eccm, identity);
|
||||
// The diagonal of eccm minus the identity matrix
|
||||
tt::affine_transform(cdiag_1, cdiag, identity, 1, -1);
|
||||
|
||||
// diagonal term: sum((diag(A' * B) - vector(1)).^2)
|
||||
// --------------------------------------------
|
||||
// => d/dA = 2 * B * diag(diag(A' * B) - vector(1)) = 2 * B * diag(diag(C) - vector(1))
|
||||
// => d/dB = 2 * A * diag(diag(A' * B) - vector(1)) = 2 * A * diag(diag(C) - vector(1))
|
||||
resizable_tensor cdiag_1(diagm(diag(mat(eccm) - 1)));
|
||||
tt::gemm(0, grad_input_a, 2, zb_norm, false, cdiag_1, false);
|
||||
tt::gemm(0, grad_input_b, 2, za_norm, false, cdiag_1, false);
|
||||
|
||||
@ -4063,22 +4072,21 @@ namespace dlib
|
||||
// --------------------------------
|
||||
// => d/dA = 2 * B * ((B' * A) .* (D .* D)') = 2 * B * (C .* (D .* D)) = 2 * B * (C .* D)
|
||||
// => d/dB = 2 * A * ((A' * B) .* (D .* D)) = 2 * A * (C .* (D .* D)) = 2 * A * (C .* D)
|
||||
resizable_tensor off_mask(ones_matrix<float>(sample_size, sample_size) - identity_matrix<float>(sample_size));
|
||||
resizable_tensor off_diag(sample_size, sample_size);
|
||||
tt::multiply(false, off_diag, eccm, off_mask);
|
||||
tt::gemm(1, grad_input_a, 2 * lambda, zb_norm, false, off_diag, false);
|
||||
tt::gemm(1, grad_input_b, 2 * lambda, za_norm, false, off_diag, false);
|
||||
|
||||
// Compute the batch norm gradients, g and b grads are not used
|
||||
resizable_tensor g_grad, b_grad;
|
||||
g_grad.copy_size(g);
|
||||
b_grad.copy_size(b);
|
||||
auto gza = split(grad);
|
||||
auto gzb = split(grad, offset);
|
||||
tt::batch_normalize_gradient(eps, grad_input_a, means_a, invstds_a, za, g, gza, g_grad, b_grad);
|
||||
tt::batch_normalize_gradient(eps, grad_input_b, means_b, invstds_b, zb, g, gzb, g_grad, b_grad);
|
||||
|
||||
return loss;
|
||||
// Compute the loss: MSE between eccm and the identity matrix.
|
||||
// Off-diagonal terms are weighed by lambda.
|
||||
const double diagonal_loss = sum(squared(mat(cdiag_1)));
|
||||
const double off_diag_loss = sum(squared(mat(off_diag)));
|
||||
return diagonal_loss + lambda * off_diag_loss;
|
||||
}
|
||||
|
||||
float get_lambda() const { return lambda; }
|
||||
@ -4116,6 +4124,11 @@ namespace dlib
|
||||
|
||||
private:
|
||||
float lambda = 0.0051;
|
||||
mutable resizable_tensor za_norm, means_a, invstds_a;
|
||||
mutable resizable_tensor zb_norm, means_b, invstds_b;
|
||||
mutable resizable_tensor rms, rvs, g, b;
|
||||
mutable resizable_tensor eccm, grad_input, g_grad, b_grad;
|
||||
mutable resizable_tensor cdiag, cdiag_1, identity, off_mask, off_diag;
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
|
Loading…
Reference in New Issue
Block a user