Replace tqdm progress bars with rich progress bars
This commit is contained in:
@@ -8,6 +8,7 @@ import time
|
||||
import traceback
|
||||
from concurrent import futures
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
@@ -21,7 +22,7 @@ from m3u8 import M3U8
|
||||
from pywidevine.cdm import Cdm as WidevineCdm
|
||||
from pywidevine.pssh import PSSH
|
||||
from requests import Session
|
||||
from tqdm import tqdm
|
||||
from rich import filesize
|
||||
|
||||
from devine.core.console import console
|
||||
from devine.core.constants import AnyTrack
|
||||
@@ -183,6 +184,7 @@ class HLS:
|
||||
def download_track(
|
||||
track: AnyTrack,
|
||||
save_dir: Path,
|
||||
progress: partial,
|
||||
session: Optional[Session] = None,
|
||||
proxy: Optional[str] = None,
|
||||
license_widevine: Optional[Callable] = None
|
||||
@@ -214,16 +216,10 @@ class HLS:
|
||||
|
||||
state_event = Event()
|
||||
|
||||
def download_segment(
|
||||
filename: str,
|
||||
segment: m3u8.Segment,
|
||||
init_data: Queue,
|
||||
segment_key: Queue,
|
||||
range_offset: Queue
|
||||
) -> None:
|
||||
def download_segment(filename: str, segment, init_data: Queue, segment_key: Queue) -> int:
|
||||
time.sleep(0.1)
|
||||
if state_event.is_set():
|
||||
return
|
||||
return 0
|
||||
|
||||
segment_save_path = (save_dir / filename).with_suffix(".mp4")
|
||||
|
||||
@@ -255,7 +251,7 @@ class HLS:
|
||||
segment_key.put(newest_segment_key)
|
||||
|
||||
if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
|
||||
return
|
||||
return 0
|
||||
|
||||
newest_init_data = init_data.get()
|
||||
if segment.init_section and (not newest_init_data or segment.discontinuity):
|
||||
@@ -302,6 +298,8 @@ class HLS:
|
||||
silent=True
|
||||
))
|
||||
|
||||
data_size = len(newest_init_data or b"")
|
||||
|
||||
if isinstance(track, Audio) or newest_init_data:
|
||||
with open(segment_save_path, "rb+") as f:
|
||||
segment_data = f.read()
|
||||
@@ -318,6 +316,7 @@ class HLS:
|
||||
f.seek(0)
|
||||
f.write(newest_init_data)
|
||||
f.write(segment_data)
|
||||
data_size += len(segment_data)
|
||||
|
||||
if newest_segment_key[0]:
|
||||
newest_segment_key[0].decrypt(segment_save_path)
|
||||
@@ -325,6 +324,8 @@ class HLS:
|
||||
if callable(track.OnDecrypted):
|
||||
track.OnDecrypted(track)
|
||||
|
||||
return data_size
|
||||
|
||||
segment_key = Queue(maxsize=1)
|
||||
init_data = Queue(maxsize=1)
|
||||
range_offset = Queue(maxsize=1)
|
||||
@@ -344,36 +345,48 @@ class HLS:
|
||||
init_data.put(None)
|
||||
range_offset.put(0)
|
||||
|
||||
with tqdm(total=len(master.segments), unit="segments") as pbar:
|
||||
with ThreadPoolExecutor(max_workers=16) as pool:
|
||||
try:
|
||||
for download in futures.as_completed((
|
||||
pool.submit(
|
||||
download_segment,
|
||||
filename=str(i).zfill(len(str(len(master.segments)))),
|
||||
segment=segment,
|
||||
init_data=init_data,
|
||||
segment_key=segment_key,
|
||||
range_offset=range_offset
|
||||
progress(total=len(master.segments))
|
||||
|
||||
download_start_time = time.time()
|
||||
download_sizes = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=16) as pool:
|
||||
try:
|
||||
for download in futures.as_completed((
|
||||
pool.submit(
|
||||
download_segment,
|
||||
filename=str(i).zfill(len(str(len(master.segments)))),
|
||||
segment=segment,
|
||||
init_data=init_data,
|
||||
segment_key=segment_key
|
||||
)
|
||||
for i, segment in enumerate(master.segments)
|
||||
)):
|
||||
if download.cancelled():
|
||||
continue
|
||||
e = download.exception()
|
||||
if e:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
traceback.print_exception(e)
|
||||
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
download_size = download.result()
|
||||
elapsed_time = time.time() - download_start_time
|
||||
download_sizes.append(download_size)
|
||||
while elapsed_time - len(download_sizes) > 10:
|
||||
download_sizes.pop(0)
|
||||
download_speed = sum(download_sizes) / len(download_sizes)
|
||||
progress(
|
||||
advance=1,
|
||||
downloaded=f"HLS {filesize.decimal(download_speed)}/s"
|
||||
)
|
||||
for i, segment in enumerate(master.segments)
|
||||
)):
|
||||
if download.cancelled():
|
||||
continue
|
||||
e = download.exception()
|
||||
if e:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
traceback.print_exception(e)
|
||||
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
pbar.update(1)
|
||||
except KeyboardInterrupt:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
console.log("Received Keyboard Interrupt, stopping...")
|
||||
return
|
||||
except KeyboardInterrupt:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
console.log("Received Keyboard Interrupt, stopping...")
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def get_drm(
|
||||
|
||||
Reference in New Issue
Block a user