Replace tqdm progress bars with rich progress bars
This commit is contained in:
@@ -11,6 +11,7 @@ import traceback
|
||||
from concurrent import futures
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from threading import Event
|
||||
@@ -23,7 +24,7 @@ from langcodes import Language, tag_is_valid
|
||||
from pywidevine.cdm import Cdm as WidevineCdm
|
||||
from pywidevine.pssh import PSSH
|
||||
from requests import Session
|
||||
from tqdm import tqdm
|
||||
from rich import filesize
|
||||
|
||||
from devine.core.console import console
|
||||
from devine.core.constants import AnyTrack
|
||||
@@ -274,6 +275,7 @@ class DASH:
|
||||
def download_track(
|
||||
track: AnyTrack,
|
||||
save_dir: Path,
|
||||
progress: partial,
|
||||
session: Optional[Session] = None,
|
||||
proxy: Optional[str] = None,
|
||||
license_widevine: Optional[Callable] = None
|
||||
@@ -447,10 +449,10 @@ class DASH:
|
||||
|
||||
state_event = Event()
|
||||
|
||||
def download_segment(filename: str, segment: tuple[str, Optional[str]]):
|
||||
def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
|
||||
time.sleep(0.1)
|
||||
if state_event.is_set():
|
||||
return
|
||||
return 0
|
||||
|
||||
segment_save_path = (save_dir / filename).with_suffix(".mp4")
|
||||
|
||||
@@ -476,6 +478,8 @@ class DASH:
|
||||
silent=True
|
||||
))
|
||||
|
||||
data_size = len(init_data or b"")
|
||||
|
||||
if isinstance(track, Audio) or init_data:
|
||||
with open(segment_save_path, "rb+") as f:
|
||||
segment_data = f.read()
|
||||
@@ -492,6 +496,7 @@ class DASH:
|
||||
f.seek(0)
|
||||
f.write(init_data)
|
||||
f.write(segment_data)
|
||||
data_size += len(segment_data)
|
||||
|
||||
if drm:
|
||||
# TODO: What if the manifest does not mention DRM, but has DRM
|
||||
@@ -500,33 +505,48 @@ class DASH:
|
||||
if callable(track.OnDecrypted):
|
||||
track.OnDecrypted(track)
|
||||
|
||||
with tqdm(total=len(segments), unit="segments") as pbar:
|
||||
with ThreadPoolExecutor(max_workers=16) as pool:
|
||||
try:
|
||||
for download in futures.as_completed((
|
||||
pool.submit(
|
||||
download_segment,
|
||||
filename=str(i).zfill(len(str(len(segments)))),
|
||||
segment=segment
|
||||
return data_size
|
||||
|
||||
progress(total=len(segments))
|
||||
|
||||
download_start_time = time.time()
|
||||
download_sizes = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=16) as pool:
|
||||
try:
|
||||
for download in futures.as_completed((
|
||||
pool.submit(
|
||||
download_segment,
|
||||
filename=str(i).zfill(len(str(len(segments)))),
|
||||
segment=segment
|
||||
)
|
||||
for i, segment in enumerate(segments)
|
||||
)):
|
||||
if download.cancelled():
|
||||
continue
|
||||
e = download.exception()
|
||||
if e:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
traceback.print_exception(e)
|
||||
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
download_size = download.result()
|
||||
elapsed_time = time.time() - download_start_time
|
||||
download_sizes.append(download_size)
|
||||
while elapsed_time - len(download_sizes) > 10:
|
||||
download_sizes.pop(0)
|
||||
download_speed = sum(download_sizes) / len(download_sizes)
|
||||
progress(
|
||||
advance=1,
|
||||
downloaded=f"DASH {filesize.decimal(download_speed)}/s"
|
||||
)
|
||||
for i, segment in enumerate(segments)
|
||||
)):
|
||||
if download.cancelled():
|
||||
continue
|
||||
e = download.exception()
|
||||
if e:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
traceback.print_exception(e)
|
||||
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
pbar.update(1)
|
||||
except KeyboardInterrupt:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
console.log("Received Keyboard Interrupt, stopping...")
|
||||
return
|
||||
except KeyboardInterrupt:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
console.log("Received Keyboard Interrupt, stopping...")
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def get_language(*options: Any) -> Optional[Language]:
|
||||
|
||||
Reference in New Issue
Block a user