# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
#
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
#
#  Licensed to Elasticsearch B.V. under one or more contributor
#  license agreements. See the NOTICE file distributed with
#  this work for additional information regarding copyright
#  ownership. Elasticsearch B.V. licenses this file to you under
#  the Apache License, Version 2.0 (the "License"); you may
#  not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
# 	http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an
#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#  KIND, either express or implied.  See the License for the
#  specific language governing permissions and limitations
#  under the License.


import collections.abc as collections_abc
from typing import Any, Optional

from .response.aggs import AggResponse, BucketData, FieldBucketData, TopHitsData
from .utils import DslBase


def A(  # pylint: disable=invalid-name
    name_or_agg: Any, filter: Any = None, **params: Any
) -> Any:
    if filter is not None:
        if name_or_agg != "filter":
            raise ValueError(
                "Aggregation %r doesn't accept positional argument 'filter'."
                % name_or_agg
            )
        params["filter"] = filter

    # {"terms": {"field": "tags"}, "aggs": {...}}
    if isinstance(name_or_agg, collections_abc.Mapping):
        if params:
            raise ValueError("A() cannot accept parameters when passing in a dict.")
        # copy to avoid modifying in-place
        agg = name_or_agg.copy()  # type: ignore
        # pop out nested aggs
        aggs = agg.pop("aggs", None)
        # pop out meta data
        meta = agg.pop("meta", None)
        # should be {"terms": {"field": "tags"}}
        if len(agg) != 1:
            raise ValueError(
                'A() can only accept dict with an aggregation ({"terms": {...}}). '
                "Instead it got (%r)" % name_or_agg
            )
        agg_type, params = agg.popitem()
        if aggs:
            params = params.copy()
            params["aggs"] = aggs
        if meta:
            params = params.copy()
            params["meta"] = meta
        return Agg.get_dsl_class(agg_type)(_expand__to_dot=False, **params)

    # Terms(...) just return the nested agg
    elif isinstance(name_or_agg, Agg):
        if params:
            raise ValueError(
                "A() cannot accept parameters when passing in an Agg object."
            )
        return name_or_agg

    # "terms", field="tags"
    return Agg.get_dsl_class(name_or_agg)(**params)


class Agg(DslBase):
    _type_name: str = "agg"
    _type_shortcut = staticmethod(A)
    name: Optional[str] = None

    def __contains__(self, key: Any) -> bool:
        return False

    def to_dict(self) -> Any:
        d = super(Agg, self).to_dict()
        if "meta" in d[self.name]:
            d["meta"] = d[self.name].pop("meta")
        return d

    def result(self, search: Any, data: Any) -> Any:
        return AggResponse(self, search, data)


class AggBase(object):
    _param_defs = {
        "aggs": {"type": "agg", "hash": True},
    }

    def __contains__(self: Any, key: Any) -> bool:
        return key in self._params.get("aggs", {})

    def __getitem__(self: Any, agg_name: Any) -> Any:
        agg = self._params.setdefault("aggs", {})[agg_name]  # propagate KeyError

        # make sure we're not mutating a shared state - whenever accessing a
        # bucket, return a shallow copy of it to be safe
        if isinstance(agg, Bucket):
            agg = A(agg.name, **agg._params)
            # be sure to store the copy so any modifications to it will affect us
            self._params["aggs"][agg_name] = agg

        return agg

    def __setitem__(self: Any, agg_name: str, agg: Any) -> None:
        self.aggs[agg_name] = A(agg)

    def __iter__(self: Any) -> Any:
        return iter(self.aggs)

    def _agg(
        self: Any, bucket: Any, name: Any, agg_type: Any, *args: Any, **params: Any
    ) -> Any:
        agg = self[name] = A(agg_type, *args, **params)

        # For chaining - when creating new buckets return them...
        if bucket:
            return agg
        # otherwise return self._base so we can keep chaining
        else:
            return self._base

    def metric(self: Any, name: Any, agg_type: Any, *args: Any, **params: Any) -> Any:
        return self._agg(False, name, agg_type, *args, **params)

    def bucket(self: Any, name: Any, agg_type: Any, *args: Any, **params: Any) -> Any:
        return self._agg(True, name, agg_type, *args, **params)

    def pipeline(self: Any, name: Any, agg_type: Any, *args: Any, **params: Any) -> Any:
        return self._agg(False, name, agg_type, *args, **params)

    def result(self: Any, search: Any, data: Any) -> Any:
        return BucketData(self, search, data)


class Bucket(AggBase, Agg):
    def __init__(self, **params: Any) -> None:
        super(Bucket, self).__init__(**params)
        # remember self for chaining
        self._base = self

    def to_dict(self) -> Any:
        d = super(AggBase, self).to_dict()
        if "aggs" in d[self.name]:
            d["aggs"] = d[self.name].pop("aggs")
        return d


class Filter(Bucket):
    name: Optional[str] = "filter"
    _param_defs = {
        "filter": {"type": "query"},
        "aggs": {"type": "agg", "hash": True},
    }

    def __init__(self, filter: Any = None, **params: Any) -> None:
        if filter is not None:
            params["filter"] = filter
        super(Filter, self).__init__(**params)

    def to_dict(self) -> Any:
        d = super(Filter, self).to_dict()
        d[self.name].update(d[self.name].pop("filter", {}))
        return d


class Pipeline(Agg):
    pass


# bucket aggregations
class Filters(Bucket):
    name: str = "filters"
    _param_defs = {
        "filters": {"type": "query", "hash": True},
        "aggs": {"type": "agg", "hash": True},
    }


class Children(Bucket):
    name = "children"


class Parent(Bucket):
    name = "parent"


class DateHistogram(Bucket):
    name = "date_histogram"

    def result(self, search: Any, data: Any) -> Any:
        return FieldBucketData(self, search, data)


class AutoDateHistogram(DateHistogram):
    name = "auto_date_histogram"


class DateRange(Bucket):
    name = "date_range"


class GeoDistance(Bucket):
    name = "geo_distance"


class GeohashGrid(Bucket):
    name = "geohash_grid"


class GeotileGrid(Bucket):
    name = "geotile_grid"


class GeoCentroid(Bucket):
    name = "geo_centroid"


class Global(Bucket):
    name = "global"


class Histogram(Bucket):
    name = "histogram"

    def result(self, search: Any, data: Any) -> Any:
        return FieldBucketData(self, search, data)


class IPRange(Bucket):
    name = "ip_range"


class Missing(Bucket):
    name = "missing"


class Nested(Bucket):
    name = "nested"


class Range(Bucket):
    name = "range"


class RareTerms(Bucket):
    name = "rare_terms"

    def result(self, search: Any, data: Any) -> Any:
        return FieldBucketData(self, search, data)


class ReverseNested(Bucket):
    name = "reverse_nested"


class SignificantTerms(Bucket):
    name = "significant_terms"


class SignificantText(Bucket):
    name = "significant_text"


class Terms(Bucket):
    name = "terms"

    def result(self, search: Any, data: Any) -> Any:
        return FieldBucketData(self, search, data)


class Sampler(Bucket):
    name = "sampler"


class DiversifiedSampler(Bucket):
    name = "diversified_sampler"


class Composite(Bucket):
    name = "composite"
    _param_defs = {
        "sources": {"type": "agg", "hash": True, "multi": True},
        "aggs": {"type": "agg", "hash": True},
    }


class VariableWidthHistogram(Bucket):
    name = "variable_width_histogram"

    def result(self, search: Any, data: Any) -> Any:
        return FieldBucketData(self, search, data)


# metric aggregations
class TopHits(Agg):
    name = "top_hits"

    def result(self, search: Any, data: Any) -> Any:
        return TopHitsData(self, search, data)


class Avg(Agg):
    name = "avg"


class WeightedAvg(Agg):
    name = "weighted_avg"


class Cardinality(Agg):
    name = "cardinality"


class ExtendedStats(Agg):
    name = "extended_stats"


class Boxplot(Agg):
    name = "boxplot"


class GeoBounds(Agg):
    name = "geo_bounds"


class Max(Agg):
    name = "max"


class MedianAbsoluteDeviation(Agg):
    name = "median_absolute_deviation"


class Min(Agg):
    name = "min"


class Percentiles(Agg):
    name = "percentiles"


class PercentileRanks(Agg):
    name = "percentile_ranks"


class ScriptedMetric(Agg):
    name = "scripted_metric"


class Stats(Agg):
    name = "stats"


class Sum(Agg):
    name = "sum"


class TTest(Agg):
    name = "t_test"


class ValueCount(Agg):
    name = "value_count"


# pipeline aggregations
class AvgBucket(Pipeline):
    name = "avg_bucket"


class BucketScript(Pipeline):
    name = "bucket_script"


class BucketSelector(Pipeline):
    name = "bucket_selector"


class CumulativeSum(Pipeline):
    name = "cumulative_sum"


class CumulativeCardinality(Pipeline):
    name = "cumulative_cardinality"


class Derivative(Pipeline):
    name = "derivative"


class ExtendedStatsBucket(Pipeline):
    name = "extended_stats_bucket"


class Inference(Pipeline):
    name = "inference"


class MaxBucket(Pipeline):
    name = "max_bucket"


class MinBucket(Pipeline):
    name = "min_bucket"


class MovingFn(Pipeline):
    name = "moving_fn"


class MovingAvg(Pipeline):
    name = "moving_avg"


class MovingPercentiles(Pipeline):
    name = "moving_percentiles"


class Normalize(Pipeline):
    name = "normalize"


class PercentilesBucket(Pipeline):
    name = "percentiles_bucket"


class SerialDiff(Pipeline):
    name = "serial_diff"


class StatsBucket(Pipeline):
    name = "stats_bucket"


class SumBucket(Pipeline):
    name = "sum_bucket"


class BucketSort(Pipeline):
    name = "bucket_sort"