package org.neo4j.gds.ml.core.optimizer;

import java.util.List;
import java.util.stream.Collectors;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/ml/core/optimizer/AdamOptimizer.class */
public class AdamOptimizer implements Updater {
    private static final double CLIP_MAX = 5.0d;
    private static final double CLIP_MIN = -5.0d;
    private final double learningRate;
    private final List<Weights<? extends Tensor<?>>> weights;
    final List<Tensor<?>> momentumTerms;
    final List<Tensor<?>> velocityTerms;
    private final double beta_1 = 0.9d;
    private final double beta_2 = 0.999d;
    private final double epsilon = 1.0E-8d;
    private int iteration = 0;

    public static long sizeInBytes(int i, int i2) {
        long sizeInBytes = Weights.sizeInBytes(i, i2);
        return Estimate.sizeOfInstance(AdamOptimizer.class) + (2 * sizeInBytes) + (2 * sizeInBytes);
    }

    public AdamOptimizer(List<Weights<? extends Tensor<?>>> list, double d) {
        this.learningRate = d;
        this.weights = list;
        this.momentumTerms = (List) list.stream().map(weights -> {
            return weights.data().createWithSameDimensions();
        }).collect(Collectors.toList());
        this.velocityTerms = (List) list.stream().map(weights2 -> {
            return weights2.data().createWithSameDimensions();
        }).collect(Collectors.toList());
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v28, types: [org.neo4j.gds.ml.core.tensor.Tensor] */
    /* JADX WARN: Type inference failed for: r0v30, types: [org.neo4j.gds.ml.core.tensor.Tensor] */
    /* JADX WARN: Type inference failed for: r1v14, types: [org.neo4j.gds.ml.core.tensor.Tensor] */
    @Override // org.neo4j.gds.ml.core.optimizer.Updater
    public void update(List<? extends Tensor<?>> list) {
        this.iteration++;
        for (int i = 0; i < this.weights.size(); i++) {
            Tensor<?> data = this.weights.get(i).data();
            Tensor<?> tensor = list.get(i);
            Tensor<?> tensor2 = this.momentumTerms.get(i);
            Tensor<?> tensor3 = this.velocityTerms.get(i);
            tensor.mapInPlace(this::clip);
            tensor2.scalarMultiplyMutate(0.9d).addInPlace(tensor.scalarMultiply(0.09999999999999998d));
            tensor3.scalarMultiplyMutate(0.999d).addInPlace(tensor.mapInPlace(d -> {
                return d * d;
            }).scalarMultiplyMutate(0.0010000000000000009d));
            data.addInPlace(tensor2.scalarMultiply(1.0d / (1.0d - Math.pow(0.9d, this.iteration))).scalarMultiplyMutate(-this.learningRate).elementwiseProductMutate(tensor3.scalarMultiply(1.0d / (1.0d - Math.pow(0.999d, this.iteration))).mapInPlace(d2 -> {
                return 1.0d / (Math.sqrt(d2) + 1.0E-8d);
            })));
        }
    }

    private double clip(double d) {
        return d > CLIP_MAX ? CLIP_MAX : Math.max(d, CLIP_MIN);
    }
}
