package org.deeplearning4j.nn.conf.memory;

import java.util.HashMap;
import java.util.Map;
import lombok.NonNull;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.nd4j.linalg.api.buffer.DataBuffer;

/* loaded from: input_file:org/deeplearning4j/nn/conf/memory/LayerMemoryReport.class */
public class LayerMemoryReport extends MemoryReport {
    private String layerName;
    private Class<?> layerType;
    private InputType inputType;
    private InputType outputType;
    private long parameterSize;
    private long updaterStateSize;
    private long workingMemoryFixedInference;
    private long workingMemoryVariableInference;
    private Map<CacheMode, Long> workingMemoryFixedTrain;
    private Map<CacheMode, Long> workingMemoryVariableTrain;
    Map<CacheMode, Long> cacheModeMemFixed;
    Map<CacheMode, Long> cacheModeMemVariablePerEx;

    /* loaded from: input_file:org/deeplearning4j/nn/conf/memory/LayerMemoryReport$Builder.class */
    public static class Builder {
        private String layerName;
        private Class<?> layerType;
        private InputType inputType;
        private InputType outputType;
        private long parameterSize;
        private long updaterStateSize;
        private long workingMemoryFixedInference;
        private long workingMemoryVariableInference;
        private Map<CacheMode, Long> workingMemoryFixedTrain;
        private Map<CacheMode, Long> workingMemoryVariableTrain;
        Map<CacheMode, Long> cacheModeMemFixed;
        Map<CacheMode, Long> cacheModeMemVariablePerEx;

        public Builder(String str, Class<?> cls, InputType inputType, InputType inputType2) {
            this.layerName = str;
            this.layerType = cls;
            this.inputType = inputType;
            this.outputType = inputType2;
        }

        public Builder standardMemory(long j, long j2) {
            this.parameterSize = j;
            this.updaterStateSize = j2;
            return this;
        }

        public Builder workingMemory(long j, long j2, long j3, long j4) {
            return workingMemory(j, j2, MemoryReport.cacheModeMapFor(j3), MemoryReport.cacheModeMapFor(j4));
        }

        public Builder workingMemory(long j, long j2, Map<CacheMode, Long> map, Map<CacheMode, Long> map2) {
            this.workingMemoryFixedInference = j;
            this.workingMemoryVariableInference = j2;
            this.workingMemoryFixedTrain = map;
            this.workingMemoryVariableTrain = map2;
            return this;
        }

        public Builder cacheMemory(long j, long j2) {
            return cacheMemory(MemoryReport.cacheModeMapFor(j), MemoryReport.cacheModeMapFor(j2));
        }

        public Builder cacheMemory(Map<CacheMode, Long> map, Map<CacheMode, Long> map2) {
            this.cacheModeMemFixed = map;
            this.cacheModeMemVariablePerEx = map2;
            return this;
        }

        public LayerMemoryReport build() {
            return new LayerMemoryReport(this);
        }
    }

    protected LayerMemoryReport(Builder builder) {
        this.layerName = builder.layerName;
        this.layerType = builder.layerType;
        this.inputType = builder.inputType;
        this.outputType = builder.outputType;
        this.parameterSize = builder.parameterSize;
        this.updaterStateSize = builder.updaterStateSize;
        this.workingMemoryFixedInference = builder.workingMemoryFixedInference;
        this.workingMemoryVariableInference = builder.workingMemoryVariableInference;
        this.workingMemoryFixedTrain = builder.workingMemoryFixedTrain;
        this.workingMemoryVariableTrain = builder.workingMemoryVariableTrain;
        this.cacheModeMemFixed = builder.cacheModeMemFixed;
        this.cacheModeMemVariablePerEx = builder.cacheModeMemVariablePerEx;
    }

    @Override // org.deeplearning4j.nn.conf.memory.MemoryReport
    public Class<?> getReportClass() {
        return this.layerType;
    }

    @Override // org.deeplearning4j.nn.conf.memory.MemoryReport
    public String getName() {
        return this.layerName;
    }

    @Override // org.deeplearning4j.nn.conf.memory.MemoryReport
    public long getTotalMemoryBytes(int i, @NonNull MemoryUseMode memoryUseMode, @NonNull CacheMode cacheMode, @NonNull DataBuffer.Type type) {
        if (memoryUseMode == null) {
            throw new NullPointerException("memoryUseMode is marked @NonNull but is null");
        }
        if (cacheMode == null) {
            throw new NullPointerException("cacheMode is marked @NonNull but is null");
        }
        if (type == null) {
            throw new NullPointerException("dataType is marked @NonNull but is null");
        }
        long j = 0;
        for (MemoryType memoryType : MemoryType.values()) {
            j += getMemoryBytes(memoryType, i, memoryUseMode, cacheMode, type);
        }
        return j;
    }

    @Override // org.deeplearning4j.nn.conf.memory.MemoryReport
    public long getMemoryBytes(MemoryType memoryType, int i, MemoryUseMode memoryUseMode, CacheMode cacheMode, DataBuffer.Type type) {
        int bytesPerElement = getBytesPerElement(type);
        switch (memoryType) {
            case PARAMETERS:
                return this.parameterSize * bytesPerElement;
            case PARAMATER_GRADIENTS:
                if (memoryUseMode == MemoryUseMode.INFERENCE) {
                    return 0L;
                }
                return this.parameterSize * bytesPerElement;
            case ACTIVATIONS:
                return i * this.outputType.arrayElementsPerExample() * bytesPerElement;
            case ACTIVATION_GRADIENTS:
                if (memoryUseMode == MemoryUseMode.INFERENCE) {
                    return 0L;
                }
                return i * this.inputType.arrayElementsPerExample() * bytesPerElement;
            case UPDATER_STATE:
                if (memoryUseMode == MemoryUseMode.INFERENCE) {
                    return 0L;
                }
                return this.updaterStateSize * bytesPerElement;
            case WORKING_MEMORY_FIXED:
                return memoryUseMode == MemoryUseMode.INFERENCE ? this.workingMemoryFixedInference * bytesPerElement : this.workingMemoryFixedTrain.get(cacheMode).longValue() * bytesPerElement;
            case WORKING_MEMORY_VARIABLE:
                return memoryUseMode == MemoryUseMode.INFERENCE ? this.workingMemoryVariableInference * bytesPerElement : i * this.workingMemoryVariableTrain.get(cacheMode).longValue() * bytesPerElement;
            case CACHED_MEMORY_FIXED:
                if (memoryUseMode == MemoryUseMode.INFERENCE) {
                    return 0L;
                }
                return this.cacheModeMemFixed.get(cacheMode).longValue() * bytesPerElement;
            case CACHED_MEMORY_VARIABLE:
                if (memoryUseMode == MemoryUseMode.INFERENCE) {
                    return 0L;
                }
                return i * this.cacheModeMemVariablePerEx.get(cacheMode).longValue() * bytesPerElement;
            default:
                throw new IllegalStateException("Unknown memory type: " + memoryType);
        }
    }

    @Override // org.deeplearning4j.nn.conf.memory.MemoryReport
    public String toString() {
        return "LayerMemoryReport(layerName=" + this.layerName + ",layerType=" + this.layerType.getSimpleName() + ")";
    }

    public void scale(int i) {
        this.parameterSize *= i;
        this.updaterStateSize *= i;
        this.workingMemoryFixedInference *= i;
        this.workingMemoryVariableInference *= i;
        this.cacheModeMemFixed = scaleEntries(this.cacheModeMemFixed, i);
        this.cacheModeMemVariablePerEx = scaleEntries(this.cacheModeMemVariablePerEx, i);
    }

    private static Map<CacheMode, Long> scaleEntries(Map<CacheMode, Long> map, int i) {
        if (map == null) {
            return null;
        }
        HashMap hashMap = new HashMap();
        for (Map.Entry<CacheMode, Long> entry : map.entrySet()) {
            hashMap.put(entry.getKey(), Long.valueOf(i * entry.getValue().longValue()));
        }
        return hashMap;
    }

    public String getLayerName() {
        return this.layerName;
    }

    public Class<?> getLayerType() {
        return this.layerType;
    }

    public InputType getInputType() {
        return this.inputType;
    }

    public InputType getOutputType() {
        return this.outputType;
    }

    public long getParameterSize() {
        return this.parameterSize;
    }

    public long getUpdaterStateSize() {
        return this.updaterStateSize;
    }

    public long getWorkingMemoryFixedInference() {
        return this.workingMemoryFixedInference;
    }

    public long getWorkingMemoryVariableInference() {
        return this.workingMemoryVariableInference;
    }

    public Map<CacheMode, Long> getWorkingMemoryFixedTrain() {
        return this.workingMemoryFixedTrain;
    }

    public Map<CacheMode, Long> getWorkingMemoryVariableTrain() {
        return this.workingMemoryVariableTrain;
    }

    public Map<CacheMode, Long> getCacheModeMemFixed() {
        return this.cacheModeMemFixed;
    }

    public Map<CacheMode, Long> getCacheModeMemVariablePerEx() {
        return this.cacheModeMemVariablePerEx;
    }

    public void setLayerName(String str) {
        this.layerName = str;
    }

    public void setLayerType(Class<?> cls) {
        this.layerType = cls;
    }

    public void setInputType(InputType inputType) {
        this.inputType = inputType;
    }

    public void setOutputType(InputType inputType) {
        this.outputType = inputType;
    }

    public void setParameterSize(long j) {
        this.parameterSize = j;
    }

    public void setUpdaterStateSize(long j) {
        this.updaterStateSize = j;
    }

    public void setWorkingMemoryFixedInference(long j) {
        this.workingMemoryFixedInference = j;
    }

    public void setWorkingMemoryVariableInference(long j) {
        this.workingMemoryVariableInference = j;
    }

    public void setWorkingMemoryFixedTrain(Map<CacheMode, Long> map) {
        this.workingMemoryFixedTrain = map;
    }

    public void setWorkingMemoryVariableTrain(Map<CacheMode, Long> map) {
        this.workingMemoryVariableTrain = map;
    }

    public void setCacheModeMemFixed(Map<CacheMode, Long> map) {
        this.cacheModeMemFixed = map;
    }

    public void setCacheModeMemVariablePerEx(Map<CacheMode, Long> map) {
        this.cacheModeMemVariablePerEx = map;
    }

    @Override // org.deeplearning4j.nn.conf.memory.MemoryReport
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof LayerMemoryReport)) {
            return false;
        }
        LayerMemoryReport layerMemoryReport = (LayerMemoryReport) obj;
        if (!layerMemoryReport.canEqual(this)) {
            return false;
        }
        String layerName = getLayerName();
        String layerName2 = layerMemoryReport.getLayerName();
        if (layerName == null) {
            if (layerName2 != null) {
                return false;
            }
        } else if (!layerName.equals(layerName2)) {
            return false;
        }
        Class<?> layerType = getLayerType();
        Class<?> layerType2 = layerMemoryReport.getLayerType();
        if (layerType == null) {
            if (layerType2 != null) {
                return false;
            }
        } else if (!layerType.equals(layerType2)) {
            return false;
        }
        InputType inputType = getInputType();
        InputType inputType2 = layerMemoryReport.getInputType();
        if (inputType == null) {
            if (inputType2 != null) {
                return false;
            }
        } else if (!inputType.equals(inputType2)) {
            return false;
        }
        InputType outputType = getOutputType();
        InputType outputType2 = layerMemoryReport.getOutputType();
        if (outputType == null) {
            if (outputType2 != null) {
                return false;
            }
        } else if (!outputType.equals(outputType2)) {
            return false;
        }
        if (getParameterSize() != layerMemoryReport.getParameterSize() || getUpdaterStateSize() != layerMemoryReport.getUpdaterStateSize() || getWorkingMemoryFixedInference() != layerMemoryReport.getWorkingMemoryFixedInference() || getWorkingMemoryVariableInference() != layerMemoryReport.getWorkingMemoryVariableInference()) {
            return false;
        }
        Map<CacheMode, Long> workingMemoryFixedTrain = getWorkingMemoryFixedTrain();
        Map<CacheMode, Long> workingMemoryFixedTrain2 = layerMemoryReport.getWorkingMemoryFixedTrain();
        if (workingMemoryFixedTrain == null) {
            if (workingMemoryFixedTrain2 != null) {
                return false;
            }
        } else if (!workingMemoryFixedTrain.equals(workingMemoryFixedTrain2)) {
            return false;
        }
        Map<CacheMode, Long> workingMemoryVariableTrain = getWorkingMemoryVariableTrain();
        Map<CacheMode, Long> workingMemoryVariableTrain2 = layerMemoryReport.getWorkingMemoryVariableTrain();
        if (workingMemoryVariableTrain == null) {
            if (workingMemoryVariableTrain2 != null) {
                return false;
            }
        } else if (!workingMemoryVariableTrain.equals(workingMemoryVariableTrain2)) {
            return false;
        }
        Map<CacheMode, Long> cacheModeMemFixed = getCacheModeMemFixed();
        Map<CacheMode, Long> cacheModeMemFixed2 = layerMemoryReport.getCacheModeMemFixed();
        if (cacheModeMemFixed == null) {
            if (cacheModeMemFixed2 != null) {
                return false;
            }
        } else if (!cacheModeMemFixed.equals(cacheModeMemFixed2)) {
            return false;
        }
        Map<CacheMode, Long> cacheModeMemVariablePerEx = getCacheModeMemVariablePerEx();
        Map<CacheMode, Long> cacheModeMemVariablePerEx2 = layerMemoryReport.getCacheModeMemVariablePerEx();
        return cacheModeMemVariablePerEx == null ? cacheModeMemVariablePerEx2 == null : cacheModeMemVariablePerEx.equals(cacheModeMemVariablePerEx2);
    }

    @Override // org.deeplearning4j.nn.conf.memory.MemoryReport
    protected boolean canEqual(Object obj) {
        return obj instanceof LayerMemoryReport;
    }

    @Override // org.deeplearning4j.nn.conf.memory.MemoryReport
    public int hashCode() {
        String layerName = getLayerName();
        int hashCode = (1 * 59) + (layerName == null ? 43 : layerName.hashCode());
        Class<?> layerType = getLayerType();
        int hashCode2 = (hashCode * 59) + (layerType == null ? 43 : layerType.hashCode());
        InputType inputType = getInputType();
        int hashCode3 = (hashCode2 * 59) + (inputType == null ? 43 : inputType.hashCode());
        InputType outputType = getOutputType();
        int hashCode4 = (hashCode3 * 59) + (outputType == null ? 43 : outputType.hashCode());
        long parameterSize = getParameterSize();
        int i = (hashCode4 * 59) + ((int) ((parameterSize >>> 32) ^ parameterSize));
        long updaterStateSize = getUpdaterStateSize();
        int i2 = (i * 59) + ((int) ((updaterStateSize >>> 32) ^ updaterStateSize));
        long workingMemoryFixedInference = getWorkingMemoryFixedInference();
        int i3 = (i2 * 59) + ((int) ((workingMemoryFixedInference >>> 32) ^ workingMemoryFixedInference));
        long workingMemoryVariableInference = getWorkingMemoryVariableInference();
        int i4 = (i3 * 59) + ((int) ((workingMemoryVariableInference >>> 32) ^ workingMemoryVariableInference));
        Map<CacheMode, Long> workingMemoryFixedTrain = getWorkingMemoryFixedTrain();
        int hashCode5 = (i4 * 59) + (workingMemoryFixedTrain == null ? 43 : workingMemoryFixedTrain.hashCode());
        Map<CacheMode, Long> workingMemoryVariableTrain = getWorkingMemoryVariableTrain();
        int hashCode6 = (hashCode5 * 59) + (workingMemoryVariableTrain == null ? 43 : workingMemoryVariableTrain.hashCode());
        Map<CacheMode, Long> cacheModeMemFixed = getCacheModeMemFixed();
        int hashCode7 = (hashCode6 * 59) + (cacheModeMemFixed == null ? 43 : cacheModeMemFixed.hashCode());
        Map<CacheMode, Long> cacheModeMemVariablePerEx = getCacheModeMemVariablePerEx();
        return (hashCode7 * 59) + (cacheModeMemVariablePerEx == null ? 43 : cacheModeMemVariablePerEx.hashCode());
    }

    public LayerMemoryReport(String str, Class<?> cls, InputType inputType, InputType inputType2, long j, long j2, long j3, long j4, Map<CacheMode, Long> map, Map<CacheMode, Long> map2, Map<CacheMode, Long> map3, Map<CacheMode, Long> map4) {
        this.layerName = str;
        this.layerType = cls;
        this.inputType = inputType;
        this.outputType = inputType2;
        this.parameterSize = j;
        this.updaterStateSize = j2;
        this.workingMemoryFixedInference = j3;
        this.workingMemoryVariableInference = j4;
        this.workingMemoryFixedTrain = map;
        this.workingMemoryVariableTrain = map2;
        this.cacheModeMemFixed = map3;
        this.cacheModeMemVariablePerEx = map4;
    }

    public LayerMemoryReport() {
    }
}
