package com.neo4j.gds.arrow.server.export;

import com.neo4j.gds.arrow.core.vectors.ArrowVectorBuffer;
import com.neo4j.gds.shaded.org.apache.arrow.flight.OutboundStreamListener;
import com.neo4j.gds.shaded.org.apache.arrow.memory.ArrowBuf;
import com.neo4j.gds.shaded.org.apache.arrow.memory.BufferAllocator;
import com.neo4j.gds.shaded.org.apache.arrow.vector.FieldVector;
import com.neo4j.gds.shaded.org.apache.arrow.vector.TypeLayout;
import com.neo4j.gds.shaded.org.apache.arrow.vector.ValueVector;
import com.neo4j.gds.shaded.org.apache.arrow.vector.VectorLoader;
import com.neo4j.gds.shaded.org.apache.arrow.vector.ipc.message.ArrowFieldNode;
import com.neo4j.gds.shaded.org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.utils.ExceptionUtil;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:com/neo4j/gds/arrow/server/export/QueueFlusher.class */
class QueueFlusher<BUFFER extends ArrowVectorBuffer<?>> extends Thread implements AutoCloseable {
    private static final AtomicLong NEXT_THREAD_ID = new AtomicLong();
    private final VectorLoader loader;
    private final OutboundStreamListener listener;
    private final ProgressTracker progressTracker;
    private final ExportBufferManager<BUFFER> bufferManager;
    private final BufferAllocator allocator;
    private final AtomicBoolean running;
    private final AtomicReference<Exception> error;
    private ValueVector[] transferVectors;

    /* JADX INFO: Access modifiers changed from: package-private */
    public QueueFlusher(VectorLoader vectorLoader, OutboundStreamListener outboundStreamListener, ProgressTracker progressTracker, ExportBufferManager<BUFFER> exportBufferManager, BufferAllocator bufferAllocator) {
        super("QueueFlusher-" + NEXT_THREAD_ID.getAndIncrement());
        this.loader = vectorLoader;
        this.listener = outboundStreamListener;
        this.progressTracker = progressTracker;
        this.bufferManager = exportBufferManager;
        this.allocator = bufferAllocator;
        this.running = new AtomicBoolean(true);
        this.error = new AtomicReference<>();
    }

    @Override // java.lang.Thread, java.lang.Runnable
    public void run() {
        try {
            flushAll();
        } catch (InterruptedException e) {
            this.running.set(false);
            Thread.currentThread().interrupt();
        } catch (Exception e2) {
            this.running.set(false);
            this.error.set(e2);
        }
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        if (this.transferVectors != null) {
            for (ValueVector valueVector : this.transferVectors) {
                valueVector.close();
            }
        }
        interrupt();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void signalExportFinished() {
        this.running.set(false);
    }

    public Throwable getError() {
        return this.error.get();
    }

    private void flushAll() throws Exception {
        while (true) {
            if (!this.running.get() && this.bufferManager.fullBuffers().isEmpty()) {
                return;
            }
            BUFFER poll = this.bufferManager.fullBuffers().poll(1L, TimeUnit.SECONDS);
            if (poll != null) {
                int batchPosition = poll.batchPosition();
                if (batchPosition > 0) {
                    ValueVector[] valueVectorArr = (ValueVector[]) Arrays.stream(poll.arrowVectors()).map(arrowPropertyVector -> {
                        arrowPropertyVector.prepareForFlush(batchPosition);
                        return arrowPropertyVector.transfer(this.allocator);
                    }).toArray(i -> {
                        return new ValueVector[i];
                    });
                    addBufferToFreeQueue(poll);
                    this.transferVectors = valueVectorArr;
                    flush(batchPosition, this.loader, this.listener, valueVectorArr);
                } else {
                    addBufferToFreeQueue(poll);
                }
            }
        }
    }

    private void addBufferToFreeQueue(BUFFER buffer) {
        if (this.running.get()) {
            this.bufferManager.addFreeBuffer(buffer);
        } else {
            buffer.close();
        }
    }

    private void flush(int i, VectorLoader vectorLoader, OutboundStreamListener outboundStreamListener, ValueVector... valueVectorArr) {
        try {
            ArrowRecordBatch recordBatch = getRecordBatch(i, Arrays.asList(valueVectorArr));
            try {
                vectorLoader.load(recordBatch);
                while (!outboundStreamListener.isReady()) {
                    Thread.onSpinWait();
                }
                outboundStreamListener.putNext();
                this.progressTracker.logProgress(i);
                if (recordBatch != null) {
                    recordBatch.close();
                }
                ExceptionUtil.closeAll(ExceptionUtil.RETHROW_UNCHECKED, valueVectorArr);
            } finally {
            }
        } catch (Throwable th) {
            ExceptionUtil.closeAll(ExceptionUtil.RETHROW_UNCHECKED, valueVectorArr);
            throw th;
        }
    }

    private ArrowRecordBatch getRecordBatch(int i, Iterable<ValueVector> iterable) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Iterator<ValueVector> it = iterable.iterator();
        while (it.hasNext()) {
            appendNodes((FieldVector) it.next(), arrayList, arrayList2);
        }
        return new ArrowRecordBatch(i, arrayList, arrayList2);
    }

    private void appendNodes(FieldVector fieldVector, List<ArrowFieldNode> list, List<ArrowBuf> list2) {
        list.add(new ArrowFieldNode(fieldVector.getValueCount(), fieldVector.getNullCount()));
        List<ArrowBuf> fieldBuffers = fieldVector.getFieldBuffers();
        if (fieldBuffers.size() != TypeLayout.getTypeBufferCount(fieldVector.getField().getType())) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("wrong number of buffers for field %s in vector %s. found: %s", fieldVector.getField(), fieldVector.getClass().getSimpleName(), fieldBuffers));
        }
        list2.addAll(fieldBuffers);
        Iterator<FieldVector> it = fieldVector.getChildrenFromFields().iterator();
        while (it.hasNext()) {
            appendNodes(it.next(), list, list2);
        }
    }
}
