package org.apache.mahout.sparkbindings.blas;

import org.apache.log4j.Logger;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.drm.logical.OpAtB;
import org.apache.mahout.sparkbindings.drm.DrmRddInput;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDD$;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.IndexedSeq;
import scala.collection.immutable.Range;
import scala.math.Ordering;
import scala.math.Ordering$Int$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichDouble$;
import scala.runtime.RichInt$;

/* compiled from: AtB.scala */
/* loaded from: input_file:org/apache/mahout/sparkbindings/blas/AtB$.class */
public final class AtB$ {
    public static final AtB$ MODULE$ = null;
    private final Logger log;

    static {
        new AtB$();
    }

    private final Logger log() {
        return this.log;
    }

    public <A> DrmRddInput<Object> atb(OpAtB<A> opAtB, DrmRddInput<A> drmRddInput, DrmRddInput<A> drmRddInput2, ClassTag<A> classTag) {
        return atb_nograph_mmul(opAtB, drmRddInput, drmRddInput2, opAtB.A().partitioningTag() == opAtB.B().partitioningTag(), classTag);
    }

    public <A> DrmRddInput<Object> atb_nograph(OpAtB<A> opAtB, DrmRddInput<A> drmRddInput, DrmRddInput<A> drmRddInput2, boolean z, ClassTag<A> classTag) {
        RDD<Tuple2<Tuple2<A, Vector>, Tuple2<A, Vector>>> map;
        RDD<Tuple2<A, Vector>> asRowWise = drmRddInput.asRowWise();
        RDD<Tuple2<A, Vector>> asRowWise2 = drmRddInput2.asRowWise();
        int ncol = opAtB.ncol();
        long nrow = opAtB.nrow();
        long nrow2 = opAtB.A().nrow();
        int ceil$extension = (int) RichDouble$.MODULE$.ceil$extension(Predef$.MODULE$.doubleWrapper((ncol * nrow) / RichDouble$.MODULE$.max$extension(Predef$.MODULE$.doubleWrapper((nrow2 * nrow) / asRowWise.partitions().length), (nrow2 * ncol) / asRowWise2.partitions().length)));
        if (log().isDebugEnabled()) {
            log().debug(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"AtB: #parts ", " for ", " x ", " geometry."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(ceil$extension), BoxesRunTime.boxToLong(nrow), BoxesRunTime.boxToInteger(ncol)})));
        }
        if (z) {
            log().debug("A and B for A'B are identically distributed, performing row-wise zip.");
            map = asRowWise.zip(asRowWise2, ClassTag$.MODULE$.apply(Tuple2.class));
        } else {
            log().debug("A and B for A'B are not identically partitioned, performing inner join.");
            ClassTag apply = ClassTag$.MODULE$.apply(Vector.class);
            RDD$.MODULE$.rddToPairRDDFunctions$default$4(asRowWise);
            map = RDD$.MODULE$.rddToPairRDDFunctions(asRowWise, classTag, apply, (Ordering) null).join(asRowWise2, ceil$extension).map(new AtB$$anonfun$6(), ClassTag$.MODULE$.apply(Tuple2.class));
        }
        return org.apache.mahout.sparkbindings.drm.package$.MODULE$.blockifiedRdd2drmRddInput(computeAtBZipped2(map, opAtB.nrow(), opAtB.A().ncol(), opAtB.B().ncol(), ceil$extension, classTag), ClassTag$.MODULE$.Int());
    }

    public <A> boolean atb_nograph$default$4() {
        return false;
    }

    public <A> DrmRddInput<Object> atb_nograph_mmul(OpAtB<A> opAtB, DrmRddInput<A> drmRddInput, DrmRddInput<A> drmRddInput2, boolean z, ClassTag<A> classTag) {
        RDD<Tuple2<Matrix, Matrix>> mapPartitions;
        org.apache.mahout.logging.package$.MODULE$.debug(new AtB$$anonfun$atb_nograph_mmul$1(), log());
        int ncol = opAtB.ncol();
        int safeToNonNegInt = org.apache.mahout.math.drm.package$.MODULE$.safeToNonNegInt(opAtB.nrow());
        int safeToNonNegInt2 = org.apache.mahout.math.drm.package$.MODULE$.safeToNonNegInt(opAtB.A().nrow());
        RDD<Tuple2<A, Vector>> asRowWise = drmRddInput.asRowWise();
        RDD<Tuple2<A, Vector>> asRowWise2 = drmRddInput2.asRowWise();
        int min$extension = RichInt$.MODULE$.min$extension(Predef$.MODULE$.intWrapper((int) RichDouble$.MODULE$.ceil$extension(Predef$.MODULE$.doubleWrapper((ncol * safeToNonNegInt) / RichDouble$.MODULE$.max$extension(Predef$.MODULE$.doubleWrapper((safeToNonNegInt2 * safeToNonNegInt) / asRowWise.partitions().length), (safeToNonNegInt2 * ncol) / asRowWise2.partitions().length)))), safeToNonNegInt);
        if (log().isDebugEnabled()) {
            log().debug(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"AtB mmul: #parts ", " for ", " x ", " geometry."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(min$extension), BoxesRunTime.boxToInteger(safeToNonNegInt), BoxesRunTime.boxToInteger(ncol)})));
        }
        if (z) {
            org.apache.mahout.logging.package$.MODULE$.debug(new AtB$$anonfun$7(), log());
            mapPartitions = drmRddInput.asBlockified(new AtB$$anonfun$1(opAtB)).zip(drmRddInput2.asBlockified(new AtB$$anonfun$2(opAtB)), ClassTag$.MODULE$.apply(Tuple2.class)).map(new AtB$$anonfun$8(), ClassTag$.MODULE$.apply(Tuple2.class));
        } else {
            org.apache.mahout.logging.package$.MODULE$.debug(new AtB$$anonfun$9(), log());
            ClassTag apply = ClassTag$.MODULE$.apply(Vector.class);
            RDD$.MODULE$.rddToPairRDDFunctions$default$4(asRowWise);
            RDD cogroup = RDD$.MODULE$.rddToPairRDDFunctions(asRowWise, classTag, apply, (Ordering) null).cogroup(asRowWise2, RichInt$.MODULE$.max$extension(Predef$.MODULE$.intWrapper(asRowWise.partitions().length), asRowWise2.partitions().length));
            mapPartitions = cogroup.mapPartitions(new AtB$$anonfun$10(ncol, safeToNonNegInt), cogroup.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class));
        }
        return org.apache.mahout.sparkbindings.drm.package$.MODULE$.blockifiedRdd2drmRddInput(computeAtBZipped3(mapPartitions, safeToNonNegInt, safeToNonNegInt, safeToNonNegInt2, min$extension, classTag), ClassTag$.MODULE$.Int());
    }

    public <A> boolean atb_nograph_mmul$default$4() {
        return false;
    }

    public RDD<Tuple2<Object, Matrix>> combineOuterProducts(RDD<Tuple2<Object, Tuple2<Vector, Vector>>> rdd, int i) {
        return RDD$.MODULE$.rddToPairRDDFunctions(rdd, ClassTag$.MODULE$.Int(), ClassTag$.MODULE$.apply(Tuple2.class), Ordering$Int$.MODULE$).combineByKey(new AtB$$anonfun$combineOuterProducts$1(), new AtB$$anonfun$combineOuterProducts$2(), new AtB$$anonfun$combineOuterProducts$3(), i);
    }

    public <A> RDD<Tuple2<int[], Matrix>> computeAtBZipped3(RDD<Tuple2<Matrix, Matrix>> rdd, int i, int i2, int i3, int i4, ClassTag<A> classTag) {
        IndexedSeq<Range> computeEvenSplits = package$.MODULE$.computeEvenSplits(i, i4);
        RDD<Tuple2<int[], Matrix>> map = RDD$.MODULE$.rddToPairRDDFunctions(rdd.flatMap(new AtB$$anonfun$11(i4, computeEvenSplits), ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Int(), ClassTag$.MODULE$.apply(Matrix.class), Ordering$Int$.MODULE$).reduceByKey(new AtB$$anonfun$12(), i4).map(new AtB$$anonfun$13(computeEvenSplits), ClassTag$.MODULE$.apply(Tuple2.class));
        org.apache.mahout.logging.package$.MODULE$.debug(new AtB$$anonfun$computeAtBZipped3$1(map), log());
        return map;
    }

    public <A> RDD<Tuple2<int[], Matrix>> computeAtBZipped2(RDD<Tuple2<Tuple2<A, Vector>, Tuple2<A, Vector>>> rdd, long j, int i, int i2, int i3, ClassTag<A> classTag) {
        int safeToNonNegInt = org.apache.mahout.math.drm.package$.MODULE$.safeToNonNegInt((j - 1) / i3) + 1;
        return combineOuterProducts(rdd.mapPartitions(new AtB$$anonfun$14(i, i3, safeToNonNegInt), rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)), i3).map(new AtB$$anonfun$16(safeToNonNegInt), ClassTag$.MODULE$.apply(Tuple2.class));
    }

    public <A> RDD<Tuple2<int[], Matrix>> computeAtBZipped(RDD<Tuple2<Tuple2<A, Vector>, Tuple2<A, Vector>>> rdd, long j, int i, int i2, int i3, ClassTag<A> classTag) {
        if (log().isDebugEnabled()) {
            log().debug(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"AtBZipped:zipped #parts ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(rdd.partitions().length)})));
            log().debug(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"AtBZipped:Targeted #parts ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i3)})));
        }
        int safeToNonNegInt = org.apache.mahout.math.drm.package$.MODULE$.safeToNonNegInt((j - 1) / i3) + 1;
        RDD<Tuple2<int[], Matrix>> map = RDD$.MODULE$.rddToPairRDDFunctions(rdd.flatMap(new AtB$$anonfun$17(i, i3, safeToNonNegInt), ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Int(), ClassTag$.MODULE$.apply(Matrix.class), Ordering$Int$.MODULE$).reduceByKey(new AtB$$anonfun$18(), i3).map(new AtB$$anonfun$19(safeToNonNegInt), ClassTag$.MODULE$.apply(Tuple2.class));
        if (log().isDebugEnabled()) {
            log().debug(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"AtBZipped #parts ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(map.partitions().length)})));
        }
        return map;
    }

    private AtB$() {
        MODULE$ = this;
        this.log = org.apache.mahout.logging.package$.MODULE$.getLog(getClass());
    }
}
