package org.neo4j.gds.ml.pipeline.nodePipeline;

import com.neo4j.gds.shaded.org.immutables.value.Value;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import org.neo4j.gds.annotation.Configuration;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.config.ToMapConvertible;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.ml.pipeline.NonEmptySetValidation;

@Configuration
/* loaded from: input_file:org/neo4j/gds/ml/pipeline/nodePipeline/NodePropertyPredictionSplitConfig.class */
public interface NodePropertyPredictionSplitConfig extends ToMapConvertible {
    public static final NodePropertyPredictionSplitConfig DEFAULT_CONFIG = of(CypherMapWrapper.empty());

    @Value.Default
    @Configuration.DoubleRange(min = 0.0d, max = 1.0d)
    default double testFraction() {
        return 0.3d;
    }

    @Configuration.IntegerRange(min = 2)
    @Value.Default
    default int validationFolds() {
        return 3;
    }

    static NodePropertyPredictionSplitConfig of(CypherMapWrapper cypherMapWrapper) {
        return new NodePropertyPredictionSplitConfigImpl(cypherMapWrapper);
    }

    @Override // org.neo4j.gds.config.ToMapConvertible
    @Configuration.ToMap
    Map<String, Object> toMap();

    @Configuration.CollectKeys
    default Collection<String> configKeys() {
        return Collections.emptyList();
    }

    @Value.Derived
    @Configuration.Ignore
    default void validateMinNumNodesInSplitSets(Graph graph) {
        long nodeCount = (long) (graph.nodeCount() * testFraction());
        long nodeCount2 = graph.nodeCount() - nodeCount;
        NonEmptySetValidation.validateNodeSetSize(nodeCount, 1L, "test", "`testFraction` is too low");
        NonEmptySetValidation.validateNodeSetSize(nodeCount2, 2L, "train", "`testFraction` is too high");
        NonEmptySetValidation.validateNodeSetSize(nodeCount2 / validationFolds(), 1L, "validation", "`validationFolds` or `testFraction` is too high");
    }

    @Value.Auxiliary
    @Value.Derived
    @Configuration.Ignore
    default long testSetSize(long j) {
        return (long) (testFraction() * j);
    }

    @Value.Auxiliary
    @Value.Derived
    @Configuration.Ignore
    default long trainSetSize(long j) {
        return (long) (j * (1.0d - testFraction()));
    }

    @Value.Auxiliary
    @Value.Derived
    @Configuration.Ignore
    default long foldTrainSetSize(long j) {
        return (trainSetSize(j) * (validationFolds() - 1)) / validationFolds();
    }

    @Value.Auxiliary
    @Value.Derived
    @Configuration.Ignore
    default long foldTestSetSize(long j) {
        return trainSetSize(j) * (1 / validationFolds());
    }
}
