package com.neo4j.gds.ml.pipeline;

import com.neo4j.gds.ModelDataSerializer;
import com.neo4j.gds.ml.model.proto.RandomForestDataProto;
import com.neo4j.gds.shaded.com.google.protobuf.GeneratedMessageV3;
import com.neo4j.gds.shaded.com.google.protobuf.InvalidProtocolBufferException;
import com.neo4j.gds.shaded.com.google.protobuf.Parser;
import com.neo4j.gds.shaded.org.jetbrains.annotations.NotNull;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.neo4j.gds.ml.decisiontree.DecisionTreePredictor;
import org.neo4j.gds.ml.decisiontree.TreeNode;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierData;

/* loaded from: input_file:com/neo4j/gds/ml/pipeline/RandomForestSerializer.class */
public class RandomForestSerializer implements ModelDataSerializer {
    @Override // com.neo4j.gds.ModelDataSerializer
    public RandomForestDataProto.RandomForestData serialize(Object obj) {
        RandomForestClassifierData randomForestClassifierData = (RandomForestClassifierData) obj;
        return RandomForestDataProto.RandomForestData.newBuilder().addAllTrees((List) randomForestClassifierData.decisionTrees().stream().map(decisionTreePredictor -> {
            return serializeDecisionTree(decisionTreePredictor.root);
        }).collect(Collectors.toList())).setNumberOfClasses(randomForestClassifierData.numberOfClasses()).setFeatureDimension(randomForestClassifierData.featureDimension()).build();
    }

    private RandomForestDataProto.DecisionTreeData serializeDecisionTree(TreeNode<Integer> treeNode) {
        ArrayList arrayList = new ArrayList();
        ArrayDeque arrayDeque = new ArrayDeque();
        arrayDeque.push(treeNode);
        while (!arrayDeque.isEmpty()) {
            TreeNode treeNode2 = (TreeNode) arrayDeque.pop();
            if (treeNode2.prediction() != null) {
                arrayList.add(RandomForestDataProto.TreeNodeData.newBuilder().setPrediction(((Integer) treeNode2.prediction()).intValue()).build());
            } else {
                arrayList.add(RandomForestDataProto.TreeNodeData.newBuilder().setFeatureIndex(treeNode2.featureIndex()).setThresholdValue(treeNode2.thresholdValue()).build());
                arrayDeque.push(treeNode2.rightChild());
                arrayDeque.push(treeNode2.leftChild());
            }
        }
        return RandomForestDataProto.DecisionTreeData.newBuilder().addAllNodes(arrayList).build();
    }

    private TreeNode<Integer> deserializeDecisionTree(RandomForestDataProto.DecisionTreeData decisionTreeData) {
        TreeNode<Integer> treeNode;
        Iterator<RandomForestDataProto.TreeNodeData> it = decisionTreeData.getNodesList().iterator();
        ArrayDeque arrayDeque = new ArrayDeque();
        RandomForestDataProto.TreeNodeData next = it.next();
        if (next.hasPrediction()) {
            treeNode = new TreeNode<>(Integer.valueOf(next.getPrediction()));
        } else {
            treeNode = new TreeNode<>(next.getFeatureIndex(), next.getThresholdValue());
            arrayDeque.push(treeNode);
        }
        while (it.hasNext()) {
            RandomForestDataProto.TreeNodeData next2 = it.next();
            if (next2.hasPrediction()) {
                TreeNode treeNode2 = new TreeNode(Integer.valueOf(next2.getPrediction()));
                TreeNode treeNode3 = (TreeNode) arrayDeque.pop();
                if (treeNode3.hasLeftChild()) {
                    treeNode3.setRightChild(treeNode2);
                    while (treeNode3.hasRightChild() && !arrayDeque.isEmpty()) {
                        TreeNode treeNode4 = (TreeNode) arrayDeque.pop();
                        if (treeNode4.hasLeftChild()) {
                            treeNode4.setRightChild(treeNode3);
                            treeNode3 = treeNode4;
                        } else {
                            treeNode4.setLeftChild(treeNode3);
                            arrayDeque.push(treeNode4);
                            treeNode3 = treeNode4;
                        }
                    }
                } else {
                    treeNode3.setLeftChild(treeNode2);
                    arrayDeque.push(treeNode3);
                }
            } else {
                arrayDeque.push(new TreeNode(next2.getFeatureIndex(), next2.getThresholdValue()));
            }
        }
        return treeNode;
    }

    @Override // com.neo4j.gds.ModelDataSerializer
    @NotNull
    public RandomForestClassifierData deserialize(GeneratedMessageV3 generatedMessageV3) throws InvalidProtocolBufferException {
        RandomForestDataProto.RandomForestData randomForestData = (RandomForestDataProto.RandomForestData) generatedMessageV3;
        if (randomForestData.getTreesCount() == 0) {
            throw new InvalidProtocolBufferException("Could not parse serializedData because decision tree count is zero.");
        }
        return RandomForestClassifierData.builder().addAllDecisionTrees((List) randomForestData.getTreesList().stream().map(this::deserializeDecisionTree).map(DecisionTreePredictor::new).collect(Collectors.toList())).numberOfClasses(randomForestData.getNumberOfClasses()).featureDimension(randomForestData.getFeatureDimension()).build();
    }

    @Override // com.neo4j.gds.ModelDataSerializer
    public Parser<RandomForestDataProto.RandomForestData> parser() {
        return RandomForestDataProto.RandomForestData.parser();
    }
}
