package org.opensearch.ml.indices;

import java.util.ArrayList;
import java.util.List;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.client.Client;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;

/* loaded from: input_file:org/opensearch/ml/indices/MLInputDatasetHandler.class */
public class MLInputDatasetHandler {

    @Generated
    private static final Logger log = LogManager.getLogger(MLInputDatasetHandler.class);
    private final Client client;

    public DataFrame parseDataFrameInput(MLInputDataset mLInputDataset) {
        if (mLInputDataset.getInputDataType().equals(MLInputDataType.DATA_FRAME)) {
            return ((DataFrameInputDataset) mLInputDataset).getDataFrame();
        }
        throw new IllegalArgumentException("Input dataset is not DATA_FRAME type.");
    }

    public void parseSearchQueryInput(MLInputDataset mLInputDataset, ActionListener<DataFrame> actionListener) {
        if (!mLInputDataset.getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
            throw new IllegalArgumentException("Input dataset is not SEARCH_QUERY type.");
        }
        SearchQueryInputDataset searchQueryInputDataset = (SearchQueryInputDataset) mLInputDataset;
        SearchRequest searchRequest = new SearchRequest();
        searchRequest.source(searchQueryInputDataset.getSearchSourceBuilder());
        List indices = searchQueryInputDataset.getIndices();
        searchRequest.indices((String[]) indices.toArray(new String[indices.size()]));
        this.client.search(searchRequest, ActionListener.wrap(searchResponse -> {
            if (searchResponse == null || searchResponse.getHits() == null || searchResponse.getHits().getTotalHits() == null || searchResponse.getHits().getTotalHits().value == 0) {
                actionListener.onFailure(new IllegalArgumentException("No document found"));
                return;
            }
            SearchHits hits = searchResponse.getHits();
            ArrayList arrayList = new ArrayList();
            for (SearchHit searchHit : hits.getHits()) {
                arrayList.add(searchHit.getSourceAsMap());
            }
            actionListener.onResponse(DataFrameBuilder.load(arrayList));
        }, exc -> {
            log.error("Failed to search" + exc);
            actionListener.onFailure(exc);
        }));
    }

    @Generated
    public MLInputDatasetHandler(Client client) {
        this.client = client;
    }
}
