Initial import: Music_Server, MusicFree, catalog-sync
This commit is contained in:
@@ -0,0 +1,383 @@
|
||||
'''
|
||||
Function:
|
||||
Implementation of HLSDownloader
|
||||
Author:
|
||||
Zhenchao Jin
|
||||
WeChat Official Account (微信公众号):
|
||||
Charles的皮卡丘
|
||||
'''
|
||||
import os
|
||||
import re
|
||||
import copy
|
||||
import time
|
||||
import math
|
||||
import m3u8
|
||||
import base64
|
||||
import shutil
|
||||
import hashlib
|
||||
import requests
|
||||
import threading
|
||||
import concurrent.futures as cf
|
||||
from pathlib import Path
|
||||
from .misc import touchdir
|
||||
from .logger import LoggerHandle
|
||||
from urllib.parse import urljoin
|
||||
from dataclasses import dataclass
|
||||
from rich.progress import Progress
|
||||
from typing import Optional, Dict, Any, Tuple, List, Union, Callable
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
|
||||
'''SegmentJob'''
|
||||
@dataclass(frozen=True)
|
||||
class SegmentJob:
|
||||
index: int
|
||||
uri: str
|
||||
byterange: Optional[str]
|
||||
key_method: Optional[str]
|
||||
key_uri: Optional[str]
|
||||
key_iv: Optional[str]
|
||||
keyformat: Optional[str]
|
||||
media_sequence: int
|
||||
map_uri: Optional[str]
|
||||
map_byterange: Optional[str]
|
||||
|
||||
|
||||
'''HLSDownloader'''
|
||||
class HLSDownloader:
|
||||
def __init__(self, output_dir: str = "downloads", proxies: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None, cookies: Optional[Dict[str, str]] = None, timeout: Tuple[float, float] = (10.0, 30.0), logger_handle: LoggerHandle = None,
|
||||
verify_tls: bool = True, concurrency: int = 16, max_retries: int = 8, backoff_base: float = 0.6, backoff_cap: float = 10.0, chunk_size: int = 1024 * 256, strict_key_length: bool = False, disable_print: bool = False, request_overrides: dict = None):
|
||||
# work dir
|
||||
self.output_dir = output_dir
|
||||
touchdir(self.output_dir)
|
||||
# logger
|
||||
self.logger_handle = logger_handle
|
||||
self.disable_print = disable_print
|
||||
# http requests
|
||||
self.proxies = proxies or {}
|
||||
self.headers = headers or {}
|
||||
self.cookies = cookies or {}
|
||||
self.timeout = timeout
|
||||
self.verify_tls = verify_tls
|
||||
self.chunk_size = int(chunk_size)
|
||||
self.backoff_cap = float(backoff_cap)
|
||||
self.backoff_base = float(backoff_base)
|
||||
self.concurrency = max(1, int(concurrency))
|
||||
self.max_retries = max(1, int(max_retries))
|
||||
self.strict_key_length = bool(strict_key_length)
|
||||
self.request_overrides = request_overrides or {}
|
||||
# threading
|
||||
self._tls = threading.local()
|
||||
self._key_cache: Dict[str, bytes] = {}
|
||||
self._key_cache_lock = threading.Lock()
|
||||
'''download'''
|
||||
def download(self, m3u8_url: str, output_path: str, quality: Union[str, int, Callable[[List[Dict[str, Any]]], int]] = "best", keep_segments: bool = False, temp_subdir: Optional[str] = None, progress: Progress = None, progress_id: int = 0) -> str:
|
||||
master_or_media = self._loadm3u8(m3u8_url)
|
||||
if master_or_media.is_variant:
|
||||
variant_url = self._selectvariant(master_or_media, quality)
|
||||
self.logger_handle.info(f"Selected variant: {variant_url}", disable_print=self.disable_print)
|
||||
playlist = self._loadm3u8(variant_url)
|
||||
else:
|
||||
playlist = master_or_media
|
||||
jobs, global_init_map = self._buildjobs(playlist)
|
||||
temp_folder, global_init_path = os.path.join(self.output_dir, temp_subdir or f".hls_tmp_{self._safenamefromurl(m3u8_url)}"), None
|
||||
touchdir(temp_folder)
|
||||
if global_init_map:
|
||||
global_init_path = os.path.join(temp_folder, "_global_init.bin")
|
||||
if not self._fileok(global_init_path): self._atomicwrite(global_init_path, self._fetchbytes(global_init_map["uri"], global_init_map.get("byterange")))
|
||||
seg_paths = self._downloadallsegments(jobs, temp_folder, progress=progress, progress_id=progress_id)
|
||||
touchdir(os.path.dirname(os.path.abspath(output_path)) or ".")
|
||||
self._mergefiles(global_init_path, seg_paths, output_path)
|
||||
if not keep_segments: shutil.rmtree(temp_folder, ignore_errors=True)
|
||||
return output_path
|
||||
'''_getsession'''
|
||||
def _getsession(self) -> requests.Session:
|
||||
sess = getattr(self._tls, "session", None)
|
||||
if sess is None:
|
||||
sess = requests.Session()
|
||||
sess.headers.update(self.headers)
|
||||
if self.cookies: sess.cookies.update(self.cookies)
|
||||
self._tls.session = sess
|
||||
return sess
|
||||
'''_request'''
|
||||
def _request(self, url: str, method: str = "GET", headers: Optional[Dict[str, str]] = None, stream: bool = False, **kwargs) -> requests.Response:
|
||||
kwargs.update(copy.deepcopy(self.request_overrides))
|
||||
sess, last_exc = self._getsession(), None
|
||||
hdrs = dict(self.headers)
|
||||
if headers: hdrs.update(headers)
|
||||
for attempt in range(1, self.max_retries + 1):
|
||||
try:
|
||||
resp = sess.request(method=method, url=url, headers=hdrs, proxies=self.proxies, timeout=self.timeout, verify=self.verify_tls, stream=stream, **kwargs)
|
||||
if resp.status_code in (429, 500, 502, 503, 504): resp.close(); raise requests.HTTPError(f"HTTP {resp.status_code} for {url}")
|
||||
resp.raise_for_status()
|
||||
return resp
|
||||
except Exception as e:
|
||||
last_exc = e
|
||||
t = min(self.backoff_cap, self.backoff_base * (2 ** (attempt - 1)))
|
||||
t = t + (0.1 * t * (0.5 - (time.time() % 1)))
|
||||
time.sleep(max(0.0, t))
|
||||
raise RuntimeError(f"Request failed after retries: {url}\nLast error: {last_exc}")
|
||||
'''_gettext'''
|
||||
def _gettext(self, url: str) -> str:
|
||||
resp = self._request(url, stream=False)
|
||||
return resp.text
|
||||
'''_getbytes'''
|
||||
def _getbytes(self, url: str, headers: Optional[Dict[str, str]] = None) -> bytes:
|
||||
resp = self._request(url, headers=headers, stream=True)
|
||||
chunks = []
|
||||
for c in resp.iter_content(chunk_size=self.chunk_size):
|
||||
if c: chunks.append(c)
|
||||
resp.close()
|
||||
return b"".join(chunks)
|
||||
'''_fetchbytes'''
|
||||
def _fetchbytes(self, url: str, byterange: Optional[str]) -> bytes:
|
||||
headers = {}
|
||||
if byterange:
|
||||
length, offset = self._parsebyterange(byterange)
|
||||
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
|
||||
return self._getbytes(url, headers=headers)
|
||||
'''_loadm3u8'''
|
||||
def _loadm3u8(self, url: str) -> m3u8.M3U8:
|
||||
text = self._gettext(url)
|
||||
return m3u8.loads(text, uri=url)
|
||||
'''_selectvariant'''
|
||||
def _selectvariant(self, master: m3u8.M3U8, quality: Union[str, int, Callable[[List[Dict[str, Any]]], int]]) -> str:
|
||||
variants, bw_func = [], lambda v: int(v.get("average_bandwidth") or v.get("bandwidth") or 0)
|
||||
for i, p in enumerate(master.playlists or []):
|
||||
si = getattr(p, "stream_info", None)
|
||||
variants.append({
|
||||
"index": i, "absolute_uri": getattr(p, "absolute_uri", None) or urljoin(master.base_uri or master.uri, p.uri), "uri": p.uri, "bandwidth": getattr(si, "bandwidth", None) if si else None,
|
||||
"average_bandwidth": getattr(si, "average_bandwidth", None) if si else None, "resolution": getattr(si, "resolution", None) if si else None, "codecs": getattr(si, "codecs", None) if si else None,
|
||||
"frame_rate": getattr(si, "frame_rate", None) if si else None,
|
||||
})
|
||||
if not variants: raise ValueError("Master playlist has no variants.")
|
||||
if callable(quality): idx = int(quality(variants)); idx = max(0, min(idx, len(variants) - 1)); return variants[idx]["absolute_uri"]
|
||||
if isinstance(quality, str):
|
||||
q = quality.lower().strip()
|
||||
if q == "best":
|
||||
chosen = max(variants, key=bw_func)
|
||||
elif q == "lowest":
|
||||
chosen = min(variants, key=bw_func)
|
||||
else:
|
||||
m = re.search(r"(\d+)", q)
|
||||
if m: target = int(m.group(1)); chosen = min(variants, key=lambda v: abs(bw_func(v) - target))
|
||||
else: chosen = max(variants, key=bw_func)
|
||||
else:
|
||||
target = int(quality)
|
||||
chosen = min(variants, key=lambda v: abs(bw_func(v) - target))
|
||||
return chosen["absolute_uri"]
|
||||
'''_buildjobs'''
|
||||
def _buildjobs(self, playlist: m3u8.M3U8) -> Tuple[List[SegmentJob], Optional[Dict[str, Any]]]:
|
||||
media_seq = int(getattr(playlist, "media_sequence", 0) or 0)
|
||||
global_init, seg_map = None, getattr(playlist, "segment_map", None)
|
||||
if seg_map:
|
||||
try: sm0 = seg_map[0]; global_init = {"uri": getattr(sm0, "absolute_uri", None) or urljoin(playlist.base_uri, sm0.uri), "byterange": getattr(sm0, "byterange", None)}
|
||||
except Exception: global_init = None
|
||||
jobs: List[SegmentJob] = []
|
||||
session_keys = getattr(playlist, "session_keys", None) or []
|
||||
fallback_session_key, last_key_obj = session_keys[-1] if session_keys else None, None
|
||||
for i, seg in enumerate(playlist.segments or []):
|
||||
seg_uri, key_obj = getattr(seg, "absolute_uri", None) or urljoin(playlist.base_uri, seg.uri), getattr(seg, "key", None) or last_key_obj or fallback_session_key
|
||||
if getattr(seg, "key", None) is not None: last_key_obj = getattr(seg, "key", None)
|
||||
key_method, key_uri, key_iv, keyformat = (getattr(key_obj, k, None) for k in ("method", "uri", "iv", "keyformat")) if key_obj else (None, None, None, None)
|
||||
key_uri_abs = (key_uri if key_uri and (key_uri.startswith("data:") or key_uri.startswith("skd://")) else (urljoin(playlist.base_uri, key_uri) if key_uri else None))
|
||||
init_section = getattr(seg, "init_section", None)
|
||||
map_uri, map_byterange = ((getattr(init_section, "absolute_uri", None) or (urljoin(playlist.base_uri, getattr(init_section, "uri", "")) if getattr(init_section, "uri", None) else None)), getattr(init_section, "byterange", None)) if init_section is not None else (None, None)
|
||||
jobs.append(SegmentJob(index=i, uri=seg_uri, byterange=getattr(seg, "byterange", None), key_method=key_method, key_uri=key_uri_abs, key_iv=key_iv, keyformat=keyformat, media_sequence=media_seq, map_uri=map_uri, map_byterange=map_byterange))
|
||||
return jobs, global_init
|
||||
'''_downloadallsegments'''
|
||||
def _downloadallsegments(self, jobs: List[SegmentJob], temp_folder: str, progress: Progress, progress_id: int) -> List[str]:
|
||||
progress.update(progress_id, description=f"HLSDownloader._downloadallsegments >>> completed (0/{len(jobs)})", total=len(jobs), kind='hls')
|
||||
byterange_cursor: Dict[str, int] = {}; seg_paths: List[Optional[str]] = [None] * len(jobs)
|
||||
init_cache: Dict[str, str] = {}; init_inflight: Dict[str, threading.Event] = {}; init_cache_lock = threading.Lock()
|
||||
def ensureinitsection_func(map_uri: str, map_byterange: Optional[str]) -> bytes:
|
||||
key = f"{map_uri}|{map_byterange or ''}"
|
||||
with init_cache_lock:
|
||||
cached = init_cache.get(key)
|
||||
if cached and self._fileok(cached): return Path(cached).read_bytes()
|
||||
leader = (evt := init_inflight.get(key)) is None; evt = init_inflight[key] = threading.Event() if leader else evt
|
||||
if not leader:
|
||||
evt.wait()
|
||||
with init_cache_lock: cached = init_cache.get(key)
|
||||
return Path(cached).read_bytes() if cached and self._fileok(cached) else (_ for _ in ()).throw(RuntimeError(f"init_section download failed: {key}"))
|
||||
try:
|
||||
data = self._fetchbytes(map_uri, map_byterange)
|
||||
path = os.path.join(temp_folder, f"_initsec_{abs(hash(key)) & 0xffffffff:08x}.bin")
|
||||
self._atomicwrite(path, data)
|
||||
with init_cache_lock: init_cache[key] = path
|
||||
return data
|
||||
finally:
|
||||
with init_cache_lock: (evt := init_inflight.pop(key, None)) and evt.set()
|
||||
def worker_func(job: SegmentJob) -> Tuple[int, str]:
|
||||
seg_path = os.path.join(temp_folder, f"seg_{job.index:06d}.bin")
|
||||
if self._fileok(seg_path): return job.index, seg_path
|
||||
prepend = ensureinitsection_func(job.map_uri, job.map_byterange) if job.map_uri else b""
|
||||
eff_byterange = self._normalizebyterange(job.uri, job.byterange, byterange_cursor) if job.byterange else job.byterange
|
||||
data = self._fetchandmaybedecrypt(job, eff_byterange)
|
||||
self._atomicwrite(seg_path, prepend + data)
|
||||
return job.index, seg_path
|
||||
exceptions: List[Exception] = []
|
||||
with cf.ThreadPoolExecutor(max_workers=self.concurrency) as ex:
|
||||
futures = [ex.submit(worker_func, j) for j in jobs]
|
||||
for fut in cf.as_completed(futures):
|
||||
try:
|
||||
idx, path = fut.result()
|
||||
seg_paths[idx] = path
|
||||
except Exception as e:
|
||||
exceptions.append(e)
|
||||
finally:
|
||||
progress.advance(progress_id, 1)
|
||||
num_downloaded_segs = int(progress.tasks[progress_id].completed)
|
||||
progress.update(progress_id, description=f"HLSDownloader._downloadallsegments >>> completed ({num_downloaded_segs}/{len(jobs)})")
|
||||
if exceptions: raise exceptions[0]
|
||||
return [p for p in seg_paths if p is not None]
|
||||
'''_fetchandmaybedecrypt'''
|
||||
def _fetchandmaybedecrypt(self, job: SegmentJob, eff_byterange: Optional[str]) -> bytes:
|
||||
method_raw, keyformat = (job.key_method or "").strip(), (job.keyformat or "").strip().lower()
|
||||
if not method_raw or method_raw.upper() == "NONE": return self._fetchbytes(job.uri, eff_byterange)
|
||||
if keyformat and keyformat not in ("identity",): raise NotImplementedError(f"Unsupported KEYFORMAT={job.keyformat} (likely DRM).")
|
||||
method = method_raw.upper().replace("_", "-")
|
||||
dec_mode = self._classifyencryptionmethod(method)
|
||||
if dec_mode in ("DRM", "UNSUPPORTED"): raise NotImplementedError(f"Unsupported encryption method: {method_raw}")
|
||||
if not job.key_uri: raise RuntimeError(f"Encrypted segment missing key URI at seg {job.index}")
|
||||
key, base_iv = self._prepareaeskey(method, self._getkeybytes(job.key_uri)), self._deriveiv(job.key_iv, job.media_sequence + job.index)
|
||||
if not eff_byterange: ciphertext = self._fetchbytes(job.uri, None); return self._decryptwhole(ciphertext, dec_mode, key, base_iv)
|
||||
length, offset = self._parsebyterange(eff_byterange)
|
||||
block, end = 16, offset + length
|
||||
aligned_start, aligned_end = (offset // block) * block, int(math.ceil(end / block) * block)
|
||||
if dec_mode == "CBC":
|
||||
fetch_start, drop = ((aligned_start - block, offset - aligned_start + block) if aligned_start > 0 else (aligned_start, offset - aligned_start)); fetch_len = aligned_end - fetch_start; fetch_range = f"{fetch_len}@{fetch_start}"
|
||||
ciphertext = self._fetchbytes(job.uri, fetch_range)
|
||||
iv = (b"\x00" * 16) if fetch_start > 0 else base_iv
|
||||
plaintext = self._aescbcdecrypt(ciphertext, key, iv)
|
||||
return plaintext[drop: drop+length]
|
||||
else:
|
||||
fetch_start, drop, fetch_len, fetch_range = aligned_start, offset - aligned_start, aligned_end - aligned_start, f"{aligned_end - aligned_start}@{aligned_start}"
|
||||
ciphertext = self._fetchbytes(job.uri, fetch_range)
|
||||
block_index = fetch_start // block
|
||||
iv_int = int.from_bytes(base_iv, "big")
|
||||
adj_iv = ((iv_int + block_index) % (1 << 128)).to_bytes(16, "big")
|
||||
plaintext = self._aesctrcrypt(ciphertext, key, adj_iv)
|
||||
return plaintext[drop: drop+length]
|
||||
'''_decryptwhole'''
|
||||
def _decryptwhole(self, ciphertext: bytes, dec_mode: str, key: bytes, iv: bytes) -> bytes:
|
||||
if dec_mode == "CBC": return self._aescbcdecrypt(ciphertext, key, iv)
|
||||
if dec_mode == "CTR": return self._aesctrcrypt(ciphertext, key, iv)
|
||||
raise NotImplementedError(f"decrypt mode {dec_mode} not supported")
|
||||
'''_classifyencryptionmethod'''
|
||||
def _classifyencryptionmethod(self, method: str) -> str:
|
||||
m = method.strip().upper()
|
||||
if m in ("AES-128", "AES-128-CBC", "AES-CBC", "CBC"): return "CBC"
|
||||
if m in ("AES-CTR", "AES-128-CTR", "AES-192-CTR", "AES-256-CTR"): return "CTR"
|
||||
if m.startswith("SAMPLE-AES") or "SKD" in m: return "DRM"
|
||||
return "UNSUPPORTED"
|
||||
'''_getkeybytes'''
|
||||
def _getkeybytes(self, key_uri: str) -> bytes:
|
||||
if key_uri.startswith("data:"):
|
||||
if "base64," in key_uri: b64 = key_uri.split("base64,", 1)[1]; return base64.b64decode(b64)
|
||||
if "," in key_uri: raw = key_uri.split(",", 1)[1]; return raw.encode("utf-8", errors="ignore")
|
||||
raise ValueError("Unsupported data: key URI")
|
||||
if key_uri.startswith("skd://"): raise NotImplementedError("skd:// indicates DRM (FairPlay). Not supported.")
|
||||
with self._key_cache_lock:
|
||||
if key_uri in self._key_cache: return self._key_cache[key_uri]
|
||||
b = self._getbytes(key_uri)
|
||||
with self._key_cache_lock: self._key_cache[key_uri] = b
|
||||
return b
|
||||
'''_decodekeyguess'''
|
||||
def _decodekeyguess(self, key_bytes: bytes) -> bytes:
|
||||
b = key_bytes.strip()
|
||||
if b"\x00" in b: return b
|
||||
b2 = b
|
||||
if b2.lower().startswith(b"0x"): b2 = b2[2:]
|
||||
if re.fullmatch(rb"[0-9a-fA-F]+", b2) and len(b2) in (32, 48, 64):
|
||||
try: return bytes.fromhex(b2.decode("ascii"))
|
||||
except Exception: pass
|
||||
if re.fullmatch(rb"[A-Za-z0-9+/=\r\n]+", b) and (len(b) % 4 == 0):
|
||||
try:
|
||||
dec = base64.b64decode(b, validate=False)
|
||||
if len(dec) in (16, 24, 32): return dec
|
||||
except Exception:
|
||||
pass
|
||||
return b
|
||||
'''_expectedkeylen'''
|
||||
def _expectedkeylen(self, method: str) -> int:
|
||||
m = method.upper()
|
||||
if "256" in m: return 32
|
||||
if "192" in m: return 24
|
||||
return 16
|
||||
'''_prepareaeskey'''
|
||||
def _prepareaeskey(self, method: str, key_bytes: bytes) -> bytes:
|
||||
k = self._decodekeyguess(key_bytes)
|
||||
want = self._expectedkeylen(method)
|
||||
if len(k) == want: return k
|
||||
if self.strict_key_length: raise ValueError(f"Bad key length for {method}: got {len(k)} bytes, expected {want}")
|
||||
self.logger_handle.warning(f"Key length mismatch for {method}: got {len(k)}, expected {want}. Best-effort fix.", disable_print=self.disable_print)
|
||||
if len(k) > want: return k[:want]
|
||||
return (k + b"\x00" * want)[:want]
|
||||
'''_deriveiv'''
|
||||
def _deriveiv(self, iv_str: Optional[str], seq_num: int) -> bytes:
|
||||
if not iv_str: return seq_num.to_bytes(16, byteorder="big", signed=False)
|
||||
s = str(iv_str).strip().lower()
|
||||
if s.startswith("0x"): s = s[2:]
|
||||
try: iv = bytes.fromhex(s)
|
||||
except Exception: iv = s.encode("utf-8", errors="ignore")
|
||||
if len(iv) < 16: iv = (b"\x00" * (16 - len(iv))) + iv
|
||||
if len(iv) > 16: iv = iv[-16:]
|
||||
return iv
|
||||
'''_aescbcdecrypt'''
|
||||
def _aescbcdecrypt(self, ciphertext: bytes, key: bytes, iv: bytes) -> bytes:
|
||||
if len(ciphertext) % 16 != 0: raise ValueError(f"CBC ciphertext length not multiple of 16: {len(ciphertext)} bytes")
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
|
||||
dec = cipher.decryptor()
|
||||
return dec.update(ciphertext) + dec.finalize()
|
||||
'''_aesctrcrypt'''
|
||||
def _aesctrcrypt(self, data: bytes, key: bytes, iv: bytes) -> bytes:
|
||||
cipher = Cipher(algorithms.AES(key), modes.CTR(iv))
|
||||
dec = cipher.decryptor()
|
||||
return dec.update(data) + dec.finalize()
|
||||
'''_parsebyterange'''
|
||||
def _parsebyterange(self, s: str) -> Tuple[int, int]:
|
||||
s = s.strip()
|
||||
if "@" in s: a, b = s.split("@", 1); return int(a), int(b)
|
||||
raise ValueError(f"BYTERANGE missing offset: {s}")
|
||||
'''_normalizebyterange'''
|
||||
def _normalizebyterange(self, uri: str, byterange: str, cursor: Dict[str, int]) -> str:
|
||||
s = byterange.strip()
|
||||
if "@" in s: length, offset = s.split("@", 1); length_i, offset_i = int(length), int(offset); cursor[uri] = offset_i + length_i; return f"{length_i}@{offset_i}"
|
||||
length_i = int(s)
|
||||
prev = cursor.get(uri, 0)
|
||||
cursor[uri] = prev + length_i
|
||||
return f"{length_i}@{prev}"
|
||||
'''_mergefiles'''
|
||||
def _mergefiles(self, global_init_path: Optional[str], seg_paths: List[str], output_path: str) -> None:
|
||||
tmp_out = output_path + ".part"
|
||||
with open(tmp_out, "wb") as out:
|
||||
if global_init_path and self._fileok(global_init_path):
|
||||
with open(global_init_path, "rb") as fp: shutil.copyfileobj(fp, out, length=1024 * 1024)
|
||||
for p in seg_paths:
|
||||
with open(p, "rb") as fp: shutil.copyfileobj(fp, out, length=1024 * 1024)
|
||||
os.replace(tmp_out, output_path)
|
||||
'''_safenamefromurl'''
|
||||
def _safenamefromurl(self, url: str, max_len: int = 20) -> str:
|
||||
return hashlib.sha256(url.encode("utf-8")).hexdigest()[:max_len]
|
||||
'''_fileok'''
|
||||
def _fileok(self, path: str) -> bool:
|
||||
return os.path.exists(path) and os.path.getsize(path) > 0
|
||||
'''_atomicwrite'''
|
||||
def _atomicwrite(self, path: str, data: bytes) -> None:
|
||||
touchdir(os.path.dirname(os.path.abspath(path)) or ".")
|
||||
pid, tid = os.getpid(), threading.get_ident()
|
||||
tmp, last = f"{path}.tmp.{pid}.{tid}.{time.time_ns()}", None
|
||||
with open(tmp, "wb") as fp:
|
||||
fp.write(data)
|
||||
try: fp.flush(); os.fsync(fp.fileno())
|
||||
except Exception: pass
|
||||
for i in range(12):
|
||||
try: os.replace(tmp, path); return
|
||||
except PermissionError as e: last = e; time.sleep(min(0.5, 0.03 * (2 ** i)))
|
||||
except OSError as e: last = e; time.sleep(min(0.5, 0.03 * (2 ** i)))
|
||||
try:
|
||||
if os.path.exists(tmp): os.remove(tmp)
|
||||
except Exception:
|
||||
pass
|
||||
raise last
|
||||
Reference in New Issue
Block a user