Move download_segment() from DASH/HLS download_track() to Class

Various overall small readability improvements have also been made.
This commit is contained in:
rlaphoenix
2023-05-17 03:12:29 +01:00
parent 03c012f88e
commit dd64212ad2
2 changed files with 288 additions and 216 deletions

View File

@@ -13,7 +13,7 @@ from functools import partial
from hashlib import md5
from pathlib import Path
from threading import Event
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union, MutableMapping
from urllib.parse import urljoin, urlparse
from uuid import UUID
@@ -392,7 +392,8 @@ class DASH:
# last chance to find the KID, assumes first segment will hold the init data
track_kid = track_kid or track.get_key_id(url=segments[0][0], session=session)
# license and grab content keys
drm = track.drm[0] # just use the first supported DRM system for now
# TODO: What if we don't want to use the first DRM system?
drm = track.drm[0]
if isinstance(drm, Widevine):
if not license_widevine:
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
@@ -404,74 +405,26 @@ class DASH:
progress(downloaded="[yellow]SKIPPED")
return
def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
if stop_event.is_set():
# the track already started downloading, but another failed or was stopped
raise KeyboardInterrupt()
segment_save_path = (save_dir / filename).with_suffix(".mp4")
segment_uri, segment_range = segment
attempts = 1
while True:
try:
downloader_ = downloader
headers_ = session.headers
if segment_range:
# aria2(c) doesn't support byte ranges, let's use python-requests (likely slower)
downloader_ = requests_downloader
headers_["Range"] = f"bytes={segment_range}"
downloader_(
uri=segment_uri,
out=segment_save_path,
headers=headers_,
proxy=proxy,
silent=attempts != 5,
segmented=True
)
break
except Exception as ee:
if stop_event.is_set() or attempts == 5:
raise ee
time.sleep(2)
attempts += 1
data_size = segment_save_path.stat().st_size
# fix audio decryption on ATVP by fixing the sample description index
# TODO: Should this be done in the video data or the init data?
if isinstance(track, Audio):
with open(segment_save_path, "rb+") as f:
segment_data = f.read()
fixed_segment_data = re.sub(
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
b"\\g<1>\x01",
segment_data
)
if fixed_segment_data != segment_data:
f.seek(0)
f.write(fixed_segment_data)
return data_size
progress(total=len(segments))
finished_threads = 0
download_sizes = []
download_speed_window = 5
last_speed_refresh = time.time()
with ThreadPoolExecutor(max_workers=16) as pool:
for download in futures.as_completed((
for i, download in enumerate(futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(segments)))),
segment=segment
DASH.download_segment,
url=url,
out_path=(save_dir / str(n).zfill(len(str(len(segments))))).with_suffix(".mp4"),
track=track,
proxy=proxy,
headers=session.headers,
bytes_range=bytes_range,
stop_event=stop_event
)
for i, segment in enumerate(segments)
)):
finished_threads += 1
for n, (url, bytes_range) in enumerate(segments)
))):
try:
download_size = download.result()
except KeyboardInterrupt:
@@ -482,16 +435,15 @@ class DASH:
# tell dl that it was cancelled
# the pool is already shut down, so exiting loop is fine
raise
except Exception as e:
except Exception:
stop_event.set() # skip pending track downloads
progress(downloaded="[red]FAILING")
pool.shutdown(wait=True, cancel_futures=True)
progress(downloaded="[red]FAILED")
# tell dl that it failed
# the pool is already shut down, so exiting loop is fine
raise e
raise
else:
# it successfully downloaded, and it was not cancelled
progress(advance=1)
now = time.time()
@@ -500,7 +452,7 @@ class DASH:
if download_size: # no size == skipped dl
download_sizes.append(download_size)
if download_sizes and (time_since > 5 or finished_threads == len(segments)):
if download_sizes and (time_since > download_speed_window or i == len(segments)):
data_size = sum(download_sizes)
download_speed = data_size / (time_since or 1)
progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s")
@@ -527,6 +479,76 @@ class DASH:
progress(downloaded="Downloaded")
@staticmethod
def download_segment(
url: str,
out_path: Path,
track: AnyTrack,
proxy: Optional[str] = None,
headers: Optional[MutableMapping[str, str | bytes]] = None,
bytes_range: Optional[str] = None,
stop_event: Optional[Event] = None
) -> int:
"""
Download a DASH Media Segment.
Parameters:
url: Full HTTP(S) URL to the Segment you want to download.
out_path: Path to save the downloaded Segment file to.
track: The Track object of which this Segment is for. Currently only used to
fix an invalid value in the TFHD box of Audio Tracks.
proxy: Proxy URI to use when downloading the Segment file.
headers: HTTP Headers to send when requesting the Segment file.
bytes_range: Download only specific bytes of the Segment file using the Range header.
stop_event: Prematurely stop the Download from beginning. Useful if ran from
a Thread Pool. It will raise a KeyboardInterrupt if set.
Returns the file size of the downloaded Segment in bytes.
"""
if stop_event and stop_event.is_set():
raise KeyboardInterrupt()
attempts = 1
while True:
try:
headers_ = headers or {}
if bytes_range:
# aria2(c) doesn't support byte ranges, use python-requests
downloader_ = requests_downloader
headers_["Range"] = f"bytes={bytes_range}"
else:
downloader_ = downloader
downloader_(
uri=url,
out=out_path,
headers=headers_,
proxy=proxy,
silent=attempts != 5,
segmented=True
)
break
except Exception as ee:
if (stop_event and stop_event.is_set()) or attempts == 5:
raise ee
time.sleep(2)
attempts += 1
# fix audio decryption on ATVP by fixing the sample description index
# TODO: Should this be done in the video data or the init data?
if isinstance(track, Audio):
with open(out_path, "rb+") as f:
segment_data = f.read()
fixed_segment_data = re.sub(
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
b"\\g<1>\x01",
segment_data
)
if fixed_segment_data != segment_data:
f.seek(0)
f.write(fixed_segment_data)
return out_path.stat().st_size
@staticmethod
def _get(
item: str,