Completely rewrite downloading system
The new system now downloads and decrypts segments individually instead of downloading all segments, merging them, and then decrypting. Overall the download system now acts more like a normal player. This fixes #23 as the new HLS download system detects changes in keys and init segments as segments are downloaded. DASH still only supports one period, and one period only, but hopefully I can change that in the future. Downloading code is now also moved from the Track classes to the manifest classes. Download progress is now also actually helpful for segmented downloads (all HLS, and most DASH streams). It uses TQDM to show a progress bar based on how many segments it needs to download, and how fast it downloads them. There's only one down side currently. Downloading of segmented videos no longer have the benefit of aria2c's -j parameter. Where it can download n URLs concurrently. Aria2c is still used but only -x and -s is going to make a difference. In the future I will make HLS and DASH download in a multi-threaded way, sort of a manual version of -j.
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import sys
|
||||
from copy import copy
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from urllib.parse import urljoin, urlparse
|
||||
from uuid import UUID
|
||||
@@ -15,6 +19,8 @@ from pywidevine.cdm import Cdm as WidevineCdm
|
||||
from pywidevine.pssh import PSSH
|
||||
from requests import Session
|
||||
|
||||
from devine.core.constants import AnyTrack
|
||||
from devine.core.downloaders import aria2c
|
||||
from devine.core.drm import Widevine
|
||||
from devine.core.tracks import Audio, Subtitle, Tracks, Video
|
||||
from devine.core.utilities import is_close_match
|
||||
@@ -92,12 +98,7 @@ class DASH:
|
||||
if callable(period_filter) and period_filter(period):
|
||||
continue
|
||||
|
||||
period_base_url = period.findtext("BaseURL") or self.manifest.findtext("BaseURL")
|
||||
if not period_base_url or not re.match("^https?://", period_base_url, re.IGNORECASE):
|
||||
period_base_url = urljoin(self.url, period_base_url)
|
||||
|
||||
for adaptation_set in period.findall("AdaptationSet"):
|
||||
# flags
|
||||
trick_mode = any(
|
||||
x.get("schemeIdUri") == "http://dashif.org/guidelines/trickmode"
|
||||
for x in (
|
||||
@@ -105,6 +106,10 @@ class DASH:
|
||||
adaptation_set.findall("SupplementalProperty")
|
||||
)
|
||||
)
|
||||
if trick_mode:
|
||||
# we don't want trick mode streams (they are only used for fast-forward/rewind)
|
||||
continue
|
||||
|
||||
descriptive = any(
|
||||
(x.get("schemeIdUri"), x.get("value")) == ("urn:mpeg:dash:role:2011", "descriptive")
|
||||
for x in adaptation_set.findall("Accessibility")
|
||||
@@ -121,12 +126,8 @@ class DASH:
|
||||
for x in adaptation_set.findall("Role")
|
||||
)
|
||||
|
||||
if trick_mode:
|
||||
# we don't want trick mode streams (they are only used for fast-forward/rewind)
|
||||
continue
|
||||
|
||||
for rep in adaptation_set.findall("Representation"):
|
||||
supplements = rep.findall("SupplementalProperty") + adaptation_set.findall("SupplementalProperty")
|
||||
codecs = rep.get("codecs") or adaptation_set.get("codecs")
|
||||
|
||||
content_type = adaptation_set.get("contentType") or \
|
||||
adaptation_set.get("mimeType") or \
|
||||
@@ -136,8 +137,6 @@ class DASH:
|
||||
raise ValueError("No content type value could be found")
|
||||
content_type = content_type.split("/")[0]
|
||||
|
||||
codecs = rep.get("codecs") or adaptation_set.get("codecs")
|
||||
|
||||
if content_type.startswith("image"):
|
||||
# we don't want what's likely thumbnails for the seekbar
|
||||
continue
|
||||
@@ -154,6 +153,8 @@ class DASH:
|
||||
if mime and not mime.endswith("/mp4"):
|
||||
codecs = mime.split("/")[1]
|
||||
|
||||
supplements = rep.findall("SupplementalProperty") + adaptation_set.findall("SupplementalProperty")
|
||||
|
||||
joc = next((
|
||||
x.get("value")
|
||||
for x in supplements
|
||||
@@ -167,18 +168,6 @@ class DASH:
|
||||
"The provided fallback language is not valid or is `None` or `und`."
|
||||
)
|
||||
|
||||
drm = DASH.get_drm(rep.findall("ContentProtection") + adaptation_set.findall("ContentProtection"))
|
||||
|
||||
# from here we need to calculate the Segment Template and compute a final list of URLs
|
||||
|
||||
segment_urls = DASH.get_segment_urls(
|
||||
representation=rep,
|
||||
period_duration=period.get("duration") or self.manifest.get("mediaPresentationDuration"),
|
||||
fallback_segment_template=adaptation_set.find("SegmentTemplate"),
|
||||
fallback_base_url=period_base_url,
|
||||
fallback_query=urlparse(self.url).query
|
||||
)
|
||||
|
||||
# for some reason it's incredibly common for services to not provide
|
||||
# a good and actually unique track ID, sometimes because of the lang
|
||||
# dialect not being represented in the id, or the bitrate, or such.
|
||||
@@ -206,7 +195,7 @@ class DASH:
|
||||
|
||||
tracks.add(track_type(
|
||||
id_=track_id,
|
||||
url=segment_urls,
|
||||
url=(self.url, rep, adaptation_set, period),
|
||||
codec=track_codec,
|
||||
language=track_lang,
|
||||
is_original_lang=not track_lang or not language or is_close_match(track_lang, [language]),
|
||||
@@ -254,8 +243,7 @@ class DASH:
|
||||
rep.find("SegmentBase").get("timescale") if
|
||||
rep.find("SegmentBase") is not None else None
|
||||
)
|
||||
),
|
||||
drm=drm
|
||||
)
|
||||
) if track_type is Video else dict(
|
||||
bitrate=rep.get("bandwidth"),
|
||||
channels=next(iter(
|
||||
@@ -263,8 +251,7 @@ class DASH:
|
||||
or adaptation_set.xpath("AudioChannelConfiguration/@value")
|
||||
), None),
|
||||
joc=joc,
|
||||
descriptive=descriptive,
|
||||
drm=drm
|
||||
descriptive=descriptive
|
||||
) if track_type is Audio else dict(
|
||||
forced=forced,
|
||||
cc=cc
|
||||
@@ -276,6 +263,241 @@ class DASH:
|
||||
|
||||
return tracks
|
||||
|
||||
@staticmethod
|
||||
def download_track(
|
||||
track: AnyTrack,
|
||||
save_dir: Path,
|
||||
session: Optional[Session] = None,
|
||||
proxy: Optional[str] = None,
|
||||
license_widevine: Optional[Callable] = None
|
||||
):
|
||||
if not session:
|
||||
session = Session()
|
||||
elif not isinstance(session, Session):
|
||||
raise TypeError(f"Expected session to be a {Session}, not {session!r}")
|
||||
|
||||
if not track.needs_proxy and proxy:
|
||||
proxy = None
|
||||
|
||||
if proxy:
|
||||
session.proxies.update({
|
||||
"all": proxy
|
||||
})
|
||||
|
||||
log = logging.getLogger("DASH")
|
||||
|
||||
manifest_url, representation, adaptation_set, period = track.url
|
||||
|
||||
drm = DASH.get_drm(
|
||||
representation.findall("ContentProtection") +
|
||||
adaptation_set.findall("ContentProtection")
|
||||
)
|
||||
if drm:
|
||||
drm = drm[0] # just use the first supported DRM system for now
|
||||
if isinstance(drm, Widevine):
|
||||
# license and grab content keys
|
||||
if not license_widevine:
|
||||
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
|
||||
license_widevine(drm)
|
||||
else:
|
||||
drm = None
|
||||
|
||||
segment_urls: list[str] = []
|
||||
manifest = load_xml(session.get(manifest_url).text)
|
||||
manifest_url_query = urlparse(manifest_url).query
|
||||
|
||||
period_base_url = period.findtext("BaseURL") or manifest.findtext("BaseURL")
|
||||
if not period_base_url or not re.match("^https?://", period_base_url, re.IGNORECASE):
|
||||
period_base_url = urljoin(manifest_url, period_base_url)
|
||||
period_duration = period.get("duration") or manifest.get("mediaPresentationDuration")
|
||||
|
||||
base_url = representation.findtext("BaseURL") or period_base_url
|
||||
|
||||
segment_template = representation.find("SegmentTemplate")
|
||||
if segment_template is None:
|
||||
segment_template = adaptation_set.find("SegmentTemplate")
|
||||
|
||||
segment_base = representation.find("SegmentBase")
|
||||
if segment_base is None:
|
||||
segment_base = adaptation_set.find("SegmentBase")
|
||||
|
||||
segment_list = representation.find("SegmentList")
|
||||
if segment_list is None:
|
||||
segment_list = adaptation_set.find("SegmentList")
|
||||
|
||||
if segment_template is not None:
|
||||
segment_template = copy(segment_template)
|
||||
start_number = int(segment_template.get("startNumber") or 1)
|
||||
segment_timeline = segment_template.find("SegmentTimeline")
|
||||
|
||||
for item in ("initialization", "media"):
|
||||
value = segment_template.get(item)
|
||||
if not value:
|
||||
continue
|
||||
if not re.match("^https?://", value, re.IGNORECASE):
|
||||
if not base_url:
|
||||
raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.")
|
||||
value = urljoin(base_url, value)
|
||||
if not urlparse(value).query and manifest_url_query:
|
||||
value += f"?{manifest_url_query}"
|
||||
segment_template.set(item, value)
|
||||
|
||||
if segment_timeline is not None:
|
||||
seg_time_list = []
|
||||
current_time = 0
|
||||
for s in segment_timeline.findall("S"):
|
||||
if s.get("t"):
|
||||
current_time = int(s.get("t"))
|
||||
for _ in range(1 + (int(s.get("r") or 0))):
|
||||
seg_time_list.append(current_time)
|
||||
current_time += int(s.get("d"))
|
||||
seg_num_list = list(range(start_number, len(seg_time_list) + start_number))
|
||||
segment_urls += [
|
||||
DASH.replace_fields(
|
||||
segment_template.get("media"),
|
||||
Bandwidth=representation.get("bandwidth"),
|
||||
Number=n,
|
||||
RepresentationID=representation.get("id"),
|
||||
Time=t
|
||||
)
|
||||
for t, n in zip(seg_time_list, seg_num_list)
|
||||
]
|
||||
else:
|
||||
if not period_duration:
|
||||
raise ValueError("Duration of the Period was unable to be determined.")
|
||||
period_duration = DASH.pt_to_sec(period_duration)
|
||||
segment_duration = float(segment_template.get("duration"))
|
||||
segment_timescale = float(segment_template.get("timescale") or 1)
|
||||
|
||||
total_segments = math.ceil(period_duration / (segment_duration / segment_timescale))
|
||||
segment_urls += [
|
||||
DASH.replace_fields(
|
||||
segment_template.get("media"),
|
||||
Bandwidth=representation.get("bandwidth"),
|
||||
Number=s,
|
||||
RepresentationID=representation.get("id"),
|
||||
Time=s
|
||||
)
|
||||
for s in range(start_number, start_number + total_segments)
|
||||
]
|
||||
|
||||
init_data = None
|
||||
init_url = segment_template.get("initialization")
|
||||
if init_url:
|
||||
res = session.get(DASH.replace_fields(
|
||||
init_url,
|
||||
Bandwidth=representation.get("bandwidth"),
|
||||
RepresentationID=representation.get("id")
|
||||
))
|
||||
res.raise_for_status()
|
||||
init_data = res.content
|
||||
|
||||
for i, segment_url in enumerate(segment_urls):
|
||||
segment_filename = str(i).zfill(len(str(len(segment_urls))))
|
||||
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
|
||||
|
||||
asyncio.run(aria2c(
|
||||
segment_url,
|
||||
segment_save_path,
|
||||
session.headers,
|
||||
proxy
|
||||
))
|
||||
# TODO: More like `segment.path`, but this will do for now
|
||||
# Needed for the drm.decrypt() call couple lines down
|
||||
track.path = segment_save_path
|
||||
|
||||
if isinstance(track, Audio) or init_data:
|
||||
with open(track.path, "rb+") as f:
|
||||
segment_data = f.read()
|
||||
if isinstance(track, Audio):
|
||||
# fix audio decryption on ATVP by fixing the sample description index
|
||||
# TODO: Is this in mpeg data, or init data?
|
||||
segment_data = re.sub(
|
||||
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
|
||||
b"\\g<1>\x01",
|
||||
segment_data
|
||||
)
|
||||
# prepend the init data to be able to decrypt
|
||||
if init_data:
|
||||
f.seek(0)
|
||||
f.write(init_data)
|
||||
f.write(segment_data)
|
||||
|
||||
if drm:
|
||||
# TODO: What if the manifest does not mention DRM, but has DRM
|
||||
drm.decrypt(track)
|
||||
if callable(track.OnDecrypted):
|
||||
track.OnDecrypted(track)
|
||||
elif segment_list is not None:
|
||||
base_media_url = urljoin(period_base_url, base_url)
|
||||
if any(x.get("media") is not None for x in segment_list.findall("SegmentURL")):
|
||||
# at least one segment has no URL specified, it uses the base url and ranges
|
||||
track.url = base_media_url
|
||||
track.descriptor = track.Descriptor.URL
|
||||
track.drm = [drm] if drm else []
|
||||
else:
|
||||
init_data = None
|
||||
initialization = segment_list.find("Initialization")
|
||||
if initialization:
|
||||
source_url = initialization.get("sourceURL")
|
||||
if source_url is None:
|
||||
source_url = base_media_url
|
||||
|
||||
res = session.get(source_url)
|
||||
res.raise_for_status()
|
||||
init_data = res.content
|
||||
|
||||
for i, segment_url in enumerate(segment_list.findall("SegmentURL")):
|
||||
segment_filename = str(i).zfill(len(str(len(segment_urls))))
|
||||
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
|
||||
|
||||
media_url = segment_url.get("media")
|
||||
if media_url is None:
|
||||
media_url = base_media_url
|
||||
|
||||
asyncio.run(aria2c(
|
||||
media_url,
|
||||
segment_save_path,
|
||||
session.headers,
|
||||
proxy,
|
||||
byte_range=segment_url.get("mediaRange")
|
||||
))
|
||||
# TODO: More like `segment.path`, but this will do for now
|
||||
# Needed for the drm.decrypt() call couple lines down
|
||||
track.path = segment_save_path
|
||||
|
||||
if isinstance(track, Audio) or init_data:
|
||||
with open(track.path, "rb+") as f:
|
||||
segment_data = f.read()
|
||||
if isinstance(track, Audio):
|
||||
# fix audio decryption on ATVP by fixing the sample description index
|
||||
# TODO: Is this in mpeg data, or init data?
|
||||
segment_data = re.sub(
|
||||
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
|
||||
b"\\g<1>\x01",
|
||||
segment_data
|
||||
)
|
||||
# prepend the init data to be able to decrypt
|
||||
if init_data:
|
||||
f.seek(0)
|
||||
f.write(init_data)
|
||||
f.write(segment_data)
|
||||
|
||||
if drm:
|
||||
# TODO: What if the manifest does not mention DRM, but has DRM
|
||||
drm.decrypt(track)
|
||||
if callable(track.OnDecrypted):
|
||||
track.OnDecrypted(track)
|
||||
elif segment_base is not None or base_url:
|
||||
# SegmentBase more or less boils down to defined ByteRanges
|
||||
# So, we don't care, just download the full file
|
||||
track.url = urljoin(period_base_url, base_url)
|
||||
track.descriptor = track.Descriptor.URL
|
||||
track.drm = [drm] if drm else []
|
||||
else:
|
||||
log.error("Could not find a way to get segments from this MPD manifest.")
|
||||
sys.exit(1)
|
||||
|
||||
@staticmethod
|
||||
def get_language(*options: Any) -> Optional[Language]:
|
||||
for option in options:
|
||||
@@ -285,8 +507,9 @@ class DASH:
|
||||
return Language.get(option)
|
||||
|
||||
@staticmethod
|
||||
def get_drm(protections) -> Optional[list[Widevine]]:
|
||||
def get_drm(protections) -> list[Widevine]:
|
||||
drm = []
|
||||
|
||||
for protection in protections:
|
||||
# TODO: Add checks for PlayReady, FairPlay, maybe more
|
||||
urn = (protection.get("schemeIdUri") or "").lower()
|
||||
@@ -319,9 +542,6 @@ class DASH:
|
||||
kid=kid
|
||||
))
|
||||
|
||||
if not drm:
|
||||
drm = None
|
||||
|
||||
return drm
|
||||
|
||||
@staticmethod
|
||||
@@ -350,91 +570,5 @@ class DASH:
|
||||
url = url.replace(m.group(), f"{value:{m.group(1)}}")
|
||||
return url
|
||||
|
||||
@staticmethod
|
||||
def get_segment_urls(
|
||||
representation,
|
||||
period_duration: str,
|
||||
fallback_segment_template,
|
||||
fallback_base_url: Optional[str] = None,
|
||||
fallback_query: Optional[str] = None
|
||||
) -> list[str]:
|
||||
segment_urls: list[str] = []
|
||||
if representation.find("SegmentTemplate") is not None:
|
||||
segment_template = representation.find("SegmentTemplate")
|
||||
else:
|
||||
segment_template = fallback_segment_template
|
||||
base_url = representation.findtext("BaseURL") or fallback_base_url
|
||||
|
||||
if segment_template is None:
|
||||
# We could implement SegmentBase, but it's basically a list of Byte Range's to download
|
||||
# So just return the Base URL as a segment, why give the downloader extra effort
|
||||
return [urljoin(fallback_base_url, base_url)]
|
||||
|
||||
segment_template = copy(segment_template)
|
||||
start_number = int(segment_template.get("startNumber") or 1)
|
||||
segment_timeline = segment_template.find("SegmentTimeline")
|
||||
|
||||
for item in ("initialization", "media"):
|
||||
value = segment_template.get(item)
|
||||
if not value:
|
||||
continue
|
||||
if not re.match("^https?://", value, re.IGNORECASE):
|
||||
if not base_url:
|
||||
raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.")
|
||||
value = urljoin(base_url, value)
|
||||
if not urlparse(value).query and fallback_query:
|
||||
value += f"?{fallback_query}"
|
||||
segment_template.set(item, value)
|
||||
|
||||
initialization = segment_template.get("initialization")
|
||||
if initialization:
|
||||
segment_urls.append(DASH.replace_fields(
|
||||
initialization,
|
||||
Bandwidth=representation.get("bandwidth"),
|
||||
RepresentationID=representation.get("id")
|
||||
))
|
||||
|
||||
if segment_timeline is not None:
|
||||
seg_time_list = []
|
||||
current_time = 0
|
||||
for s in segment_timeline.findall("S"):
|
||||
if s.get("t"):
|
||||
current_time = int(s.get("t"))
|
||||
for _ in range(1 + (int(s.get("r") or 0))):
|
||||
seg_time_list.append(current_time)
|
||||
current_time += int(s.get("d"))
|
||||
seg_num_list = list(range(start_number, len(seg_time_list) + start_number))
|
||||
segment_urls += [
|
||||
DASH.replace_fields(
|
||||
segment_template.get("media"),
|
||||
Bandwidth=representation.get("bandwidth"),
|
||||
Number=n,
|
||||
RepresentationID=representation.get("id"),
|
||||
Time=t
|
||||
)
|
||||
for t, n in zip(seg_time_list, seg_num_list)
|
||||
]
|
||||
else:
|
||||
if not period_duration:
|
||||
raise ValueError("Duration of the Period was unable to be determined.")
|
||||
period_duration = DASH.pt_to_sec(period_duration)
|
||||
|
||||
segment_duration = (
|
||||
float(segment_template.get("duration")) / float(segment_template.get("timescale") or 1)
|
||||
)
|
||||
total_segments = math.ceil(period_duration / segment_duration)
|
||||
segment_urls += [
|
||||
DASH.replace_fields(
|
||||
segment_template.get("media"),
|
||||
Bandwidth=representation.get("bandwidth"),
|
||||
Number=s,
|
||||
RepresentationID=representation.get("id"),
|
||||
Time=s
|
||||
)
|
||||
for s in range(start_number, start_number + total_segments)
|
||||
]
|
||||
|
||||
return segment_urls
|
||||
|
||||
|
||||
__ALL__ = (DASH,)
|
||||
|
||||
Reference in New Issue
Block a user