Move download_segment() from DASH/HLS download_track() to Class

Various overall small readability improvements have also been made.
This commit is contained in:
rlaphoenix
2023-05-17 03:12:29 +01:00
parent 03c012f88e
commit dd64212ad2
2 changed files with 288 additions and 216 deletions

View File

@@ -214,137 +214,6 @@ class HLS:
log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.")
sys.exit(1)
drm_lock = Lock()
def download_segment(filename: str, segment: m3u8.Segment, init_data: Queue, segment_key: Queue) -> int:
if stop_event.is_set():
# the track already started downloading, but another failed or was stopped
raise KeyboardInterrupt()
if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
return 0
segment_save_path = (save_dir / filename).with_suffix(".mp4")
newest_init_data = init_data.get()
try:
if segment.init_section and (not newest_init_data or segment.discontinuity):
# Only use the init data if there's no init data yet (e.g., start of file)
# or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP.
# Even if a new EXT-X-MAP is supplied, it may just be duplicate and would
# be unnecessary and slow to re-download the init data each time.
if not segment.init_section.uri.startswith(segment.init_section.base_uri):
segment.init_section.uri = segment.init_section.base_uri + segment.init_section.uri
if segment.init_section.byterange:
byte_range = HLS.calculate_byte_range(segment.init_section.byterange)
_ = range_offset.get()
range_offset.put(byte_range.split("-")[0])
headers = {
"Range": f"bytes={byte_range}"
}
else:
headers = {}
log.debug("Got new init segment, %s", segment.init_section.uri)
res = session.get(segment.init_section.uri, headers=headers)
res.raise_for_status()
newest_init_data = res.content
finally:
init_data.put(newest_init_data)
with drm_lock:
newest_segment_key = segment_key.get()
try:
if segment.keys and newest_segment_key[1] != segment.keys:
try:
drm = HLS.get_drm(
keys=segment.keys,
proxy=proxy
)
except NotImplementedError as e:
log.error(str(e))
sys.exit(1)
else:
if drm:
track.drm = drm
drm = drm[0] # just use the first supported DRM system for now
log.debug("Got segment key, %s", drm)
if isinstance(drm, Widevine):
# license and grab content keys
track_kid = track.get_key_id(newest_init_data)
if not license_widevine:
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
license_widevine(drm, track_kid=track_kid)
newest_segment_key = (drm, segment.keys)
finally:
segment_key.put(newest_segment_key)
if skip_event.is_set():
progress(downloaded="[yellow]SKIPPING")
return 0
if not segment.uri.startswith(segment.base_uri):
segment.uri = segment.base_uri + segment.uri
attempts = 1
while True:
try:
downloader_ = downloader
headers_ = session.headers
if segment.byterange:
# aria2(c) doesn't support byte ranges, let's use python-requests (likely slower)
previous_range_offset = range_offset.get()
byte_range = HLS.calculate_byte_range(segment.byterange, previous_range_offset)
range_offset.put(byte_range.split("-")[0])
downloader_ = requests_downloader
headers_["Range"] = f"bytes={byte_range}"
downloader_(
uri=segment.uri,
out=segment_save_path,
headers=headers_,
proxy=proxy,
silent=attempts != 5,
segmented=True
)
break
except Exception as ee:
if stop_event.is_set() or attempts == 5:
raise ee
time.sleep(2)
attempts += 1
data_size = segment_save_path.stat().st_size
if isinstance(track, Audio) or newest_init_data:
with open(segment_save_path, "rb+") as f:
segment_data = f.read()
if isinstance(track, Audio):
# fix audio decryption on ATVP by fixing the sample description index
# TODO: Is this in mpeg data, or init data?
segment_data = re.sub(
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
b"\\g<1>\x01",
segment_data
)
# prepend the init data to be able to decrypt
if newest_init_data:
f.seek(0)
f.write(newest_init_data)
f.write(segment_data)
if newest_segment_key[0]:
newest_segment_key[0].decrypt(segment_save_path)
track.drm = None
if callable(track.OnDecrypted):
track.OnDecrypted(track)
return data_size
segment_key = Queue(maxsize=1)
init_data = Queue(maxsize=1)
range_offset = Queue(maxsize=1)
if track.drm:
session_drm = track.drm[0] # just use the first supported DRM system for now
if isinstance(session_drm, Widevine):
@@ -355,30 +224,39 @@ class HLS:
else:
session_drm = None
# have data to begin with, or it will be stuck waiting on the first pool forever
segment_key.put((session_drm, None))
init_data.put(None)
range_offset.put(0)
progress(total=len(master.segments))
finished_threads = 0
download_sizes = []
download_speed_window = 5
last_speed_refresh = time.time()
with ThreadPoolExecutor(max_workers=16) as pool:
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(master.segments)))),
segment=segment,
init_data=init_data,
segment_key=segment_key
)
for i, segment in enumerate(master.segments)
)):
finished_threads += 1
segment_key = Queue(maxsize=1)
segment_key.put((session_drm, None))
init_data = Queue(maxsize=1)
init_data.put(None)
range_offset = Queue(maxsize=1)
range_offset.put(0)
drm_lock = Lock()
with ThreadPoolExecutor(max_workers=16) as pool:
for i, download in enumerate(futures.as_completed((
pool.submit(
HLS.download_segment,
segment=segment,
out_path=(save_dir / str(n).zfill(len(str(len(master.segments))))).with_suffix(".mp4"),
track=track,
init_data=init_data,
segment_key=segment_key,
range_offset=range_offset,
drm_lock=drm_lock,
license_widevine=license_widevine,
session=session,
proxy=proxy,
stop_event=stop_event,
skip_event=skip_event
)
for n, segment in enumerate(master.segments)
))):
try:
download_size = download.result()
except KeyboardInterrupt:
@@ -401,13 +279,17 @@ class HLS:
# it successfully downloaded, and it was not cancelled
progress(advance=1)
if download_size == -1: # skipped for --skip-dl
progress(downloaded="[yellow]SKIPPING")
continue
now = time.time()
time_since = now - last_speed_refresh
if download_size: # no size == skipped dl
download_sizes.append(download_size)
if download_sizes and (time_since > 5 or finished_threads == len(master.segments)):
if download_sizes and (time_since > download_speed_window or i == len(master.segments)):
data_size = sum(download_sizes)
download_speed = data_size / (time_since or 1)
progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s")
@@ -424,6 +306,174 @@ class HLS:
track.path = save_path
save_dir.rmdir()
@staticmethod
def download_segment(
segment: m3u8.Segment,
out_path: Path,
track: AnyTrack,
init_data: Queue,
segment_key: Queue,
range_offset: Queue,
drm_lock: Lock,
license_widevine: Optional[Callable] = None,
session: Optional[Session] = None,
proxy: Optional[str] = None,
stop_event: Optional[Event] = None,
skip_event: Optional[Event] = None
) -> int:
"""
Download (and Decrypt) an HLS Media Segment.
Note: Make sure all Queue objects passed are appropriately initialized with
a starting value or this function may get permanently stuck.
Parameters:
segment: The m3u8.Segment Object to Download.
out_path: Path to save the downloaded Segment file to.
track: The Track object of which this Segment is for. Currently used to fix an
invalid value in the TFHD box of Audio Tracks, for the OnSegmentFilter, and
for DRM-related operations like getting the Track ID and Decryption.
init_data: Queue for saving and loading the most recent init section data.
segment_key: Queue for saving and loading the most recent DRM object, and it's
adjacent Segment.Key object.
range_offset: Queue for saving and loading the most recent Segment Bytes Range.
drm_lock: Prevent more than one Download from doing anything DRM-related at the
same time. Make sure all calls to download_segment() use the same Lock object.
license_widevine: Function used to license Widevine DRM objects. It must be passed
if the Segment's DRM uses Widevine.
proxy: Proxy URI to use when downloading the Segment file.
session: Python-Requests Session used when requesting init data.
stop_event: Prematurely stop the Download from beginning. Useful if ran from
a Thread Pool. It will raise a KeyboardInterrupt if set.
skip_event: Prematurely stop the Download from beginning. It returns with a
file size of -1 directly after DRM licensing occurs, even if it's DRM-free.
This is mainly for `--skip-dl` to allow licensing without downloading.
Returns the file size of the downloaded Segment in bytes.
"""
if stop_event.is_set():
raise KeyboardInterrupt()
if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
return 0
# handle init section changes
newest_init_data = init_data.get()
try:
if segment.init_section and (not newest_init_data or segment.discontinuity):
# Only use the init data if there's no init data yet (e.g., start of file)
# or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP.
# Even if a new EXT-X-MAP is supplied, it may just be duplicate and would
# be unnecessary and slow to re-download the init data each time.
if not segment.init_section.uri.startswith(segment.init_section.base_uri):
segment.init_section.uri = segment.init_section.base_uri + segment.init_section.uri
if segment.init_section.byterange:
byte_range = HLS.calculate_byte_range(segment.init_section.byterange)
_ = range_offset.get()
range_offset.put(byte_range.split("-")[0])
range_header = {
"Range": f"bytes={byte_range}"
}
else:
range_header = {}
res = session.get(segment.init_section.uri, headers=range_header)
res.raise_for_status()
newest_init_data = res.content
finally:
init_data.put(newest_init_data)
# handle segment key changes
with drm_lock:
newest_segment_key = segment_key.get()
try:
if segment.keys and newest_segment_key[1] != segment.keys:
drm = HLS.get_drm(
keys=segment.keys,
proxy=proxy
)
if drm:
track.drm = drm
# license and grab content keys
# TODO: What if we don't want to use the first DRM system?
drm = drm[0]
if isinstance(drm, Widevine):
track_kid = track.get_key_id(newest_init_data)
if not license_widevine:
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
license_widevine(drm, track_kid=track_kid)
newest_segment_key = (drm, segment.keys)
finally:
segment_key.put(newest_segment_key)
if skip_event.is_set():
return -1
if not segment.uri.startswith(segment.base_uri):
segment.uri = segment.base_uri + segment.uri
attempts = 1
while True:
try:
headers_ = session.headers
if segment.byterange:
# aria2(c) doesn't support byte ranges, use python-requests
downloader_ = requests_downloader
previous_range_offset = range_offset.get()
byte_range = HLS.calculate_byte_range(segment.byterange, previous_range_offset)
range_offset.put(byte_range.split("-")[0])
headers_["Range"] = f"bytes={byte_range}"
else:
downloader_ = downloader
downloader_(
uri=segment.uri,
out=out_path,
headers=headers_,
proxy=proxy,
silent=attempts != 5,
segmented=True
)
break
except Exception as ee:
if stop_event.is_set() or attempts == 5:
raise ee
time.sleep(2)
attempts += 1
download_size = out_path.stat().st_size
# fix audio decryption on ATVP by fixing the sample description index
# TODO: Should this be done in the video data or the init data?
if isinstance(track, Audio):
with open(out_path, "rb+") as f:
segment_data = f.read()
fixed_segment_data = re.sub(
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
b"\\g<1>\x01",
segment_data
)
if fixed_segment_data != segment_data:
f.seek(0)
f.write(fixed_segment_data)
# prepend the init data to be able to decrypt
if newest_init_data:
with open(out_path, "rb+") as f:
segment_data = f.read()
f.seek(0)
f.write(newest_init_data)
f.write(segment_data)
# decrypt segment if encrypted
if newest_segment_key[0]:
newest_segment_key[0].decrypt(out_path)
track.drm = None
if callable(track.OnDecrypted):
track.OnDecrypted(track)
return download_size
@staticmethod
def get_drm(
keys: list[Union[m3u8.model.SessionKey, m3u8.model.Key]],