package apoc.export.arrow;

import apoc.Extended;
import apoc.Pools;
import apoc.diff.DiffFull;
import apoc.export.util.BatchTransaction;
import apoc.export.util.ExportConfig;
import apoc.export.util.ProgressReporter;
import apoc.result.ProgressInfo;
import apoc.util.FileUtils;
import apoc.util.JsonUtil;
import apoc.util.Util;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.json.JsonWriteFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.net.HttpHeaders;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.OffsetTime;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowFileReader;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.commons.configuration2.tree.DefaultExpressionEngineSymbols;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.math3.analysis.interpolation.MicrosphereInterpolator;
import org.apache.commons.text.lookup.StringLookupFactory;
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Label;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.RelationshipType;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
import org.neo4j.values.storable.DateTimeValue;
import org.neo4j.values.storable.DateValue;
import org.neo4j.values.storable.DurationValue;
import org.neo4j.values.storable.LocalDateTimeValue;
import org.neo4j.values.storable.LocalTimeValue;
import org.neo4j.values.storable.PointValue;
import org.neo4j.values.storable.TimeValue;
import org.neo4j.values.storable.Values;

@Extended
/* loaded from: input_file:apoc/export/arrow/ImportArrow.class */
public class ImportArrow {

    @Context
    public Pools pools;

    @Context
    public GraphDatabaseService db;

    /* loaded from: input_file:apoc/export/arrow/ImportArrow$ArrowConfig.class */
    public static class ArrowConfig {
        private final int batchSize;
        private final Map<String, Object> mapping;

        public ArrowConfig(Map<String, Object> map) {
            map = map == null ? Collections.emptyMap() : map;
            this.mapping = (Map) map.getOrDefault("mapping", Map.of());
            this.batchSize = Util.toInteger(map.getOrDefault("batchSize", Integer.valueOf(MicrosphereInterpolator.DEFAULT_MICROSPHERE_ELEMENTS))).intValue();
        }

        public int getBatchSize() {
            return this.batchSize;
        }

        public Map<String, Object> getMapping() {
            return this.mapping;
        }
    }

    @Procedure(name = "apoc.import.arrow", mode = Mode.WRITE)
    @Description("Imports arrow from the provided arrow file or byte array")
    public Stream<ProgressInfo> importFile(@Name("input") Object obj, @Name(value = "config", defaultValue = "{}") Map<String, Object> map) throws Exception {
        return Stream.of((ProgressInfo) Util.inThread(this.pools, () -> {
            String str = null;
            String str2 = "binary";
            if (obj instanceof String) {
                str = (String) obj;
                str2 = StringLookupFactory.KEY_FILE;
            }
            ArrowConfig arrowConfig = new ArrowConfig(map);
            HashMap hashMap = new HashMap();
            AtomicInteger atomicInteger = new AtomicInteger();
            ArrowReader reader = getReader(obj);
            try {
                VectorSchemaRoot vectorSchemaRoot = reader.getVectorSchemaRoot();
                try {
                    ProgressReporter progressReporter = new ProgressReporter(null, null, new ProgressInfo(str, str2, "arrow"));
                    BatchTransaction batchTransaction = new BatchTransaction(this.db, arrowConfig.getBatchSize(), progressReporter);
                    while (hasElements(atomicInteger, reader, vectorSchemaRoot)) {
                        try {
                            try {
                                Map<String, Object> map2 = (Map) vectorSchemaRoot.getFieldVectors().stream().collect(HashMap::new, (hashMap2, fieldVector) -> {
                                    Object read = read(fieldVector, atomicInteger.get(), arrowConfig);
                                    if (read == null) {
                                        return;
                                    }
                                    hashMap2.put(fieldVector.getName(), read);
                                }, (v0, v1) -> {
                                    v0.putAll(v1);
                                });
                                String str3 = (String) map2.remove(ArrowUtils.FIELD_TYPE.getName());
                                if (str3 == null) {
                                    Node createNode = batchTransaction.getTransaction().createNode((Label[]) Optional.ofNullable((String[]) map2.remove(ArrowUtils.FIELD_LABELS.getName())).map(strArr -> {
                                        return (Label[]) Arrays.stream(strArr).map(Label::label).toArray(i -> {
                                            return new Label[i];
                                        });
                                    }).orElse(new Label[0]));
                                    hashMap.put(Long.valueOf(((Long) map2.remove(ArrowUtils.FIELD_ID.getName())).longValue()), Long.valueOf(createNode.getId()));
                                    addProps(map2, createNode);
                                    progressReporter.update(1L, 0L, map2.size());
                                } else {
                                    addProps(map2, batchTransaction.getTransaction().getNodeById(((Long) hashMap.get(Long.valueOf(((Long) map2.remove(ArrowUtils.FIELD_SOURCE_ID.getName())).longValue()))).longValue()).createRelationshipTo(batchTransaction.getTransaction().getNodeById(((Long) hashMap.get(Long.valueOf(((Long) map2.remove(ArrowUtils.FIELD_TARGET_ID.getName())).longValue()))).longValue()), RelationshipType.withName(str3)));
                                    progressReporter.update(0L, 1L, map2.size());
                                }
                                atomicInteger.incrementAndGet();
                                batchTransaction.increment();
                            } catch (RuntimeException e) {
                                batchTransaction.rollback();
                                throw e;
                            }
                        } catch (Throwable th) {
                            batchTransaction.close();
                            throw th;
                        }
                    }
                    batchTransaction.commit();
                    batchTransaction.close();
                    ProgressInfo total = progressReporter.getTotal();
                    if (vectorSchemaRoot != null) {
                        vectorSchemaRoot.close();
                    }
                    if (reader != null) {
                        reader.close();
                    }
                    return total;
                } finally {
                }
            } catch (Throwable th2) {
                if (reader != null) {
                    try {
                        reader.close();
                    } catch (Throwable th3) {
                        th2.addSuppressed(th3);
                    }
                }
                throw th2;
            }
        }));
    }

    private ArrowReader getReader(Object obj) throws IOException {
        RootAllocator rootAllocator = new RootAllocator();
        return obj instanceof String ? new ArrowFileReader(FileUtils.inputStreamFor(obj, null, null, null).asChannel(), rootAllocator) : new ArrowStreamReader(new ByteArrayInputStream((byte[]) obj), rootAllocator);
    }

    private static boolean hasElements(AtomicInteger atomicInteger, ArrowReader arrowReader, VectorSchemaRoot vectorSchemaRoot) throws IOException {
        if (atomicInteger.get() < vectorSchemaRoot.getRowCount()) {
            return true;
        }
        if (!arrowReader.loadNextBatch()) {
            return false;
        }
        atomicInteger.set(0);
        return true;
    }

    private static Object read(FieldVector fieldVector, int i, ArrowConfig arrowConfig) {
        if (fieldVector.isNull(i)) {
            return null;
        }
        if (fieldVector instanceof BitVector) {
            return Boolean.valueOf(((BitVector) fieldVector).get(i) == 1);
        }
        Object object = fieldVector.getObject(i);
        if ((object instanceof Collection) && ((Collection) object).isEmpty()) {
            return null;
        }
        return toValidValue(object, fieldVector.getName(), arrowConfig.getMapping());
    }

    private void addProps(Map<String, Object> map, Entity entity) {
        Objects.requireNonNull(entity);
        map.forEach(entity::setProperty);
    }

    public static Object toValidValue(Object obj, String str, Map<String, Object> map) {
        Object obj2 = map.get(str);
        if (obj != null && obj2 != null) {
            return convertValue(obj.toString(), obj2.toString());
        }
        if (obj instanceof Collection) {
            return ((List) ((Collection) obj).stream().map(obj3 -> {
                return toValidValue(obj3, str, map);
            }).collect(Collectors.toList())).toArray(new String[0]);
        }
        if (obj instanceof Map) {
            return ((Map) obj).entrySet().stream().collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, entry -> {
                return toValidValue(entry.getValue(), str, map);
            }));
        }
        try {
            Values.of(obj);
            return obj;
        } catch (Exception e) {
            return obj.toString();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Object convertValue(String str, String str2) {
        boolean z = -1;
        switch (str2.hashCode()) {
            case -1927368268:
                if (str2.equals("Duration")) {
                    z = 6;
                    break;
                }
                break;
            case -1600196269:
                if (str2.equals("NO_VALUE")) {
                    z = 16;
                    break;
                }
                break;
            case -97531304:
                if (str2.equals(DiffFull.RELATIONSHIP)) {
                    z = 15;
                    break;
                }
                break;
            case 73679:
                if (str2.equals("Int")) {
                    z = 12;
                    break;
                }
                break;
            case 2086184:
                if (str2.equals("Byte")) {
                    z = 8;
                    break;
                }
                break;
            case 2099062:
                if (str2.equals("Char")) {
                    z = 7;
                    break;
                }
                break;
            case 2122702:
                if (str2.equals(HttpHeaders.DATE)) {
                    z = 5;
                    break;
                }
                break;
            case 2374300:
                if (str2.equals("Long")) {
                    z = 13;
                    break;
                }
                break;
            case 2433570:
                if (str2.equals(DiffFull.NODE)) {
                    z = 14;
                    break;
                }
                break;
            case 2606829:
                if (str2.equals("Time")) {
                    z = 4;
                    break;
                }
                break;
            case 67973692:
                if (str2.equals("Float")) {
                    z = 10;
                    break;
                }
                break;
            case 77292912:
                if (str2.equals("Point")) {
                    z = false;
                    break;
                }
                break;
            case 79860828:
                if (str2.equals("Short")) {
                    z = 11;
                    break;
                }
                break;
            case 798759096:
                if (str2.equals("LocalTime")) {
                    z = 2;
                    break;
                }
                break;
            case 1153828870:
                if (str2.equals("LocalDateTime")) {
                    z = true;
                    break;
                }
                break;
            case 1857393595:
                if (str2.equals("DateTime")) {
                    z = 3;
                    break;
                }
                break;
            case 2052876273:
                if (str2.equals("Double")) {
                    z = 9;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return getPointValue(str);
            case true:
                return LocalDateTimeValue.parse(str).asObjectCopy();
            case true:
                return LocalTimeValue.parse(str).asObjectCopy();
            case true:
                return DateTimeValue.parse(str, () -> {
                    return ZoneId.of("Z");
                }).asObjectCopy();
            case true:
                return TimeValue.parse(str, () -> {
                    return ZoneId.of("Z");
                }).asObjectCopy();
            case true:
                return DateValue.parse(str).asObjectCopy();
            case true:
                return DurationValue.parse(str);
            case true:
                return Character.valueOf(str.charAt(0));
            case true:
                return str.getBytes();
            case true:
                return Double.valueOf(Double.parseDouble(str));
            case true:
                return Float.valueOf(Float.parseFloat(str));
            case true:
                return Short.valueOf(Short.parseShort(str));
            case true:
                return Integer.valueOf(Integer.parseInt(str));
            case true:
                return Long.valueOf(Long.parseLong(str));
            case true:
            case true:
                return JsonUtil.parse(str, null, Map.class);
            case true:
                return null;
            default:
                if (!str2.endsWith("Array")) {
                    return str;
                }
                String removeEnd = StringUtils.removeEnd(StringUtils.removeStart(str, "["), DefaultExpressionEngineSymbols.DEFAULT_ATTRIBUTE_END);
                String replace = str2.replace("Array", "");
                return ((List) Arrays.stream(removeEnd.split(ExportConfig.DEFAULT_DELIM)).map(str3 -> {
                    return convertValue(StringUtils.trim(str3), replace);
                }).collect(Collectors.toList())).toArray(getPrototypeFor(replace));
        }
    }

    private static PointValue getPointValue(String str) {
        try {
            return PointValue.parse(str);
        } catch (RuntimeException e) {
            ObjectMapper disable = new ObjectMapper().disable(new JsonGenerator.Feature[]{JsonWriteFeature.QUOTE_FIELD_NAMES.mappedFeature()});
            try {
                return PointValue.parse(disable.writeValueAsString((Map) disable.readValue(str, Map.class)));
            } catch (JsonProcessingException e2) {
                throw new RuntimeException((Throwable) e2);
            }
        }
    }

    public static Object[] getPrototypeFor(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1927368268:
                if (str.equals("Duration")) {
                    z = 15;
                    break;
                }
                break;
            case -1808118735:
                if (str.equals("String")) {
                    z = 8;
                    break;
                }
                break;
            case -672261858:
                if (str.equals("Integer")) {
                    z = true;
                    break;
                }
                break;
            case 2086184:
                if (str.equals("Byte")) {
                    z = 5;
                    break;
                }
                break;
            case 2099062:
                if (str.equals("Char")) {
                    z = 7;
                    break;
                }
                break;
            case 2122702:
                if (str.equals(HttpHeaders.DATE)) {
                    z = 14;
                    break;
                }
                break;
            case 2374300:
                if (str.equals("Long")) {
                    z = false;
                    break;
                }
                break;
            case 2606829:
                if (str.equals("Time")) {
                    z = 13;
                    break;
                }
                break;
            case 67973692:
                if (str.equals("Float")) {
                    z = 3;
                    break;
                }
                break;
            case 77292912:
                if (str.equals("Point")) {
                    z = 12;
                    break;
                }
                break;
            case 79860828:
                if (str.equals("Short")) {
                    z = 6;
                    break;
                }
                break;
            case 798759096:
                if (str.equals("LocalTime")) {
                    z = 10;
                    break;
                }
                break;
            case 1153828870:
                if (str.equals("LocalDateTime")) {
                    z = 11;
                    break;
                }
                break;
            case 1729365000:
                if (str.equals("Boolean")) {
                    z = 4;
                    break;
                }
                break;
            case 1857393595:
                if (str.equals("DateTime")) {
                    z = 9;
                    break;
                }
                break;
            case 2052876273:
                if (str.equals("Double")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return new Long[0];
            case true:
                return new Integer[0];
            case true:
                return new Double[0];
            case true:
                return new Float[0];
            case true:
                return new Boolean[0];
            case true:
                return new Byte[0];
            case true:
                return new Short[0];
            case true:
                return new Character[0];
            case true:
                return new String[0];
            case true:
                return new ZonedDateTime[0];
            case true:
                return new LocalTime[0];
            case true:
                return new LocalDateTime[0];
            case true:
                return new PointValue[0];
            case true:
                return new OffsetTime[0];
            case true:
                return new LocalDate[0];
            case true:
                return new DurationValue[0];
            default:
                throw new IllegalStateException("Type " + str + " not supported.");
        }
    }
}
