/*
 * Decompiled with CFR 0.152.
 */
package org.apache.solr.llm.texttovector.model;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.EmbeddingModel;
import java.lang.reflect.Method;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.llm.texttovector.store.TextToVectorModelException;

public class SolrTextToVectorModel
implements Accountable {
    private static final long BASE_RAM_BYTES = RamUsageEstimator.shallowSizeOfInstance(SolrTextToVectorModel.class);
    private static final String TIMEOUT_PARAM = "timeout";
    private static final String MAX_SEGMENTS_PER_BATCH_PARAM = "maxSegmentsPerBatch";
    private static final String MAX_RETRIES_PARAM = "maxRetries";
    private final String name;
    private final Map<String, Object> params;
    private final EmbeddingModel textToVector;
    private final int hashCode;

    public static SolrTextToVectorModel getInstance(SolrResourceLoader solrResourceLoader, String className, String name, Map<String, Object> params) throws TextToVectorModelException {
        try {
            Class modelClass = solrResourceLoader.findClass(className, EmbeddingModel.class);
            Object builder = modelClass.getMethod("builder", new Class[0]).invoke(null, new Object[0]);
            if (params != null) {
                Iterator<String> iterator = params.keySet().iterator();
                block12: while (iterator.hasNext()) {
                    String paramName;
                    switch (paramName = iterator.next()) {
                        case "timeout": {
                            Duration timeOut = Duration.ofSeconds((Long)params.get(paramName));
                            builder.getClass().getMethod(paramName, Duration.class).invoke(builder, timeOut);
                            continue block12;
                        }
                        case "maxSegmentsPerBatch": {
                            builder.getClass().getMethod(paramName, Integer.class).invoke(builder, ((Long)params.get(paramName)).intValue());
                            continue block12;
                        }
                        case "maxRetries": {
                            builder.getClass().getMethod(paramName, Integer.class).invoke(builder, ((Long)params.get(paramName)).intValue());
                            continue block12;
                        }
                    }
                    ArrayList<Method> paramNameMatches = new ArrayList<Method>();
                    for (Method method : builder.getClass().getMethods()) {
                        if (!paramName.equals(method.getName()) || method.getParameterCount() != 1) continue;
                        paramNameMatches.add(method);
                    }
                    if (paramNameMatches.size() == 1) {
                        ((Method)paramNameMatches.get(0)).invoke(builder, params.get(paramName));
                        continue;
                    }
                    builder.getClass().getMethod(paramName, String.class).invoke(builder, params.get(paramName).toString());
                }
            }
            EmbeddingModel textToVector = (EmbeddingModel)builder.getClass().getMethod("build", new Class[0]).invoke(builder, new Object[0]);
            return new SolrTextToVectorModel(name, textToVector, params);
        }
        catch (Exception e) {
            throw new TextToVectorModelException("Model loading failed for " + className, e);
        }
    }

    public SolrTextToVectorModel(String name, EmbeddingModel textToVector, Map<String, Object> params) {
        this.name = name;
        this.textToVector = textToVector;
        this.params = params;
        this.hashCode = this.calculateHashCode();
    }

    public float[] vectorise(String text) {
        Embedding vector = (Embedding)this.textToVector.embed(text).content();
        return vector.vector();
    }

    public String toString() {
        return this.getClass().getSimpleName() + "(name=" + this.getName() + ")";
    }

    public long ramBytesUsed() {
        return BASE_RAM_BYTES + RamUsageEstimator.sizeOfObject((Object)this.name) + RamUsageEstimator.sizeOfObject((Object)this.textToVector);
    }

    public int hashCode() {
        return this.hashCode;
    }

    private int calculateHashCode() {
        int prime = 31;
        int result = 1;
        result = 31 * result + Objects.hashCode(this.name);
        result = 31 * result + Objects.hashCode(this.textToVector);
        return result;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof SolrTextToVectorModel)) {
            return false;
        }
        SolrTextToVectorModel other = (SolrTextToVectorModel)obj;
        return Objects.equals(this.textToVector, other.textToVector) && Objects.equals(this.name, other.name);
    }

    public String getName() {
        return this.name;
    }

    public String getEmbeddingModelClassName() {
        return this.textToVector.getClass().getName();
    }

    public Map<String, Object> getParams() {
        return this.params;
    }
}

