package org.neo4j.gds.ml.core;

import com.neo4j.gds.shaded.org.apache.commons.lang3.mutable.MutableInt;
import com.neo4j.gds.shaded.org.jetbrains.annotations.TestOnly;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.stream.Collectors;
import org.neo4j.gds.ml.core.functions.SingleParentVariable;
import org.neo4j.gds.ml.core.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/ml/core/ComputationContext.class */
public class ComputationContext {
    private final Map<Variable<?>, Tensor<?>> data = new HashMap();
    private final Map<Variable<?>, Tensor<?>> gradients = new HashMap();
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/ml/core/ComputationContext$BackPropTask.class */
    public static class BackPropTask {
        Variable<?> variable;
        Variable<?> child;

        BackPropTask(Variable<?> variable, Variable<?> variable2) {
            this.variable = variable;
            this.child = variable2;
        }
    }

    /* loaded from: input_file:org/neo4j/gds/ml/core/ComputationContext$PassThroughVariable.class */
    private static final class PassThroughVariable<T extends Tensor<T>> extends SingleParentVariable<T, T> {
        private PassThroughVariable(Variable<T> variable) {
            super(variable, variable.dimensions());
            if (variable instanceof PassThroughVariable) {
                throw new IllegalArgumentException("Redundant use of PassthroughVariables. Chaining does not make sense.");
            }
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // org.neo4j.gds.ml.core.Variable
        public T apply(ComputationContext computationContext) {
            return (T) computationContext.data(this.parent);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // org.neo4j.gds.ml.core.functions.SingleParentVariable
        public T gradientForParent(ComputationContext computationContext) {
            return (T) computationContext.data(this.parent).map(d -> {
                return 1.0d;
            });
        }
    }

    public <T extends Tensor<T>> T forward(Variable<T> variable) {
        T t = (T) this.data.get(variable);
        if (t != null) {
            return t;
        }
        Iterator<? extends Variable<?>> it = variable.parents().iterator();
        while (it.hasNext()) {
            forward((Variable) it.next());
        }
        T apply = variable.apply(this);
        this.data.put(variable, apply);
        return apply;
    }

    public <T extends Tensor<T>> T data(Variable<T> variable) {
        return (T) this.data.get(variable);
    }

    public <T extends Tensor<T>> T gradient(Variable<T> variable) {
        return (T) this.gradients.get(variable);
    }

    public void backward(Variable<?> variable) {
        if (!$assertionsDisabled && !Dimensions.isScalar(variable.dimensions())) {
            throw new AssertionError("Root variable must be scalar.");
        }
        if (!$assertionsDisabled && !variable.requireGradient()) {
            throw new AssertionError("Root variable must have requireGradient==true");
        }
        this.gradients.clear();
        LinkedBlockingQueue linkedBlockingQueue = new LinkedBlockingQueue();
        PassThroughVariable passThroughVariable = new PassThroughVariable(variable);
        linkedBlockingQueue.add(new BackPropTask(variable, passThroughVariable));
        HashMap hashMap = new HashMap();
        initUpstream(passThroughVariable, hashMap);
        backward(linkedBlockingQueue, hashMap);
    }

    private void backward(Queue<BackPropTask> queue, Map<Variable<?>, MutableInt> map) {
        while (!queue.isEmpty()) {
            BackPropTask poll = queue.poll();
            Variable<?> variable = poll.variable;
            updateGradient(variable, poll.child.gradient(variable, this));
            if (map.get(variable).decrementAndGet() == 0) {
                for (Variable<?> variable2 : variable.parents()) {
                    if (variable2.requireGradient()) {
                        queue.offer(new BackPropTask(variable2, variable));
                    }
                }
            }
        }
    }

    private void initUpstream(Variable<?> variable, Map<Variable<?>, MutableInt> map) {
        for (Variable<?> variable2 : variable.parents()) {
            if (variable2.requireGradient()) {
                if (!map.containsKey(variable2)) {
                    initUpstream(variable2, map);
                    map.put(variable2, new MutableInt(0));
                }
                map.get(variable2).increment();
            }
        }
    }

    private void updateGradient(Variable<?> variable, Tensor<?> tensor) {
        if (this.gradients.containsKey(variable)) {
            this.gradients.get(variable).addInPlace(tensor);
        } else {
            this.gradients.put(variable, tensor);
        }
    }

    public String render() {
        StringBuilder sb = new StringBuilder();
        this.data.forEach((variable, tensor) -> {
            sb.append(variable).append(System.lineSeparator()).append("\t data: ").append(tensor).append(System.lineSeparator());
            sb.append("\t gradient: ").append((String) Optional.ofNullable(this.gradients.get(variable)).map((v0) -> {
                return v0.toString();
            }).orElse("None")).append(System.lineSeparator());
        });
        renderOrphanGradients(sb);
        return sb.toString();
    }

    @TestOnly
    public Set<Variable<?>> computedVariables() {
        return this.data.keySet();
    }

    private void renderOrphanGradients(StringBuilder sb) {
        Set<Variable<?>> keySet = this.data.keySet();
        List list = (List) this.gradients.entrySet().stream().filter(entry -> {
            return !keySet.contains(entry.getKey());
        }).collect(Collectors.toList());
        if (list.isEmpty()) {
            return;
        }
        sb.append("Found gradients but no data for: ");
        list.forEach(entry2 -> {
            sb.append(System.lineSeparator()).append(entry2.getKey()).append(entry2.getValue());
        });
    }

    static {
        $assertionsDisabled = !ComputationContext.class.desiredAssertionStatus();
    }
}
