"""Caching of formatted files with feature-based invalidation."""

import hashlib
import os
import pickle
import sys
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterable, NamedTuple

from platformdirs import user_cache_dir

from _black_version import version as __version__
from black.mode import Mode
from black.output import err

if sys.version_info >= (3, 11):
    from typing import Self
else:
    from typing_extensions import Self


class FileData(NamedTuple):
    st_mtime: float
    st_size: int
    hash: str


def get_cache_dir() -> Path:
    """Get the cache directory used by black.

    Users can customize this directory on all systems using `BLACK_CACHE_DIR`
    environment variable. By default, the cache directory is the user cache directory
    under the black application.

    This result is immediately set to a constant `black.cache.CACHE_DIR` as to avoid
    repeated calls.
    """
    # NOTE: Function mostly exists as a clean way to test getting the cache directory.
    default_cache_dir = user_cache_dir("black")
    cache_dir = Path(os.environ.get("BLACK_CACHE_DIR", default_cache_dir))
    cache_dir = cache_dir / __version__
    return cache_dir


CACHE_DIR = get_cache_dir()


def get_cache_file(mode: Mode) -> Path:
    return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"


@dataclass
class Cache:
    mode: Mode
    cache_file: Path
    file_data: dict[str, FileData] = field(default_factory=dict)

    @classmethod
    def read(cls, mode: Mode) -> Self:
        """Read the cache if it exists and is well-formed.

        If it is not well-formed, the call to write later should
        resolve the issue.
        """
        cache_file = get_cache_file(mode)
        try:
            exists = cache_file.exists()
        except OSError as e:
            # Likely file too long; see #4172 and #4174
            err(f"Unable to read cache file {cache_file} due to {e}")
            return cls(mode, cache_file)
        if not exists:
            return cls(mode, cache_file)

        with cache_file.open("rb") as fobj:
            try:
                data: dict[str, tuple[float, int, str]] = pickle.load(fobj)
                file_data = {k: FileData(*v) for k, v in data.items()}
            except (pickle.UnpicklingError, ValueError, IndexError):
                return cls(mode, cache_file)

        return cls(mode, cache_file, file_data)

    @staticmethod
    def hash_digest(path: Path) -> str:
        """Return hash digest for path."""

        data = path.read_bytes()
        return hashlib.sha256(data).hexdigest()

    @staticmethod
    def get_file_data(path: Path) -> FileData:
        """Return file data for path."""

        stat = path.stat()
        hash = Cache.hash_digest(path)
        return FileData(stat.st_mtime, stat.st_size, hash)

    def is_changed(self, source: Path) -> bool:
        """Check if source has changed compared to cached version."""
        res_src = source.resolve()
        old = self.file_data.get(str(res_src))
        if old is None:
            return True

        st = res_src.stat()
        if st.st_size != old.st_size:
            return True
        if st.st_mtime != old.st_mtime:
            new_hash = Cache.hash_digest(res_src)
            if new_hash != old.hash:
                return True
        return False

    def filtered_cached(self, sources: Iterable[Path]) -> tuple[set[Path], set[Path]]:
        """Split an iterable of paths in `sources` into two sets.

        The first contains paths of files that modified on disk or are not in the
        cache. The other contains paths to non-modified files.
        """
        changed: set[Path] = set()
        done: set[Path] = set()
        for src in sources:
            if self.is_changed(src):
                changed.add(src)
            else:
                done.add(src)
        return changed, done

    def write(self, sources: Iterable[Path]) -> None:
        """Update the cache file data and write a new cache file."""
        self.file_data.update(
            **{str(src.resolve()): Cache.get_file_data(src) for src in sources}
        )
        try:
            CACHE_DIR.mkdir(parents=True, exist_ok=True)
            with tempfile.NamedTemporaryFile(
                dir=str(self.cache_file.parent), delete=False
            ) as f:
                # We store raw tuples in the cache because it's faster.
                data: dict[str, tuple[float, int, str]] = {
                    k: (*v,) for k, v in self.file_data.items()
                }
                pickle.dump(data, f, protocol=4)
            os.replace(f.name, self.cache_file)
        except OSError:
            pass