package ml.dmlc.xgboost4j.scala.spark;

import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.IRabitTracker;
import ml.dmlc.xgboost4j.java.Rabit;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.DMatrix;
import ml.dmlc.xgboost4j.scala.EvalTrait;
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager;
import ml.dmlc.xgboost4j.scala.ExternalCheckpointParams;
import ml.dmlc.xgboost4j.scala.ObjectiveTrait;
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker;
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker$;
import ml.dmlc.xgboost4j.scala.spark.XGBoost;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkParallelismTracker;
import org.apache.spark.TaskContext$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.storage.StorageLevel$;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Serializable;
import scala.Some;
import scala.Tuple2;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterable;
import scala.collection.Iterable$;
import scala.collection.Iterator;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayBuilder;
import scala.collection.mutable.ArrayOps;
import scala.math.Ordering$Int$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;
import scala.util.Either;

/* compiled from: XGBoost.scala */
/* loaded from: input_file:ml/dmlc/xgboost4j/scala/spark/XGBoost$.class */
public final class XGBoost$ implements Serializable {
    public static XGBoost$ MODULE$;
    private final Log logger;

    static {
        new XGBoost$();
    }

    private Log logger() {
        return this.logger;
    }

    private Iterator<LabeledPoint> verifyMissingSetting(Iterator<LabeledPoint> iterator, float f, boolean z) {
        return (f == 0.0f || z) ? iterator : iterator.map(labeledPoint -> {
            if (labeledPoint.indices() != null) {
                throw new RuntimeException(new StringBuilder(328).append("you can only specify missing value as 0.0 (the currently").append(new StringBuilder(71).append(" set value ").append(f).append(") when you have SparseVector or Empty vector as your feature").toString()).append(" format. If you didn't use Spark's VectorAssembler class to build your feature ").append("vector but instead did so in a way that preserves zeros in your feature vector ").append("you can avoid this check by using the 'allow_non_zero_missing parameter'").append(" (only use if you know what you are doing)").toString());
            }
            return labeledPoint;
        });
    }

    private Iterator<LabeledPoint> removeMissingValues(Iterator<LabeledPoint> iterator, float f, Function1<Object, Object> function1) {
        return iterator.map(labeledPoint -> {
            ArrayBuilder.ofInt ofint = new ArrayBuilder.ofInt();
            ArrayBuilder.ofFloat offloat = new ArrayBuilder.ofFloat();
            new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(labeledPoint.values())).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).withFilter(tuple2 -> {
                return BoxesRunTime.boxToBoolean($anonfun$removeMissingValues$2(tuple2));
            }).withFilter(tuple22 -> {
                return BoxesRunTime.boxToBoolean($anonfun$removeMissingValues$3(function1, tuple22));
            }).foreach(tuple23 -> {
                if (tuple23 == null) {
                    throw new MatchError(tuple23);
                }
                float unboxToFloat = BoxesRunTime.unboxToFloat(tuple23._1());
                int _2$mcI$sp = tuple23._2$mcI$sp();
                ofint.$plus$eq(labeledPoint.indices() == null ? _2$mcI$sp : labeledPoint.indices()[_2$mcI$sp]);
                return offloat.$plus$eq(unboxToFloat);
            });
            return labeledPoint.copy(labeledPoint.copy$default$1(), ofint.result(), offloat.result(), labeledPoint.copy$default$4(), labeledPoint.copy$default$5(), labeledPoint.copy$default$6());
        });
    }

    public Iterator<LabeledPoint> processMissingValues(Iterator<LabeledPoint> iterator, float f, boolean z) {
        return !Predef$.MODULE$.float2Float(f).isNaN() ? removeMissingValues(verifyMissingSetting(iterator, f, z), f, f2 -> {
            return f2 != f;
        }) : removeMissingValues(verifyMissingSetting(iterator, f, z), f, f3 -> {
            return !Predef$.MODULE$.float2Float(f3).isNaN();
        });
    }

    private Iterator<LabeledPoint[]> processMissingValuesWithGroup(Iterator<LabeledPoint[]> iterator, float f, boolean z) {
        return !Predef$.MODULE$.float2Float(f).isNaN() ? iterator.map(labeledPointArr -> {
            return (LabeledPoint[]) MODULE$.processMissingValues(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(labeledPointArr)).iterator(), f, z).toArray(ClassTag$.MODULE$.apply(LabeledPoint.class));
        }) : iterator;
    }

    private Option<String> getCacheDirName(boolean z) {
        return z ? new Some(Files.createTempDirectory(new StringBuilder(7).append(TaskContext$.MODULE$.get().stageId()).append("-cache-").append(BoxesRunTime.boxToInteger(TaskContext$.MODULE$.getPartitionId()).toString()).toString(), new FileAttribute[0]).toAbsolutePath().toString()) : None$.MODULE$;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Iterator<Tuple2<Booster, Map<String, float[]>>> buildDistributedBooster(Watches watches, XGBoostExecutionParams xGBoostExecutionParams, java.util.Map<String, String> map, ObjectiveTrait objectiveTrait, EvalTrait evalTrait, Booster booster) {
        if (((DMatrix) watches.toMap().apply("train")).rowNum() == 0) {
            throw new XGBoostError(new StringBuilder(63).append("detected an empty partition in the training data, partition ID:").append(new StringBuilder(1).append(" ").append(TaskContext$.MODULE$.getPartitionId()).toString()).toString());
        }
        String obj = BoxesRunTime.boxToInteger(TaskContext$.MODULE$.getPartitionId()).toString();
        String obj2 = BoxesRunTime.boxToInteger(TaskContext$.MODULE$.get().attemptNumber()).toString();
        map.put("DMLC_TASK_ID", obj);
        map.put("DMLC_NUM_ATTEMPT", obj2);
        map.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false");
        int numRounds = xGBoostExecutionParams.numRounds();
        boolean z = xGBoostExecutionParams.checkpointParam().isDefined() && new StringOps(Predef$.MODULE$.augmentString(obj)).toInt() == 0;
        try {
            try {
                Rabit.init(map);
                int numEarlyStoppingRounds = xGBoostExecutionParams.earlyStoppingParams().numEarlyStoppingRounds();
                float[][] fArr = (float[][]) Array$.MODULE$.tabulate(watches.size(), obj3 -> {
                    return $anonfun$buildDistributedBooster$1(numRounds, BoxesRunTime.unboxToInt(obj3));
                }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)));
                return scala.package$.MODULE$.Iterator().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(z ? ml.dmlc.xgboost4j.scala.XGBoost$.MODULE$.trainAndSaveCheckpoint((DMatrix) watches.toMap().apply("train"), xGBoostExecutionParams.toMap(), numRounds, watches.toMap(), fArr, objectiveTrait, evalTrait, numEarlyStoppingRounds, booster, xGBoostExecutionParams.checkpointParam()) : ml.dmlc.xgboost4j.scala.XGBoost$.MODULE$.train((DMatrix) watches.toMap().apply("train"), xGBoostExecutionParams.toMap(), numRounds, watches.toMap(), fArr, objectiveTrait, evalTrait, numEarlyStoppingRounds, booster)), ((TraversableOnce) watches.toMap().keys().zip(Predef$.MODULE$.wrapRefArray(fArr), Iterable$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms()))}));
            } catch (XGBoostError e) {
                logger().error(new StringBuilder(43).append("XGBooster worker ").append(obj).append(" has failed ").append(obj2).append(" times due to ").toString(), e);
                throw e;
            }
        } finally {
            Rabit.shutdown();
            watches.delete();
        }
    }

    private IRabitTracker startTracker(int i, TrackerConf trackerConf) {
        String trackerImpl = trackerConf.trackerImpl();
        RabitTracker rabitTracker = "scala".equals(trackerImpl) ? new RabitTracker(i, RabitTracker$.MODULE$.$lessinit$greater$default$2(), RabitTracker$.MODULE$.$lessinit$greater$default$3()) : "python".equals(trackerImpl) ? new ml.dmlc.xgboost4j.java.RabitTracker(i) : new ml.dmlc.xgboost4j.java.RabitTracker(i);
        Predef$.MODULE$.require(rabitTracker.start(trackerConf.workerConnectionTimeout()), () -> {
            return "FAULT: Failed to start tracker";
        });
        return rabitTracker;
    }

    private RDD<Tuple2<String, Iterator<LabeledPoint>>> coPartitionNoGroupSets(RDD<LabeledPoint> rdd, Map<String, RDD<LabeledPoint>> map, int i) {
        return (RDD) ((Map) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("train"), rdd)})).$plus$plus(map).map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            RDD rdd2 = (RDD) tuple2._2();
            return rdd2.getNumPartitions() != i ? new Tuple2(str, rdd2.repartition(i, rdd2.repartition$default$2(i))) : new Tuple2(str, rdd2);
        }, Map$.MODULE$.canBuildFrom())).foldLeft(rdd.sparkContext().parallelize(Predef$.MODULE$.wrapRefArray((Object[]) Array$.MODULE$.fill(i, () -> {
            return null;
        }, ClassTag$.MODULE$.apply(Tuple2.class))), i, ClassTag$.MODULE$.apply(Tuple2.class)), (rdd2, tuple22) -> {
            Tuple2 tuple22 = new Tuple2(rdd2, tuple22);
            if (tuple22 != null) {
                RDD rdd2 = (RDD) tuple22._1();
                Tuple2 tuple23 = (Tuple2) tuple22._2();
                if (tuple23 != null) {
                    String str = (String) tuple23._1();
                    return rdd2.zipPartitions((RDD) tuple23._2(), (iterator, iterator2) -> {
                        if (iterator2.hasNext()) {
                            Tuple2[] tuple2Arr = (Tuple2[]) iterator.toArray(ClassTag$.MODULE$.apply(Tuple2.class));
                            return new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).head() != null ? new XGBoost.IteratorWrapper((Tuple2[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).$colon$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), iterator2), ClassTag$.MODULE$.apply(Tuple2.class))) : new XGBoost.IteratorWrapper(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), iterator2)});
                        }
                        MODULE$.logger().error("when specifying eval sets as dataframes, you have to ensure that the number of elements in each dataframe is larger than the number of workers");
                        throw new Exception("too few elements in evaluation sets");
                    }, ClassTag$.MODULE$.apply(LabeledPoint.class), ClassTag$.MODULE$.apply(Tuple2.class));
                }
            }
            throw new MatchError(tuple22);
        });
    }

    private RDD<Tuple2<Booster, Map<String, float[]>>> trainForNonRanking(RDD<LabeledPoint> rdd, XGBoostExecutionParams xGBoostExecutionParams, java.util.Map<String, String> map, Booster booster, Map<String, RDD<LabeledPoint>> map2) {
        if (map2.isEmpty()) {
            return rdd.mapPartitions(iterator -> {
                return MODULE$.buildDistributedBooster(Watches$.MODULE$.buildWatches(xGBoostExecutionParams, MODULE$.processMissingValues(iterator, xGBoostExecutionParams.missing(), xGBoostExecutionParams.allowNonZeroForMissing()), MODULE$.getCacheDirName(xGBoostExecutionParams.useExternalMemory())), xGBoostExecutionParams, map, xGBoostExecutionParams.obj(), xGBoostExecutionParams.eval(), booster);
            }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
        }
        RDD<Tuple2<String, Iterator<LabeledPoint>>> coPartitionNoGroupSets = coPartitionNoGroupSets(rdd, map2, xGBoostExecutionParams.numWorkers());
        return coPartitionNoGroupSets.mapPartitions(iterator2 -> {
            return MODULE$.buildDistributedBooster(Watches$.MODULE$.buildWatches(iterator2.map(tuple2 -> {
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                return new Tuple2((String) tuple2._1(), MODULE$.processMissingValues((Iterator) tuple2._2(), xGBoostExecutionParams.missing(), xGBoostExecutionParams.allowNonZeroForMissing()));
            }), MODULE$.getCacheDirName(xGBoostExecutionParams.useExternalMemory())), xGBoostExecutionParams, map, xGBoostExecutionParams.obj(), xGBoostExecutionParams.eval(), booster);
        }, coPartitionNoGroupSets.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
    }

    private RDD<Tuple2<Booster, Map<String, float[]>>> trainForRanking(RDD<LabeledPoint[]> rdd, XGBoostExecutionParams xGBoostExecutionParams, java.util.Map<String, String> map, Booster booster, Map<String, RDD<LabeledPoint>> map2) {
        if (map2.isEmpty()) {
            return rdd.mapPartitions(iterator -> {
                return MODULE$.buildDistributedBooster(Watches$.MODULE$.buildWatchesWithGroup(xGBoostExecutionParams, MODULE$.processMissingValuesWithGroup(iterator, xGBoostExecutionParams.missing(), xGBoostExecutionParams.allowNonZeroForMissing()), MODULE$.getCacheDirName(xGBoostExecutionParams.useExternalMemory())), xGBoostExecutionParams, map, xGBoostExecutionParams.obj(), xGBoostExecutionParams.eval(), booster);
            }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
        }
        RDD<Tuple2<String, Iterator<LabeledPoint[]>>> coPartitionGroupSets = coPartitionGroupSets(rdd, map2, xGBoostExecutionParams.numWorkers());
        return coPartitionGroupSets.mapPartitions(iterator2 -> {
            return MODULE$.buildDistributedBooster(Watches$.MODULE$.buildWatchesWithGroup(iterator2.map(tuple2 -> {
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                return new Tuple2((String) tuple2._1(), MODULE$.processMissingValuesWithGroup((Iterator) tuple2._2(), xGBoostExecutionParams.missing(), xGBoostExecutionParams.allowNonZeroForMissing()));
            }), MODULE$.getCacheDirName(xGBoostExecutionParams.useExternalMemory())), xGBoostExecutionParams, map, xGBoostExecutionParams.obj(), xGBoostExecutionParams.eval(), booster);
        }, coPartitionGroupSets.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
    }

    private RDD<?> cacheData(boolean z, RDD<?> rdd) {
        return z ? rdd.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK()) : rdd;
    }

    private Either<RDD<LabeledPoint[]>, RDD<LabeledPoint>> composeInputData(RDD<LabeledPoint> rdd, boolean z, boolean z2, int i) {
        if (!z2) {
            return scala.package$.MODULE$.Right().apply(cacheData(z, rdd));
        }
        return scala.package$.MODULE$.Left().apply(cacheData(z, repartitionForTrainingGroup(rdd, i)));
    }

    public Tuple2<Booster, Map<String, float[]>> trainDistributed(RDD<LabeledPoint> rdd, Map<String, Object> map, boolean z, Map<String, RDD<LabeledPoint>> map2) throws XGBoostError {
        logger().info(new StringBuilder(34).append("Running XGBoost ").append(package$.MODULE$.VERSION()).append(" with parameters:\n").append(map.mkString("\n")).toString());
        XGBoostExecutionParams buildXGBRuntimeParams = new XGBoostExecutionParamsFactory(map, rdd.sparkContext()).buildXGBRuntimeParams();
        SparkContext sparkContext = rdd.sparkContext();
        Either<RDD<LabeledPoint[]>, RDD<LabeledPoint>> composeInputData = composeInputData(rdd, buildXGBRuntimeParams.cacheTrainingSet(), z, buildXGBRuntimeParams.numWorkers());
        Booster booster = (Booster) buildXGBRuntimeParams.checkpointParam().map(externalCheckpointParams -> {
            ExternalCheckpointManager externalCheckpointManager = new ExternalCheckpointManager(externalCheckpointParams.checkpointPath(), FileSystem.get(sparkContext.hadoopConfiguration()));
            externalCheckpointManager.cleanUpHigherVersions(buildXGBRuntimeParams.numRounds());
            return externalCheckpointManager.loadCheckpointAsScalaBooster();
        }).orNull(Predef$.MODULE$.$conforms());
        try {
            try {
                IRabitTracker startTracker = startTracker(buildXGBRuntimeParams.numWorkers(), buildXGBRuntimeParams.trackerConf());
                try {
                    SparkParallelismTracker sparkParallelismTracker = new SparkParallelismTracker(sparkContext, buildXGBRuntimeParams.timeoutRequestWorkers(), buildXGBRuntimeParams.numWorkers());
                    java.util.Map<String, String> workerEnvs = startTracker.getWorkerEnvs();
                    final RDD<Tuple2<Booster, Map<String, float[]>>> trainForRanking = z ? trainForRanking((RDD) composeInputData.left().get(), buildXGBRuntimeParams, workerEnvs, booster, map2) : trainForNonRanking((RDD) composeInputData.right().get(), buildXGBRuntimeParams, workerEnvs, booster, map2);
                    Thread thread = new Thread(trainForRanking) { // from class: ml.dmlc.xgboost4j.scala.spark.XGBoost$$anon$1
                        private final RDD boostersAndMetrics$1;

                        @Override // java.lang.Thread, java.lang.Runnable
                        public void run() {
                            this.boostersAndMetrics$1.foreachPartition(iterator -> {
                                () -> {
                                    return iterator;
                                };
                                return BoxedUnit.UNIT;
                            });
                        }

                        {
                            this.boostersAndMetrics$1 = trainForRanking;
                        }
                    };
                    thread.setUncaughtExceptionHandler(startTracker);
                    thread.start();
                    int unboxToInt = BoxesRunTime.unboxToInt(sparkParallelismTracker.execute(() -> {
                        return startTracker.waitFor(0L);
                    }));
                    logger().info(new StringBuilder(29).append("Rabit returns with exit code ").append(unboxToInt).toString());
                    Tuple2<Booster, Map<String, float[]>> postTrackerReturnProcessing = postTrackerReturnProcessing(unboxToInt, trainForRanking, thread);
                    if (postTrackerReturnProcessing == null) {
                        throw new MatchError(postTrackerReturnProcessing);
                    }
                    Tuple2 tuple2 = new Tuple2((Booster) postTrackerReturnProcessing._1(), (Map) postTrackerReturnProcessing._2());
                    Tuple2 tuple22 = new Tuple2((Booster) tuple2._1(), (Map) tuple2._2());
                    if (tuple22 == null) {
                        throw new MatchError(tuple22);
                    }
                    Tuple2 tuple23 = new Tuple2((Booster) tuple22._1(), (Map) tuple22._2());
                    Booster booster2 = (Booster) tuple23._1();
                    Map map3 = (Map) tuple23._2();
                    buildXGBRuntimeParams.checkpointParam().foreach(externalCheckpointParams2 -> {
                        $anonfun$trainDistributed$3(buildXGBRuntimeParams, sparkContext, externalCheckpointParams2);
                        return BoxedUnit.UNIT;
                    });
                    return new Tuple2<>(booster2, map3);
                } finally {
                    startTracker.stop();
                }
            } catch (Throwable th) {
                logger().error("the job was aborted due to ", th);
                rdd.sparkContext().stop();
                throw th;
            }
        } finally {
            uncacheTrainingData(buildXGBRuntimeParams.cacheTrainingSet(), composeInputData);
        }
    }

    public boolean trainDistributed$default$3() {
        return false;
    }

    public Map<String, RDD<LabeledPoint>> trainDistributed$default$4() {
        return Predef$.MODULE$.Map().apply(Nil$.MODULE$);
    }

    private void uncacheTrainingData(boolean z, Either<RDD<LabeledPoint[]>, RDD<LabeledPoint>> either) {
        if (z) {
            if (either.isLeft()) {
                RDD rdd = (RDD) either.left().get();
                rdd.unpersist(rdd.unpersist$default$1());
            } else {
                RDD rdd2 = (RDD) either.right().get();
                rdd2.unpersist(rdd2.unpersist$default$1());
            }
        }
    }

    private RDD<LabeledPoint[]> aggByGroupInfo(RDD<LabeledPoint> rdd) {
        return rdd.mapPartitions(iterator -> {
            return new LabeledPointGroupIterator(iterator);
        }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(XGBLabeledPointGroup.class)).filter(xGBLabeledPointGroup -> {
            return BoxesRunTime.boxToBoolean($anonfun$aggByGroupInfo$2(xGBLabeledPointGroup));
        }).map(xGBLabeledPointGroup2 -> {
            return xGBLabeledPointGroup2.points();
        }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(LabeledPoint.class))).union(rdd.mapPartitions(iterator2 -> {
            return new LabeledPointGroupIterator(iterator2);
        }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(XGBLabeledPointGroup.class)).filter(xGBLabeledPointGroup3 -> {
            return BoxesRunTime.boxToBoolean(xGBLabeledPointGroup3.isEdgeGroup());
        }).map(xGBLabeledPointGroup4 -> {
            return new Tuple2(BoxesRunTime.boxToInteger(TaskContext$.MODULE$.getPartitionId()), xGBLabeledPointGroup4);
        }, ClassTag$.MODULE$.apply(Tuple2.class)).groupBy(tuple2 -> {
            return BoxesRunTime.boxToInteger($anonfun$aggByGroupInfo$7(tuple2));
        }, ClassTag$.MODULE$.Int()).map(tuple22 -> {
            return (LabeledPoint[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) ((Iterable) tuple22._2()).toArray(ClassTag$.MODULE$.apply(Tuple2.class)))).sortBy(tuple22 -> {
                return BoxesRunTime.boxToInteger(tuple22._1$mcI$sp());
            }, Ordering$Int$.MODULE$))).flatMap(tuple23 -> {
                return new ArrayOps.ofRef($anonfun$aggByGroupInfo$10(tuple23));
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(LabeledPoint.class)));
        }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(LabeledPoint.class))));
    }

    public RDD<LabeledPoint[]> repartitionForTrainingGroup(RDD<LabeledPoint> rdd, int i) {
        RDD<LabeledPoint[]> aggByGroupInfo = aggByGroupInfo(rdd);
        logger().info(new StringBuilder(48).append("repartitioning training group set to ").append(i).append(" partitions").toString());
        return aggByGroupInfo.repartition(i, aggByGroupInfo.repartition$default$2(i));
    }

    private RDD<Tuple2<String, Iterator<LabeledPoint[]>>> coPartitionGroupSets(RDD<LabeledPoint[]> rdd, Map<String, RDD<LabeledPoint>> map, int i) {
        return (RDD) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("train"), rdd)})).$plus$plus((GenTraversableOnce) map.map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            RDD<LabeledPoint[]> aggByGroupInfo = MODULE$.aggByGroupInfo((RDD) tuple2._2());
            return aggByGroupInfo.getNumPartitions() != i ? Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), aggByGroupInfo.repartition(i, aggByGroupInfo.repartition$default$2(i))) : Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), aggByGroupInfo);
        }, Map$.MODULE$.canBuildFrom())).foldLeft(rdd.sparkContext().parallelize(Predef$.MODULE$.wrapRefArray((Object[]) Array$.MODULE$.fill(i, () -> {
            return null;
        }, ClassTag$.MODULE$.apply(Tuple2.class))), i, ClassTag$.MODULE$.apply(Tuple2.class)), (rdd2, tuple22) -> {
            Tuple2 tuple22 = new Tuple2(rdd2, tuple22);
            if (tuple22 != null) {
                RDD rdd2 = (RDD) tuple22._1();
                Tuple2 tuple23 = (Tuple2) tuple22._2();
                if (tuple23 != null) {
                    String str = (String) tuple23._1();
                    return rdd2.zipPartitions((RDD) tuple23._2(), (iterator, iterator2) -> {
                        if (iterator2.hasNext()) {
                            Tuple2[] tuple2Arr = (Tuple2[]) iterator.toArray(ClassTag$.MODULE$.apply(Tuple2.class));
                            return new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).head() != null ? new XGBoost.IteratorWrapper((Tuple2[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).$colon$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), iterator2), ClassTag$.MODULE$.apply(Tuple2.class))) : new XGBoost.IteratorWrapper(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), iterator2)});
                        }
                        MODULE$.logger().error("when specifying eval sets as dataframes, you have to ensure that the number of elements in each dataframe is larger than the number of workers");
                        throw new Exception("too few elements in evaluation sets");
                    }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(LabeledPoint.class)), ClassTag$.MODULE$.apply(Tuple2.class));
                }
            }
            throw new MatchError(tuple22);
        });
    }

    private Tuple2<Booster, Map<String, float[]>> postTrackerReturnProcessing(int i, RDD<Tuple2<Booster, Map<String, float[]>>> rdd, Thread thread) {
        if (i != 0) {
            try {
                if (thread.isAlive()) {
                    thread.interrupt();
                }
            } catch (InterruptedException unused) {
                logger().info("spark job thread is interrupted");
            }
            throw new XGBoostError("XGBoostModel training failed");
        }
        thread.join();
        Tuple2 tuple2 = (Tuple2) rdd.first();
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2((Booster) tuple2._1(), (Map) tuple2._2());
        Booster booster = (Booster) tuple22._1();
        Map map = (Map) tuple22._2();
        rdd.unpersist(false);
        return new Tuple2<>(booster, map);
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ boolean $anonfun$removeMissingValues$2(Tuple2 tuple2) {
        return tuple2 != null;
    }

    public static final /* synthetic */ boolean $anonfun$removeMissingValues$3(Function1 function1, Tuple2 tuple2) {
        if (tuple2 != null) {
            return function1.apply$mcZF$sp(BoxesRunTime.unboxToFloat(tuple2._1()));
        }
        throw new MatchError(tuple2);
    }

    public static final /* synthetic */ float[] $anonfun$buildDistributedBooster$1(int i, int i2) {
        return (float[]) Array$.MODULE$.ofDim(i, ClassTag$.MODULE$.Float());
    }

    public static final /* synthetic */ void $anonfun$trainDistributed$3(XGBoostExecutionParams xGBoostExecutionParams, SparkContext sparkContext, ExternalCheckpointParams externalCheckpointParams) {
        if (((ExternalCheckpointParams) xGBoostExecutionParams.checkpointParam().get()).skipCleanCheckpoint()) {
            return;
        }
        new ExternalCheckpointManager(externalCheckpointParams.checkpointPath(), FileSystem.get(sparkContext.hadoopConfiguration())).cleanPath();
    }

    public static final /* synthetic */ boolean $anonfun$aggByGroupInfo$2(XGBLabeledPointGroup xGBLabeledPointGroup) {
        return !xGBLabeledPointGroup.isEdgeGroup();
    }

    public static final /* synthetic */ int $anonfun$aggByGroupInfo$7(Tuple2 tuple2) {
        return ((XGBLabeledPointGroup) tuple2._2()).groupId();
    }

    public static final /* synthetic */ Object[] $anonfun$aggByGroupInfo$10(Tuple2 tuple2) {
        return Predef$.MODULE$.refArrayOps(((XGBLabeledPointGroup) tuple2._2()).points());
    }

    private XGBoost$() {
        MODULE$ = this;
        this.logger = LogFactory.getLog("XGBoostSpark");
    }
}
