/*
 * Decompiled with CFR 0.152.
 */
package org.ojalgo.tensor;

import org.ojalgo.function.FunctionSet;
import org.ojalgo.scalar.Scalar;
import org.ojalgo.structure.Access1D;
import org.ojalgo.structure.Access2D;
import org.ojalgo.structure.Factory2D;
import org.ojalgo.structure.Mutate2D;

public final class TensorFactory2D<N extends Comparable<N>, T extends Mutate2D>
implements Factory2D<T> {
    private final Factory2D<T> myFactory;

    public static <N extends Comparable<N>, T extends Mutate2D> TensorFactory2D<N, T> of(Factory2D<T> factory) {
        return new TensorFactory2D<N, T>(factory);
    }

    TensorFactory2D(Factory2D<T> factory) {
        this.myFactory = factory;
    }

    public T copy(Access2D<N> elements) {
        Mutate2D retVal = (Mutate2D)this.myFactory.make(elements.countRows(), elements.countColumns());
        for (long i = 0L; i < elements.count(); ++i) {
            retVal.set(i, (Comparable<?>)elements.get(i));
        }
        return (T)retVal;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!super.equals(obj) || !(obj instanceof TensorFactory2D)) {
            return false;
        }
        TensorFactory2D other = (TensorFactory2D)obj;
        return !(this.myFactory == null ? other.myFactory != null : !this.myFactory.equals(other.myFactory));
    }

    public FunctionSet<N> function() {
        return this.myFactory.function();
    }

    public int hashCode() {
        int prime = 31;
        int result = super.hashCode();
        result = 31 * result + (this.myFactory == null ? 0 : this.myFactory.hashCode());
        return result;
    }

    public T identity(int dimensions) {
        Mutate2D retVal = (Mutate2D)this.myFactory.make(dimensions, dimensions);
        N one = this.scalar().cast(1.0);
        for (int ij = 0; ij < dimensions; ++ij) {
            retVal.set((long)ij, (long)ij, (Comparable<?>)one);
        }
        return (T)retVal;
    }

    @Override
    public T make(long rows, long columns) {
        return (T)((Mutate2D)this.myFactory.make(rows, columns));
    }

    public T product(Access1D<N> vector1, Access1D<N> vector2) {
        long rows = vector1.count();
        long cols = vector2.count();
        Mutate2D retVal = (Mutate2D)this.myFactory.make(rows, cols);
        for (long j = 0L; j < cols; ++j) {
            for (long i = 0L; i < rows; ++i) {
                retVal.set(i, j, vector1.doubleValue(i) * vector2.doubleValue(j));
            }
        }
        return (T)retVal;
    }

    public T power2(Access1D<N> vector) {
        return this.product(vector, vector);
    }

    public T kronecker(Access2D<N> matrix1, Access2D<N> matrix2) {
        long rows1 = matrix1.countRows();
        long cols1 = matrix1.countColumns();
        long rows2 = matrix2.countRows();
        long cols2 = matrix2.countColumns();
        long rows = rows1 * rows2;
        long cols = cols1 * cols2;
        Mutate2D retVal = (Mutate2D)this.myFactory.make(rows, cols);
        for (long j1 = 0L; j1 < cols1; ++j1) {
            for (long j2 = 0L; j2 < cols2; ++j2) {
                long j = j1 * cols2 + j2;
                for (long i1 = 0L; i1 < rows1; ++i1) {
                    double val1 = matrix1.doubleValue(i1, j1);
                    for (long i2 = 0L; i2 < rows2; ++i2) {
                        long i = i1 * rows2 + i2;
                        double val2 = matrix2.doubleValue(i2, j2);
                        retVal.set(i, j, val1 * val2);
                    }
                }
            }
        }
        return (T)retVal;
    }

    public Scalar.Factory<N> scalar() {
        return this.myFactory.scalar();
    }

    public T blocks(Access2D<N> ... matrices) {
        long rows = 0L;
        long cols = 0L;
        for (Access2D<N> matrix : matrices) {
            rows += matrix.countRows();
            cols += matrix.countColumns();
        }
        Mutate2D retVal = (Mutate2D)this.myFactory.make(rows, cols);
        long rowOffset = 0L;
        long colOffset = 0L;
        for (Access2D<N> matrix : matrices) {
            long m = matrix.countRows();
            long n = matrix.countColumns();
            int j = 0;
            while ((long)j < n) {
                int i = 0;
                while ((long)i < m) {
                    retVal.set(rowOffset + (long)i, colOffset + (long)j, (Comparable<?>)matrix.get(i, j));
                    ++i;
                }
                ++j;
            }
            rowOffset += m;
            colOffset += n;
        }
        return (T)retVal;
    }
}

