Replace tqdm progress bars with rich progress bars

This commit is contained in:
rlaphoenix
2023-02-25 13:45:17 +00:00
parent cc69423374
commit 92895426b3
5 changed files with 202 additions and 87 deletions

View File

@@ -8,6 +8,7 @@ import time
import traceback
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from hashlib import md5
from pathlib import Path
from queue import Queue
@@ -21,7 +22,7 @@ from m3u8 import M3U8
from pywidevine.cdm import Cdm as WidevineCdm
from pywidevine.pssh import PSSH
from requests import Session
from tqdm import tqdm
from rich import filesize
from devine.core.console import console
from devine.core.constants import AnyTrack
@@ -183,6 +184,7 @@ class HLS:
def download_track(
track: AnyTrack,
save_dir: Path,
progress: partial,
session: Optional[Session] = None,
proxy: Optional[str] = None,
license_widevine: Optional[Callable] = None
@@ -214,16 +216,10 @@ class HLS:
state_event = Event()
def download_segment(
filename: str,
segment: m3u8.Segment,
init_data: Queue,
segment_key: Queue,
range_offset: Queue
) -> None:
def download_segment(filename: str, segment, init_data: Queue, segment_key: Queue) -> int:
time.sleep(0.1)
if state_event.is_set():
return
return 0
segment_save_path = (save_dir / filename).with_suffix(".mp4")
@@ -255,7 +251,7 @@ class HLS:
segment_key.put(newest_segment_key)
if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
return
return 0
newest_init_data = init_data.get()
if segment.init_section and (not newest_init_data or segment.discontinuity):
@@ -302,6 +298,8 @@ class HLS:
silent=True
))
data_size = len(newest_init_data or b"")
if isinstance(track, Audio) or newest_init_data:
with open(segment_save_path, "rb+") as f:
segment_data = f.read()
@@ -318,6 +316,7 @@ class HLS:
f.seek(0)
f.write(newest_init_data)
f.write(segment_data)
data_size += len(segment_data)
if newest_segment_key[0]:
newest_segment_key[0].decrypt(segment_save_path)
@@ -325,6 +324,8 @@ class HLS:
if callable(track.OnDecrypted):
track.OnDecrypted(track)
return data_size
segment_key = Queue(maxsize=1)
init_data = Queue(maxsize=1)
range_offset = Queue(maxsize=1)
@@ -344,36 +345,48 @@ class HLS:
init_data.put(None)
range_offset.put(0)
with tqdm(total=len(master.segments), unit="segments") as pbar:
with ThreadPoolExecutor(max_workers=16) as pool:
try:
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(master.segments)))),
segment=segment,
init_data=init_data,
segment_key=segment_key,
range_offset=range_offset
progress(total=len(master.segments))
download_start_time = time.time()
download_sizes = []
with ThreadPoolExecutor(max_workers=16) as pool:
try:
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(master.segments)))),
segment=segment,
init_data=init_data,
segment_key=segment_key
)
for i, segment in enumerate(master.segments)
)):
if download.cancelled():
continue
e = download.exception()
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
download_size = download.result()
elapsed_time = time.time() - download_start_time
download_sizes.append(download_size)
while elapsed_time - len(download_sizes) > 10:
download_sizes.pop(0)
download_speed = sum(download_sizes) / len(download_sizes)
progress(
advance=1,
downloaded=f"HLS {filesize.decimal(download_speed)}/s"
)
for i, segment in enumerate(master.segments)
)):
if download.cancelled():
continue
e = download.exception()
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
pbar.update(1)
except KeyboardInterrupt:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
console.log("Received Keyboard Interrupt, stopping...")
return
except KeyboardInterrupt:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
console.log("Received Keyboard Interrupt, stopping...")
return
@staticmethod
def get_drm(