package com.neo4j.gds.arrow.server.export;

import com.neo4j.gds.arrow.core.metrics.FlightMetrics;
import com.neo4j.gds.arrow.core.vectors.ArrowVectorBuffer;
import com.neo4j.gds.arrow.server.GdsServerExceptions;
import com.neo4j.gds.arrow.server.actions.v1.GdsFlightServerCommands;
import com.neo4j.gds.arrow.server.actions.v1.GdsFlightServerCommands.BaseGetCommand;
import com.neo4j.gds.arrow.server.export.resultstore.MetaDataWriter;
import com.neo4j.gds.shaded.org.apache.arrow.flight.FlightProducer;
import com.neo4j.gds.shaded.org.apache.arrow.flight.OutboundStreamListener;
import com.neo4j.gds.shaded.org.apache.arrow.memory.ArrowBuf;
import com.neo4j.gds.shaded.org.apache.arrow.memory.BufferAllocator;
import com.neo4j.gds.shaded.org.apache.arrow.vector.VectorLoader;
import com.neo4j.gds.shaded.org.apache.arrow.vector.VectorSchemaRoot;
import com.neo4j.gds.shaded.org.apache.arrow.vector.dictionary.DictionaryProvider;
import com.neo4j.gds.shaded.org.apache.arrow.vector.ipc.message.IpcOption;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.ResultStore;
import org.neo4j.gds.core.concurrency.BatchSize;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.loading.CatalogRequest;
import org.neo4j.gds.core.loading.GraphStoreCatalog;
import org.neo4j.gds.core.loading.GraphStoreCatalogEntry;
import org.neo4j.gds.core.utils.progress.JobId;
import org.neo4j.gds.core.utils.progress.TaskStore;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.logging.Log;

/* loaded from: input_file:com/neo4j/gds/arrow/server/export/ArrowExporter.class */
public class ArrowExporter<COMMAND extends GdsFlightServerCommands.BaseGetCommand> {
    private final Concurrency concurrency;
    private final BatchSize batchSize;
    private final BufferAllocator bufferAllocator;
    private final FlightMetrics metrics;
    private final Runnable logUpdateFunction;
    private final ExportTaskFactory<COMMAND> exportTaskFactory;
    private final ProgressTrackerFactory progressTrackerFactory;

    /* loaded from: input_file:com/neo4j/gds/arrow/server/export/ArrowExporter$TrackingFlightStreamListener.class */
    private static final class TrackingFlightStreamListener implements FlightProducer.ServerStreamListener {
        private final FlightMetrics metrics;
        private final FlightProducer.ServerStreamListener delegate;
        private final FlightMetrics.StreamKind streamKind;
        private VectorSchemaRoot root;
        private long streamedRecords = 0;

        private TrackingFlightStreamListener(FlightMetrics flightMetrics, FlightProducer.ServerStreamListener serverStreamListener, FlightMetrics.StreamKind streamKind) {
            this.metrics = flightMetrics;
            this.delegate = serverStreamListener;
            this.streamKind = streamKind;
        }

        @Override // com.neo4j.gds.shaded.org.apache.arrow.flight.FlightProducer.ServerStreamListener
        public boolean isCancelled() {
            return this.delegate.isCancelled();
        }

        @Override // com.neo4j.gds.shaded.org.apache.arrow.flight.FlightProducer.ServerStreamListener
        public void setOnCancelHandler(Runnable runnable) {
            this.delegate.setOnCancelHandler(runnable);
        }

        @Override // com.neo4j.gds.shaded.org.apache.arrow.flight.OutboundStreamListener
        public boolean isReady() {
            return this.delegate.isReady();
        }

        @Override // com.neo4j.gds.shaded.org.apache.arrow.flight.OutboundStreamListener
        public void start(VectorSchemaRoot vectorSchemaRoot, DictionaryProvider dictionaryProvider, IpcOption ipcOption) {
            this.root = vectorSchemaRoot;
            this.delegate.start(vectorSchemaRoot, dictionaryProvider, ipcOption);
        }

        @Override // com.neo4j.gds.shaded.org.apache.arrow.flight.OutboundStreamListener
        public void putNext() {
            this.streamedRecords += this.root.getRowCount();
            this.delegate.putNext();
        }

        @Override // com.neo4j.gds.shaded.org.apache.arrow.flight.OutboundStreamListener
        public void putNext(ArrowBuf arrowBuf) {
            this.streamedRecords += this.root.getRowCount();
            this.delegate.putNext(arrowBuf);
        }

        @Override // com.neo4j.gds.shaded.org.apache.arrow.flight.OutboundStreamListener
        public void putMetadata(ArrowBuf arrowBuf) {
            this.delegate.putMetadata(arrowBuf);
        }

        @Override // com.neo4j.gds.shaded.org.apache.arrow.flight.OutboundStreamListener
        public void error(Throwable th) {
            this.delegate.error(th);
        }

        @Override // com.neo4j.gds.shaded.org.apache.arrow.flight.OutboundStreamListener
        public void completed() {
            this.metrics.entitiesStreamed(this.streamKind, this.streamedRecords);
            this.delegate.completed();
        }
    }

    public ArrowExporter(Concurrency concurrency, BatchSize batchSize, BufferAllocator bufferAllocator, ExportTaskFactorySelector<COMMAND> exportTaskFactorySelector, TaskStore taskStore, Log log, FlightMetrics flightMetrics, Runnable runnable) {
        this.concurrency = concurrency;
        this.batchSize = batchSize;
        this.bufferAllocator = bufferAllocator;
        this.metrics = flightMetrics;
        this.logUpdateFunction = runnable;
        this.exportTaskFactory = exportTaskFactorySelector.select(concurrency);
        this.progressTrackerFactory = new ProgressTrackerFactory(taskStore, log, concurrency);
    }

    public void export(String str, COMMAND command, FlightProducer.ServerStreamListener serverStreamListener, ExecutionHandle executionHandle) throws Exception {
        GraphStoreCatalogEntry graphStoreCatalogEntry = GraphStoreCatalog.get(CatalogRequest.of(str, command.databaseName()), command.graphName());
        GraphStore graphStore = graphStoreCatalogEntry.graphStore();
        ResultStore resultStore = graphStoreCatalogEntry.resultStore();
        ExportTask<?> create = this.exportTaskFactory.create(graphStore, resultStore, str, command);
        JobId jobId = command.jobId();
        long partitionOffset = command.partitionOffset();
        long partitionSize = command.partitionSize();
        ProgressTracker create2 = this.progressTrackerFactory.create(create.task(), jobId, str);
        ExportDriver exportDriver = new ExportDriver(create, this.concurrency, this.batchSize, partitionOffset, partitionSize, this.logUpdateFunction);
        try {
            VectorSchemaRoot create3 = VectorSchemaRoot.create(create.schema(), this.bufferAllocator);
            try {
                if (this.metrics != FlightMetrics.NOOP) {
                    serverStreamListener = new TrackingFlightStreamListener(this.metrics, serverStreamListener, create.streamKind());
                }
                VectorLoader vectorLoader = new VectorLoader(create3);
                exportDriver.initialize(executionHandle);
                if (create.metaData().isPresent()) {
                    new MetaDataWriter(create.metaData().get()).write(serverStreamListener, this.bufferAllocator);
                }
                QueueFlusher createQueueFlusher = createQueueFlusher(vectorLoader, serverStreamListener, create2, exportDriver.bufferManager(), this.bufferAllocator);
                executionHandle.registerCloseableResource(createQueueFlusher);
                create2.beginSubTask();
                createQueueFlusher.start();
                serverStreamListener.start(create3, create.dictionaryProvider(this.bufferAllocator));
                try {
                    try {
                        exportDriver.export(executionHandle, this.bufferAllocator);
                        this.exportTaskFactory.cleanupResultStore(resultStore, jobId);
                        try {
                            createQueueFlusher.signalExportFinished();
                            createQueueFlusher.join();
                            Throwable error = createQueueFlusher.getError();
                            if (error != null) {
                                create2.endSubTaskWithFailure();
                                throw GdsServerExceptions.graphExportInterrupted(error);
                            }
                            create2.endSubTask();
                            if (!exportDriver.isEmpty()) {
                                throw new IllegalStateException("The flusher finished without an exception but there are sill buffers left to flush");
                            }
                            if (create3 != null) {
                                create3.close();
                            }
                            exportDriver.close();
                            serverStreamListener.completed();
                        } catch (Exception e) {
                            create2.endSubTaskWithFailure();
                            throw e;
                        }
                    } catch (Throwable th) {
                        this.exportTaskFactory.cleanupResultStore(resultStore, jobId);
                        throw th;
                    }
                } catch (Throwable th2) {
                    Thread.currentThread().interrupt();
                    createQueueFlusher.interrupt();
                    throw GdsServerExceptions.graphExportInterrupted(th2);
                }
            } finally {
            }
        } catch (Throwable th3) {
            try {
                exportDriver.close();
            } catch (Throwable th4) {
                th3.addSuppressed(th4);
            }
            throw th3;
        }
    }

    private static <BUFFER extends ArrowVectorBuffer<?>> QueueFlusher<BUFFER> createQueueFlusher(VectorLoader vectorLoader, OutboundStreamListener outboundStreamListener, ProgressTracker progressTracker, ExportBufferManager<BUFFER> exportBufferManager, BufferAllocator bufferAllocator) {
        return new QueueFlusher<>(vectorLoader, outboundStreamListener, progressTracker, exportBufferManager, bufferAllocator);
    }
}
