Completely rewrite downloading system

The new system now downloads and decrypts segments individually instead of downloading all segments, merging them, and then decrypting. Overall the download system now acts more like a normal player.

This fixes #23 as the new HLS download system detects changes in keys and init segments as segments are downloaded. DASH still only supports one period, and one period only, but hopefully I can change that in the future.

Downloading code is now also moved from the Track classes to the manifest classes. Download progress is now also actually helpful for segmented downloads (all HLS, and most DASH streams). It uses TQDM to show a progress bar based on how many segments it needs to download, and how fast it downloads them.

There's only one down side currently. Downloading of segmented videos no longer have the benefit of aria2c's -j parameter. Where it can download n URLs concurrently. Aria2c is still used but only -x and -s is going to make a difference.

In the future I will make HLS and DASH download in a multi-threaded way, sort of a manual version of -j.
This commit is contained in:
rlaphoenix
2023-02-21 05:42:00 +00:00
parent c925cb8af9
commit 42aaa03941
6 changed files with 591 additions and 436 deletions

View File

@@ -1,10 +1,14 @@
from __future__ import annotations
import asyncio
import base64
import logging
import math
import re
import sys
from copy import copy
from hashlib import md5
from pathlib import Path
from typing import Any, Callable, Optional, Union
from urllib.parse import urljoin, urlparse
from uuid import UUID
@@ -15,6 +19,8 @@ from pywidevine.cdm import Cdm as WidevineCdm
from pywidevine.pssh import PSSH
from requests import Session
from devine.core.constants import AnyTrack
from devine.core.downloaders import aria2c
from devine.core.drm import Widevine
from devine.core.tracks import Audio, Subtitle, Tracks, Video
from devine.core.utilities import is_close_match
@@ -92,12 +98,7 @@ class DASH:
if callable(period_filter) and period_filter(period):
continue
period_base_url = period.findtext("BaseURL") or self.manifest.findtext("BaseURL")
if not period_base_url or not re.match("^https?://", period_base_url, re.IGNORECASE):
period_base_url = urljoin(self.url, period_base_url)
for adaptation_set in period.findall("AdaptationSet"):
# flags
trick_mode = any(
x.get("schemeIdUri") == "http://dashif.org/guidelines/trickmode"
for x in (
@@ -105,6 +106,10 @@ class DASH:
adaptation_set.findall("SupplementalProperty")
)
)
if trick_mode:
# we don't want trick mode streams (they are only used for fast-forward/rewind)
continue
descriptive = any(
(x.get("schemeIdUri"), x.get("value")) == ("urn:mpeg:dash:role:2011", "descriptive")
for x in adaptation_set.findall("Accessibility")
@@ -121,12 +126,8 @@ class DASH:
for x in adaptation_set.findall("Role")
)
if trick_mode:
# we don't want trick mode streams (they are only used for fast-forward/rewind)
continue
for rep in adaptation_set.findall("Representation"):
supplements = rep.findall("SupplementalProperty") + adaptation_set.findall("SupplementalProperty")
codecs = rep.get("codecs") or adaptation_set.get("codecs")
content_type = adaptation_set.get("contentType") or \
adaptation_set.get("mimeType") or \
@@ -136,8 +137,6 @@ class DASH:
raise ValueError("No content type value could be found")
content_type = content_type.split("/")[0]
codecs = rep.get("codecs") or adaptation_set.get("codecs")
if content_type.startswith("image"):
# we don't want what's likely thumbnails for the seekbar
continue
@@ -154,6 +153,8 @@ class DASH:
if mime and not mime.endswith("/mp4"):
codecs = mime.split("/")[1]
supplements = rep.findall("SupplementalProperty") + adaptation_set.findall("SupplementalProperty")
joc = next((
x.get("value")
for x in supplements
@@ -167,18 +168,6 @@ class DASH:
"The provided fallback language is not valid or is `None` or `und`."
)
drm = DASH.get_drm(rep.findall("ContentProtection") + adaptation_set.findall("ContentProtection"))
# from here we need to calculate the Segment Template and compute a final list of URLs
segment_urls = DASH.get_segment_urls(
representation=rep,
period_duration=period.get("duration") or self.manifest.get("mediaPresentationDuration"),
fallback_segment_template=adaptation_set.find("SegmentTemplate"),
fallback_base_url=period_base_url,
fallback_query=urlparse(self.url).query
)
# for some reason it's incredibly common for services to not provide
# a good and actually unique track ID, sometimes because of the lang
# dialect not being represented in the id, or the bitrate, or such.
@@ -206,7 +195,7 @@ class DASH:
tracks.add(track_type(
id_=track_id,
url=segment_urls,
url=(self.url, rep, adaptation_set, period),
codec=track_codec,
language=track_lang,
is_original_lang=not track_lang or not language or is_close_match(track_lang, [language]),
@@ -254,8 +243,7 @@ class DASH:
rep.find("SegmentBase").get("timescale") if
rep.find("SegmentBase") is not None else None
)
),
drm=drm
)
) if track_type is Video else dict(
bitrate=rep.get("bandwidth"),
channels=next(iter(
@@ -263,8 +251,7 @@ class DASH:
or adaptation_set.xpath("AudioChannelConfiguration/@value")
), None),
joc=joc,
descriptive=descriptive,
drm=drm
descriptive=descriptive
) if track_type is Audio else dict(
forced=forced,
cc=cc
@@ -276,6 +263,241 @@ class DASH:
return tracks
@staticmethod
def download_track(
track: AnyTrack,
save_dir: Path,
session: Optional[Session] = None,
proxy: Optional[str] = None,
license_widevine: Optional[Callable] = None
):
if not session:
session = Session()
elif not isinstance(session, Session):
raise TypeError(f"Expected session to be a {Session}, not {session!r}")
if not track.needs_proxy and proxy:
proxy = None
if proxy:
session.proxies.update({
"all": proxy
})
log = logging.getLogger("DASH")
manifest_url, representation, adaptation_set, period = track.url
drm = DASH.get_drm(
representation.findall("ContentProtection") +
adaptation_set.findall("ContentProtection")
)
if drm:
drm = drm[0] # just use the first supported DRM system for now
if isinstance(drm, Widevine):
# license and grab content keys
if not license_widevine:
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
license_widevine(drm)
else:
drm = None
segment_urls: list[str] = []
manifest = load_xml(session.get(manifest_url).text)
manifest_url_query = urlparse(manifest_url).query
period_base_url = period.findtext("BaseURL") or manifest.findtext("BaseURL")
if not period_base_url or not re.match("^https?://", period_base_url, re.IGNORECASE):
period_base_url = urljoin(manifest_url, period_base_url)
period_duration = period.get("duration") or manifest.get("mediaPresentationDuration")
base_url = representation.findtext("BaseURL") or period_base_url
segment_template = representation.find("SegmentTemplate")
if segment_template is None:
segment_template = adaptation_set.find("SegmentTemplate")
segment_base = representation.find("SegmentBase")
if segment_base is None:
segment_base = adaptation_set.find("SegmentBase")
segment_list = representation.find("SegmentList")
if segment_list is None:
segment_list = adaptation_set.find("SegmentList")
if segment_template is not None:
segment_template = copy(segment_template)
start_number = int(segment_template.get("startNumber") or 1)
segment_timeline = segment_template.find("SegmentTimeline")
for item in ("initialization", "media"):
value = segment_template.get(item)
if not value:
continue
if not re.match("^https?://", value, re.IGNORECASE):
if not base_url:
raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.")
value = urljoin(base_url, value)
if not urlparse(value).query and manifest_url_query:
value += f"?{manifest_url_query}"
segment_template.set(item, value)
if segment_timeline is not None:
seg_time_list = []
current_time = 0
for s in segment_timeline.findall("S"):
if s.get("t"):
current_time = int(s.get("t"))
for _ in range(1 + (int(s.get("r") or 0))):
seg_time_list.append(current_time)
current_time += int(s.get("d"))
seg_num_list = list(range(start_number, len(seg_time_list) + start_number))
segment_urls += [
DASH.replace_fields(
segment_template.get("media"),
Bandwidth=representation.get("bandwidth"),
Number=n,
RepresentationID=representation.get("id"),
Time=t
)
for t, n in zip(seg_time_list, seg_num_list)
]
else:
if not period_duration:
raise ValueError("Duration of the Period was unable to be determined.")
period_duration = DASH.pt_to_sec(period_duration)
segment_duration = float(segment_template.get("duration"))
segment_timescale = float(segment_template.get("timescale") or 1)
total_segments = math.ceil(period_duration / (segment_duration / segment_timescale))
segment_urls += [
DASH.replace_fields(
segment_template.get("media"),
Bandwidth=representation.get("bandwidth"),
Number=s,
RepresentationID=representation.get("id"),
Time=s
)
for s in range(start_number, start_number + total_segments)
]
init_data = None
init_url = segment_template.get("initialization")
if init_url:
res = session.get(DASH.replace_fields(
init_url,
Bandwidth=representation.get("bandwidth"),
RepresentationID=representation.get("id")
))
res.raise_for_status()
init_data = res.content
for i, segment_url in enumerate(segment_urls):
segment_filename = str(i).zfill(len(str(len(segment_urls))))
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
asyncio.run(aria2c(
segment_url,
segment_save_path,
session.headers,
proxy
))
# TODO: More like `segment.path`, but this will do for now
# Needed for the drm.decrypt() call couple lines down
track.path = segment_save_path
if isinstance(track, Audio) or init_data:
with open(track.path, "rb+") as f:
segment_data = f.read()
if isinstance(track, Audio):
# fix audio decryption on ATVP by fixing the sample description index
# TODO: Is this in mpeg data, or init data?
segment_data = re.sub(
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
b"\\g<1>\x01",
segment_data
)
# prepend the init data to be able to decrypt
if init_data:
f.seek(0)
f.write(init_data)
f.write(segment_data)
if drm:
# TODO: What if the manifest does not mention DRM, but has DRM
drm.decrypt(track)
if callable(track.OnDecrypted):
track.OnDecrypted(track)
elif segment_list is not None:
base_media_url = urljoin(period_base_url, base_url)
if any(x.get("media") is not None for x in segment_list.findall("SegmentURL")):
# at least one segment has no URL specified, it uses the base url and ranges
track.url = base_media_url
track.descriptor = track.Descriptor.URL
track.drm = [drm] if drm else []
else:
init_data = None
initialization = segment_list.find("Initialization")
if initialization:
source_url = initialization.get("sourceURL")
if source_url is None:
source_url = base_media_url
res = session.get(source_url)
res.raise_for_status()
init_data = res.content
for i, segment_url in enumerate(segment_list.findall("SegmentURL")):
segment_filename = str(i).zfill(len(str(len(segment_urls))))
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
media_url = segment_url.get("media")
if media_url is None:
media_url = base_media_url
asyncio.run(aria2c(
media_url,
segment_save_path,
session.headers,
proxy,
byte_range=segment_url.get("mediaRange")
))
# TODO: More like `segment.path`, but this will do for now
# Needed for the drm.decrypt() call couple lines down
track.path = segment_save_path
if isinstance(track, Audio) or init_data:
with open(track.path, "rb+") as f:
segment_data = f.read()
if isinstance(track, Audio):
# fix audio decryption on ATVP by fixing the sample description index
# TODO: Is this in mpeg data, or init data?
segment_data = re.sub(
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
b"\\g<1>\x01",
segment_data
)
# prepend the init data to be able to decrypt
if init_data:
f.seek(0)
f.write(init_data)
f.write(segment_data)
if drm:
# TODO: What if the manifest does not mention DRM, but has DRM
drm.decrypt(track)
if callable(track.OnDecrypted):
track.OnDecrypted(track)
elif segment_base is not None or base_url:
# SegmentBase more or less boils down to defined ByteRanges
# So, we don't care, just download the full file
track.url = urljoin(period_base_url, base_url)
track.descriptor = track.Descriptor.URL
track.drm = [drm] if drm else []
else:
log.error("Could not find a way to get segments from this MPD manifest.")
sys.exit(1)
@staticmethod
def get_language(*options: Any) -> Optional[Language]:
for option in options:
@@ -285,8 +507,9 @@ class DASH:
return Language.get(option)
@staticmethod
def get_drm(protections) -> Optional[list[Widevine]]:
def get_drm(protections) -> list[Widevine]:
drm = []
for protection in protections:
# TODO: Add checks for PlayReady, FairPlay, maybe more
urn = (protection.get("schemeIdUri") or "").lower()
@@ -319,9 +542,6 @@ class DASH:
kid=kid
))
if not drm:
drm = None
return drm
@staticmethod
@@ -350,91 +570,5 @@ class DASH:
url = url.replace(m.group(), f"{value:{m.group(1)}}")
return url
@staticmethod
def get_segment_urls(
representation,
period_duration: str,
fallback_segment_template,
fallback_base_url: Optional[str] = None,
fallback_query: Optional[str] = None
) -> list[str]:
segment_urls: list[str] = []
if representation.find("SegmentTemplate") is not None:
segment_template = representation.find("SegmentTemplate")
else:
segment_template = fallback_segment_template
base_url = representation.findtext("BaseURL") or fallback_base_url
if segment_template is None:
# We could implement SegmentBase, but it's basically a list of Byte Range's to download
# So just return the Base URL as a segment, why give the downloader extra effort
return [urljoin(fallback_base_url, base_url)]
segment_template = copy(segment_template)
start_number = int(segment_template.get("startNumber") or 1)
segment_timeline = segment_template.find("SegmentTimeline")
for item in ("initialization", "media"):
value = segment_template.get(item)
if not value:
continue
if not re.match("^https?://", value, re.IGNORECASE):
if not base_url:
raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.")
value = urljoin(base_url, value)
if not urlparse(value).query and fallback_query:
value += f"?{fallback_query}"
segment_template.set(item, value)
initialization = segment_template.get("initialization")
if initialization:
segment_urls.append(DASH.replace_fields(
initialization,
Bandwidth=representation.get("bandwidth"),
RepresentationID=representation.get("id")
))
if segment_timeline is not None:
seg_time_list = []
current_time = 0
for s in segment_timeline.findall("S"):
if s.get("t"):
current_time = int(s.get("t"))
for _ in range(1 + (int(s.get("r") or 0))):
seg_time_list.append(current_time)
current_time += int(s.get("d"))
seg_num_list = list(range(start_number, len(seg_time_list) + start_number))
segment_urls += [
DASH.replace_fields(
segment_template.get("media"),
Bandwidth=representation.get("bandwidth"),
Number=n,
RepresentationID=representation.get("id"),
Time=t
)
for t, n in zip(seg_time_list, seg_num_list)
]
else:
if not period_duration:
raise ValueError("Duration of the Period was unable to be determined.")
period_duration = DASH.pt_to_sec(period_duration)
segment_duration = (
float(segment_template.get("duration")) / float(segment_template.get("timescale") or 1)
)
total_segments = math.ceil(period_duration / segment_duration)
segment_urls += [
DASH.replace_fields(
segment_template.get("media"),
Bandwidth=representation.get("bandwidth"),
Number=s,
RepresentationID=representation.get("id"),
Time=s
)
for s in range(start_number, start_number + total_segments)
]
return segment_urls
__ALL__ = (DASH,)