package com.neo4j.gds.arrow.server;

import com.neo4j.gds.arrow.core.exceptions.FlightServerException;
import com.neo4j.gds.arrow.core.memory.MemoryManager;
import com.neo4j.gds.arrow.core.metrics.ConnectionInfoMiddleware;
import com.neo4j.gds.arrow.core.metrics.FlightMetrics;
import com.neo4j.gds.arrow.core.monitoring.AutoAborter;
import com.neo4j.gds.arrow.core.process.ProcessRegistry;
import com.neo4j.gds.arrow.server.GdsFlightServerConfig;
import com.neo4j.gds.arrow.server.auth.AuthUtil;
import com.neo4j.gds.arrow.server.auth.AuthenticationStrategy;
import com.neo4j.gds.arrow.server.auth.BearerTokenGenerator;
import com.neo4j.gds.arrow.server.auth.TlsKeyAndCertificate;
import com.neo4j.gds.arrow.server.auth.UUIDBearerTokenGenerator;
import com.neo4j.gds.arrow.server.auth.UserPrivileges;
import com.neo4j.gds.arrow.server.handlers.HandlersRegistry;
import com.neo4j.gds.arrow.server.handlers.v1.V1HandlersBuilder;
import com.neo4j.gds.shaded.com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.neo4j.gds.shaded.org.apache.arrow.flight.FlightProducer;
import com.neo4j.gds.shaded.org.apache.arrow.flight.FlightServer;
import com.neo4j.gds.shaded.org.apache.arrow.flight.Location;
import com.neo4j.gds.shaded.org.apache.arrow.flight.auth2.CallHeaderAuthenticator;
import com.neo4j.gds.shaded.org.apache.arrow.memory.RootAllocator;
import com.neo4j.gds.shaded.org.jetbrains.annotations.Nullable;
import com.neo4j.gds.shaded.org.jetbrains.annotations.TestOnly;
import java.io.IOException;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.IntSupplier;
import org.neo4j.dbms.api.DatabaseManagementService;
import org.neo4j.gds.core.utils.ClockService;
import org.neo4j.gds.core.utils.progress.TaskStoreService;
import org.neo4j.gds.logging.Log;
import org.neo4j.gds.logging.LogAdapter;
import org.neo4j.gds.utils.ExceptionUtil;

/* loaded from: input_file:com/neo4j/gds/arrow/server/GdsFlightServer.class */
public final class GdsFlightServer {
    private final UserPrivileges userPrivileges;
    private final Log log;
    private final DatabaseManagementService dbms;
    private final FlightMetrics flightMetrics;
    private final RootAllocator allocator;
    private final AutoAborter autoAborter;
    private final FlightServer.Builder flightServerBuilder;
    private final IntSupplier batchSizeSupplier;
    private final ProcessRegistry processRegistry = new ProcessRegistry();
    private final ExecutorService executorService = Executors.newCachedThreadPool(new ThreadFactoryBuilder().setNameFormat("gds-arrow-server-executor-%d").build());
    private final TaskStoreService taskStoreService;
    private final MemoryManager memoryManager;
    private FlightServer flightServer;
    private volatile boolean runFlag;

    public static GdsFlightServer fromDbms(DatabaseManagementService databaseManagementService) {
        return create(FlightServerDependencyProvider.fromDbms(databaseManagementService), UUIDBearerTokenGenerator.INSTANCE);
    }

    public static GdsFlightServer create(FlightServerDependencyProvider flightServerDependencyProvider, BearerTokenGenerator bearerTokenGenerator) {
        org.neo4j.logging.Log log = flightServerDependencyProvider.log(GdsFlightServer.class);
        UserPrivileges userPrivileges = flightServerDependencyProvider.userPrivileges();
        return new GdsFlightServer(flightServerDependencyProvider.flightServerConfig(), flightServerDependencyProvider.authenticationStrategy(), AuthUtil.credentialCallHeaderAuthenticator(flightServerDependencyProvider.authenticator(userPrivileges, log), bearerTokenGenerator, userPrivileges, log), userPrivileges, flightServerDependencyProvider.tlsKeyAndCertificate(log), new LogAdapter(log), flightServerDependencyProvider.dbms(), flightServerDependencyProvider.flightMetrics(), flightServerDependencyProvider.memoryManager());
    }

    private GdsFlightServer(GdsFlightServerConfig gdsFlightServerConfig, AuthenticationStrategy authenticationStrategy, CallHeaderAuthenticator callHeaderAuthenticator, UserPrivileges userPrivileges, @Nullable TlsKeyAndCertificate tlsKeyAndCertificate, Log log, DatabaseManagementService databaseManagementService, FlightMetrics flightMetrics, MemoryManager memoryManager) {
        this.userPrivileges = userPrivileges;
        this.log = log;
        this.dbms = databaseManagementService;
        this.flightMetrics = flightMetrics;
        this.memoryManager = memoryManager;
        this.allocator = memoryManager.allocator(log, gdsFlightServerConfig.logAllocations());
        this.autoAborter = new AutoAborter(gdsFlightServerConfig.processTimeoutDuration(), this.processRegistry, log, ClockService.clock());
        this.flightServerBuilder = FlightServer.builder().allocator(this.allocator).executor(this.executorService).middleware(ConnectionInfoMiddleware.KEY, new ConnectionInfoMiddleware.Factory(flightMetrics));
        setLocation(this.flightServerBuilder, withConfiguredAuthentication(authenticationStrategy, callHeaderAuthenticator, Optional.ofNullable(tlsKeyAndCertificate), this.flightServerBuilder, log), gdsFlightServerConfig);
        this.batchSizeSupplier = gdsFlightServerConfig.batchSizeSupplier();
        this.runFlag = false;
        this.taskStoreService = new TaskStoreService(true);
    }

    private static boolean withConfiguredAuthentication(AuthenticationStrategy authenticationStrategy, CallHeaderAuthenticator callHeaderAuthenticator, Optional<TlsKeyAndCertificate> optional, FlightServer.Builder builder, Log log) {
        if (authenticationStrategy.authenticated()) {
            builder.headerAuthenticator(callHeaderAuthenticator);
        }
        if (authenticationStrategy.encrypted()) {
            return ((Boolean) optional.map(tlsKeyAndCertificate -> {
                log.info("GDS Flight server encryption is enabled based on %s.", tlsKeyAndCertificate.name());
                setupTls(optional, builder);
                return true;
            }).orElseGet(() -> {
                log.info("GDS Flight server encryption is not enabled.");
                return false;
            })).booleanValue();
        }
        log.info("GDS Flight server encryption is explicitly disabled.");
        return false;
    }

    private static void setupTls(Optional<TlsKeyAndCertificate> optional, FlightServer.Builder builder) {
        optional.ifPresent(ExceptionUtil.consumer(tlsKeyAndCertificate -> {
            builder.useTls(tlsKeyAndCertificate.publicCertificate(), tlsKeyAndCertificate.privateKey());
        }));
    }

    private static void setLocation(FlightServer.Builder builder, boolean z, GdsFlightServerConfig gdsFlightServerConfig) {
        GdsFlightServerConfig.SocketAddress socketAddress = gdsFlightServerConfig.socketAddress();
        builder.location(z ? Location.forGrpcTls(socketAddress.hostname(), socketAddress.port()) : Location.forGrpcInsecure(socketAddress.hostname(), socketAddress.port()));
    }

    public synchronized void start() throws IOException, FlightServerException {
        if (isRunning()) {
            throw new FlightServerException("Flight server is already running");
        }
        this.flightServer = finalizeFlightServer();
        this.flightServer.start();
        this.autoAborter.start();
        this.runFlag = true;
        this.log.info("GDS Flight server running at %s", this.flightServer.getLocation().getUri());
    }

    public synchronized void stop() throws FlightServerException {
        try {
            if (!isRunning()) {
                throw new FlightServerException("Flight server is not running");
            }
            try {
                this.log.info("Stopping GDS Flight Server...");
                this.log.info(this.memoryManager.allocationSummary());
                this.memoryManager.close();
                this.runFlag = false;
                this.processRegistry.forEach(arrowProcess -> {
                    arrowProcess.abort(new IllegalStateException("Aborting process due to server shutdown"));
                });
                this.autoAborter.close();
                this.flightServer.close();
                shutdownExecutorService();
                this.log.info("GDS Flight Server stopped.");
                this.allocator.close();
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            this.allocator.close();
            throw th;
        }
    }

    private void shutdownExecutorService() throws InterruptedException {
        this.executorService.shutdown();
        if (this.executorService.awaitTermination(10L, TimeUnit.SECONDS)) {
            return;
        }
        this.executorService.shutdownNow();
    }

    public Location location() {
        return this.flightServer.getLocation();
    }

    public void awaitTermination() throws InterruptedException {
        this.flightServer.awaitTermination();
    }

    public boolean isRunning() {
        return this.runFlag;
    }

    @TestOnly
    public ProcessRegistry processRegistry() {
        return this.processRegistry;
    }

    private FlightServer finalizeFlightServer() {
        return this.flightServerBuilder.producer(initializeFlightProducer()).build();
    }

    private FlightProducer initializeFlightProducer() {
        HandlersRegistry handlersRegistry = new HandlersRegistry();
        handlersRegistry.register(new V1HandlersBuilder().userPrivileges(this.userPrivileges).batchSizeSupplier(this.batchSizeSupplier).dbms(this.dbms).allocator(this.allocator).processRegistry(this.processRegistry).log(this.log).taskStoreService(this.taskStoreService).executorService(this.executorService).metrics(this.flightMetrics).build());
        return new GdsFlightProducer(handlersRegistry, this.log, this.flightMetrics);
    }
}
