"""Caching of formatted files with feature-based invalidation.""" import hashlib import os import pickle import sys import tempfile from dataclasses import dataclass, field from pathlib import Path from typing import Dict, Iterable, NamedTuple, Set, Tuple from platformdirs import user_cache_dir from _black_version import version as __version__ from black.mode import Mode from black.output import err if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self class FileData(NamedTuple): st_mtime: float st_size: int hash: str def get_cache_dir() -> Path: """Get the cache directory used by black. Users can customize this directory on all systems using `BLACK_CACHE_DIR` environment variable. By default, the cache directory is the user cache directory under the black application. This result is immediately set to a constant `black.cache.CACHE_DIR` as to avoid repeated calls. """ # NOTE: Function mostly exists as a clean way to test getting the cache directory. default_cache_dir = user_cache_dir("black") cache_dir = Path(os.environ.get("BLACK_CACHE_DIR", default_cache_dir)) cache_dir = cache_dir / __version__ return cache_dir CACHE_DIR = get_cache_dir() def get_cache_file(mode: Mode) -> Path: return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle" @dataclass class Cache: mode: Mode cache_file: Path file_data: Dict[str, FileData] = field(default_factory=dict) @classmethod def read(cls, mode: Mode) -> Self: """Read the cache if it exists and is well-formed. If it is not well-formed, the call to write later should resolve the issue. """ cache_file = get_cache_file(mode) try: exists = cache_file.exists() except OSError as e: # Likely file too long; see #4172 and #4174 err(f"Unable to read cache file {cache_file} due to {e}") return cls(mode, cache_file) if not exists: return cls(mode, cache_file) with cache_file.open("rb") as fobj: try: data: Dict[str, Tuple[float, int, str]] = pickle.load(fobj) file_data = {k: FileData(*v) for k, v in data.items()} except (pickle.UnpicklingError, ValueError, IndexError): return cls(mode, cache_file) return cls(mode, cache_file, file_data) @staticmethod def hash_digest(path: Path) -> str: """Return hash digest for path.""" data = path.read_bytes() return hashlib.sha256(data).hexdigest() @staticmethod def get_file_data(path: Path) -> FileData: """Return file data for path.""" stat = path.stat() hash = Cache.hash_digest(path) return FileData(stat.st_mtime, stat.st_size, hash) def is_changed(self, source: Path) -> bool: """Check if source has changed compared to cached version.""" res_src = source.resolve() old = self.file_data.get(str(res_src)) if old is None: return True st = res_src.stat() if st.st_size != old.st_size: return True if st.st_mtime != old.st_mtime: new_hash = Cache.hash_digest(res_src) if new_hash != old.hash: return True return False def filtered_cached(self, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]: """Split an iterable of paths in `sources` into two sets. The first contains paths of files that modified on disk or are not in the cache. The other contains paths to non-modified files. """ changed: Set[Path] = set() done: Set[Path] = set() for src in sources: if self.is_changed(src): changed.add(src) else: done.add(src) return changed, done def write(self, sources: Iterable[Path]) -> None: """Update the cache file data and write a new cache file.""" self.file_data.update( **{str(src.resolve()): Cache.get_file_data(src) for src in sources} ) try: CACHE_DIR.mkdir(parents=True, exist_ok=True) with tempfile.NamedTemporaryFile( dir=str(self.cache_file.parent), delete=False ) as f: # We store raw tuples in the cache because pickling NamedTuples # doesn't work with mypyc on Python 3.8, and because it's faster. data: Dict[str, Tuple[float, int, str]] = { k: (*v,) for k, v in self.file_data.items() } pickle.dump(data, f, protocol=4) os.replace(f.name, self.cache_file) except OSError: pass