package boofcv.deepboof;

import boofcv.abst.scene.ImageClassifier;
import boofcv.struct.image.GrayF32;
import boofcv.struct.image.ImageType;
import boofcv.struct.image.Planar;
import deepboof.Function;
import deepboof.graph.FunctionSequence;
import deepboof.tensors.Tensor_F32;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import org.ddogleg.struct.FastQueue;

/* loaded from: input_file:boofcv/deepboof/BaseImageClassifier.class */
public abstract class BaseImageClassifier implements ImageClassifier<Planar<GrayF32>> {
    protected FunctionSequence<Tensor_F32, Function<Tensor_F32>> network;
    protected int imageSize;
    protected Planar<GrayF32> imageRgb;
    protected Tensor_F32 tensorInput;
    protected Tensor_F32 tensorOutput;
    protected int categoryBest;
    protected List<String> categories = new ArrayList();
    protected ImageType<Planar<GrayF32>> imageType = ImageType.pl(3, GrayF32.class);
    protected ClipAndReduce<Planar<GrayF32>> massage = new ClipAndReduce<>(true, this.imageType);
    protected FastQueue<ImageClassifier.Score> categoryScores = new FastQueue<>(ImageClassifier.Score.class, true);
    Comparator<ImageClassifier.Score> comparator = new Comparator<ImageClassifier.Score>() { // from class: boofcv.deepboof.BaseImageClassifier.1
        @Override // java.util.Comparator
        public int compare(ImageClassifier.Score score, ImageClassifier.Score score2) {
            if (score.score < score2.score) {
                return 1;
            }
            return score.score > score2.score ? -1 : 0;
        }
    };

    public BaseImageClassifier(int i) {
        this.imageSize = i;
        this.imageRgb = new Planar<>(GrayF32.class, i, i, 3);
        this.tensorInput = new Tensor_F32(1, 3, i, i);
    }

    @Override // boofcv.abst.scene.ImageModelBase
    public ImageType<Planar<GrayF32>> getInputType() {
        return this.imageType;
    }

    @Override // boofcv.abst.scene.ImageClassifier
    public void classify(Planar<GrayF32> planar) {
        DataManipulationOps.imageToTensor(preprocess(planar), this.tensorInput, 0);
        innerProcess(this.tensorInput);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Planar<GrayF32> preprocess(Planar<GrayF32> planar) {
        if (planar.width == this.imageSize && planar.height == this.imageSize) {
            this.imageRgb.setTo(planar);
        } else {
            if (planar.width < this.imageSize || planar.height < this.imageSize) {
                throw new IllegalArgumentException("Image width or height is too small");
            }
            this.massage.massage(planar, this.imageRgb);
        }
        return this.imageRgb;
    }

    protected void innerProcess(Tensor_F32 tensor_F32) {
        this.network.process(tensor_F32, this.tensorOutput);
        this.categoryScores.reset();
        double d = -1.7976931348623157E308d;
        this.categoryBest = -1;
        for (int i = 0; i < this.tensorOutput.length(1); i++) {
            double d2 = this.tensorOutput.get(0, i);
            this.categoryScores.grow().set(d2, i);
            if (d2 > d) {
                d = d2;
                this.categoryBest = i;
            }
        }
        Collections.sort(this.categoryScores.toList(), this.comparator);
    }

    @Override // boofcv.abst.scene.ImageClassifier
    public int getBestResult() {
        return this.categoryBest;
    }

    @Override // boofcv.abst.scene.ImageClassifier
    public List<ImageClassifier.Score> getAllResults() {
        return this.categoryScores.toList();
    }

    @Override // boofcv.abst.scene.ImageClassifier
    public List<String> getCategories() {
        return this.categories;
    }

    public Planar<GrayF32> getImageRgb() {
        return this.imageRgb;
    }
}
