package com.neo4j.gds.core.model;

import com.neo4j.gds.ModelInfoSerializer;
import com.neo4j.gds.core.model.proto.GraphSageProto;
import com.neo4j.gds.shaded.com.google.protobuf.GeneratedMessageV3;
import java.util.List;
import java.util.stream.Collectors;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
import org.neo4j.gds.embeddings.graphsage.ImmutableGraphSageTrainMetrics;

/* loaded from: input_file:com/neo4j/gds/core/model/GraphSageTrainModelInfoSerializer.class */
public class GraphSageTrainModelInfoSerializer implements ModelInfoSerializer<GraphSageModelTrainer.GraphSageTrainMetrics> {
    @Override // com.neo4j.gds.ModelInfoSerializer
    public GraphSageProto.GraphSageMetrics serialize(Model.CustomInfo customInfo) {
        GraphSageModelTrainer.GraphSageTrainMetrics graphSageTrainMetrics = (GraphSageModelTrainer.GraphSageTrainMetrics) customInfo;
        return GraphSageProto.GraphSageMetrics.newBuilder().addAllIterationLossPerEpoch((Iterable) graphSageTrainMetrics.iterationLossPerEpoch().stream().map(list -> {
            return GraphSageProto.LossesForEpochs.newBuilder().addAllLoss(list).build();
        }).collect(Collectors.toList())).setDidConverge(graphSageTrainMetrics.didConverge()).build();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.neo4j.gds.ModelInfoSerializer
    public GraphSageModelTrainer.GraphSageTrainMetrics deserialize(GeneratedMessageV3 generatedMessageV3) {
        GraphSageProto.GraphSageMetrics graphSageMetrics = (GraphSageProto.GraphSageMetrics) generatedMessageV3;
        ImmutableGraphSageTrainMetrics.Builder didConverge = ImmutableGraphSageTrainMetrics.builder().didConverge(graphSageMetrics.getDidConverge());
        if (graphSageMetrics.getEpochLossesList().isEmpty()) {
            graphSageMetrics.getIterationLossPerEpochList().forEach(lossesForEpochs -> {
                didConverge.addIterationLossPerEpoch(lossesForEpochs.getLossList());
            });
        } else {
            didConverge.iterationLossPerEpoch((Iterable) graphSageMetrics.getEpochLossesList().stream().map((v0) -> {
                return List.of(v0);
            }).collect(Collectors.toList()));
        }
        return didConverge.build();
    }

    @Override // com.neo4j.gds.ModelInfoSerializer
    public Class<GraphSageProto.GraphSageMetrics> serializableClass() {
        return GraphSageProto.GraphSageMetrics.class;
    }
}
