package org.neo4j.gds.embeddings.graphsage.algo;

import java.util.function.Function;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.GraphDimensions;
import org.neo4j.gds.embeddings.graphsage.GraphSageHelper;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.mem.MemoryEstimateDefinition;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/algo/GraphSageMemoryEstimateDefinition.class */
public class GraphSageMemoryEstimateDefinition implements MemoryEstimateDefinition {
    private final GraphSageTrainMemoryEstimateParameters trainEstimationParameters;
    private final boolean mutating;

    public GraphSageMemoryEstimateDefinition(GraphSageTrainMemoryEstimateParameters graphSageTrainMemoryEstimateParameters, boolean z) {
        this.trainEstimationParameters = graphSageTrainMemoryEstimateParameters;
        this.mutating = z;
    }

    @Override // org.neo4j.gds.mem.MemoryEstimateDefinition
    public MemoryEstimation memoryEstimation() {
        return MemoryEstimations.setup("", (Function<GraphDimensions, MemoryEstimation>) graphDimensions -> {
            return memoryEstimation(graphDimensions.nodeCount());
        });
    }

    private MemoryEstimation memoryEstimation(long j) {
        MemoryEstimations.Builder builder = MemoryEstimations.builder("GraphSage");
        if (this.mutating) {
            builder = builder.startField(MemoryEstimations.RESIDENT_MEMORY).perNode("resultFeatures", j2 -> {
                return HugeObjectArray.memoryEstimation(j2, Estimate.sizeOfDoubleArray(this.trainEstimationParameters.embeddingDimension()));
            }).endField();
        }
        MemoryEstimations.Builder perThread = builder.startField(MemoryEstimations.TEMPORARY_MEMORY).field("this.instance", GraphSage.class).perNode("initialFeatures", j3 -> {
            return HugeObjectArray.memoryEstimation(j3, Estimate.sizeOfDoubleArray(this.trainEstimationParameters.estimationFeatureDimension()));
        }).perThread("concurrentBatches", MemoryEstimations.builder().add(GraphSageHelper.embeddingsEstimation(this.trainEstimationParameters, this.trainEstimationParameters.batchSize(), j, 0, false)).build());
        if (!this.mutating) {
            perThread = perThread.perNode("resultFeatures", j4 -> {
                return HugeObjectArray.memoryEstimation(j4, Estimate.sizeOfDoubleArray(this.trainEstimationParameters.embeddingDimension()));
            });
        }
        return perThread.endField().build();
    }
}
