Replace tqdm progress bars with rich progress bars

This commit is contained in:
rlaphoenix
2023-02-25 13:45:17 +00:00
parent cc69423374
commit 92895426b3
5 changed files with 202 additions and 87 deletions

View File

@@ -11,6 +11,7 @@ import traceback
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from copy import copy
from functools import partial
from hashlib import md5
from pathlib import Path
from threading import Event
@@ -23,7 +24,7 @@ from langcodes import Language, tag_is_valid
from pywidevine.cdm import Cdm as WidevineCdm
from pywidevine.pssh import PSSH
from requests import Session
from tqdm import tqdm
from rich import filesize
from devine.core.console import console
from devine.core.constants import AnyTrack
@@ -274,6 +275,7 @@ class DASH:
def download_track(
track: AnyTrack,
save_dir: Path,
progress: partial,
session: Optional[Session] = None,
proxy: Optional[str] = None,
license_widevine: Optional[Callable] = None
@@ -447,10 +449,10 @@ class DASH:
state_event = Event()
def download_segment(filename: str, segment: tuple[str, Optional[str]]):
def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
time.sleep(0.1)
if state_event.is_set():
return
return 0
segment_save_path = (save_dir / filename).with_suffix(".mp4")
@@ -476,6 +478,8 @@ class DASH:
silent=True
))
data_size = len(init_data or b"")
if isinstance(track, Audio) or init_data:
with open(segment_save_path, "rb+") as f:
segment_data = f.read()
@@ -492,6 +496,7 @@ class DASH:
f.seek(0)
f.write(init_data)
f.write(segment_data)
data_size += len(segment_data)
if drm:
# TODO: What if the manifest does not mention DRM, but has DRM
@@ -500,33 +505,48 @@ class DASH:
if callable(track.OnDecrypted):
track.OnDecrypted(track)
with tqdm(total=len(segments), unit="segments") as pbar:
with ThreadPoolExecutor(max_workers=16) as pool:
try:
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(segments)))),
segment=segment
return data_size
progress(total=len(segments))
download_start_time = time.time()
download_sizes = []
with ThreadPoolExecutor(max_workers=16) as pool:
try:
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(segments)))),
segment=segment
)
for i, segment in enumerate(segments)
)):
if download.cancelled():
continue
e = download.exception()
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
download_size = download.result()
elapsed_time = time.time() - download_start_time
download_sizes.append(download_size)
while elapsed_time - len(download_sizes) > 10:
download_sizes.pop(0)
download_speed = sum(download_sizes) / len(download_sizes)
progress(
advance=1,
downloaded=f"DASH {filesize.decimal(download_speed)}/s"
)
for i, segment in enumerate(segments)
)):
if download.cancelled():
continue
e = download.exception()
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
pbar.update(1)
except KeyboardInterrupt:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
console.log("Received Keyboard Interrupt, stopping...")
return
except KeyboardInterrupt:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
console.log("Received Keyboard Interrupt, stopping...")
return
@staticmethod
def get_language(*options: Any) -> Optional[Language]: