mirror of
				https://github.com/devine-dl/pywidevine.git
				synced 2025-11-04 03:44:50 +00:00 
			
		
		
		
	Various typing/linting fixes and improvements
This commit is contained in:
		
							parent
							
								
									97ec2e1c60
								
							
						
					
					
						commit
						0e6aa1d5e8
					
				@ -510,8 +510,7 @@ class Cdm:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        input_file = Path(input_file)
 | 
					        input_file = Path(input_file)
 | 
				
			||||||
        output_file = Path(output_file)
 | 
					        output_file = Path(output_file)
 | 
				
			||||||
        if temp_dir:
 | 
					        temp_dir_ = Path(temp_dir) if temp_dir else None
 | 
				
			||||||
            temp_dir = Path(temp_dir)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not input_file.is_file():
 | 
					        if not input_file.is_file():
 | 
				
			||||||
            raise FileNotFoundError(f"Input file does not exist, {input_file}")
 | 
					            raise FileNotFoundError(f"Input file does not exist, {input_file}")
 | 
				
			||||||
@ -545,9 +544,9 @@ class Cdm:
 | 
				
			|||||||
            ])
 | 
					            ])
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if temp_dir:
 | 
					        if temp_dir_:
 | 
				
			||||||
            temp_dir.mkdir(parents=True, exist_ok=True)
 | 
					            temp_dir_.mkdir(parents=True, exist_ok=True)
 | 
				
			||||||
            args.extend(["--temp_dir", temp_dir])
 | 
					            args.extend(["--temp_dir", str(temp_dir_)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return subprocess.check_call([executable, *args])
 | 
					        return subprocess.check_call([executable, *args])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -555,8 +554,8 @@ class Cdm:
 | 
				
			|||||||
    def encrypt_client_id(
 | 
					    def encrypt_client_id(
 | 
				
			||||||
        client_id: ClientIdentification,
 | 
					        client_id: ClientIdentification,
 | 
				
			||||||
        service_certificate: Union[SignedDrmCertificate, DrmCertificate],
 | 
					        service_certificate: Union[SignedDrmCertificate, DrmCertificate],
 | 
				
			||||||
        key: bytes = None,
 | 
					        key: Optional[bytes] = None,
 | 
				
			||||||
        iv: bytes = None
 | 
					        iv: Optional[bytes] = None
 | 
				
			||||||
    ) -> EncryptedClientIdentification:
 | 
					    ) -> EncryptedClientIdentification:
 | 
				
			||||||
        """Encrypt the Client ID with the Service's Privacy Certificate."""
 | 
					        """Encrypt the Client ID with the Service's Privacy Certificate."""
 | 
				
			||||||
        privacy_key = key or get_random_bytes(16)
 | 
					        privacy_key = key or get_random_bytes(16)
 | 
				
			||||||
 | 
				
			|||||||
@ -199,36 +199,36 @@ class Device:
 | 
				
			|||||||
            raise ValueError("Device Data does not seem to be a WVD file (v0).")
 | 
					            raise ValueError("Device Data does not seem to be a WVD file (v0).")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if header.version == 1:  # v1 to v2
 | 
					        if header.version == 1:  # v1 to v2
 | 
				
			||||||
            data = _Structures.v1.parse(data)
 | 
					            v1_struct = _Structures.v1.parse(data)
 | 
				
			||||||
            data.version = 2  # update version to 2 to allow loading
 | 
					            v1_struct.version = 2  # update version to 2 to allow loading
 | 
				
			||||||
            data.flags = Container()  # blank flags that may have been used in v1
 | 
					            v1_struct.flags = Container()  # blank flags that may have been used in v1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            vmp = FileHashes()
 | 
					            vmp = FileHashes()
 | 
				
			||||||
            if data.vmp:
 | 
					            if v1_struct.vmp:
 | 
				
			||||||
                try:
 | 
					                try:
 | 
				
			||||||
                    vmp.ParseFromString(data.vmp)
 | 
					                    vmp.ParseFromString(v1_struct.vmp)
 | 
				
			||||||
                    if vmp.SerializeToString() != data.vmp:
 | 
					                    if vmp.SerializeToString() != v1_struct.vmp:
 | 
				
			||||||
                        raise DecodeError("partial parse")
 | 
					                        raise DecodeError("partial parse")
 | 
				
			||||||
                except DecodeError as e:
 | 
					                except DecodeError as e:
 | 
				
			||||||
                    raise DecodeError(f"Failed to parse VMP data as FileHashes, {e}")
 | 
					                    raise DecodeError(f"Failed to parse VMP data as FileHashes, {e}")
 | 
				
			||||||
                data.vmp = vmp
 | 
					                v1_struct.vmp = vmp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                client_id = ClientIdentification()
 | 
					                client_id = ClientIdentification()
 | 
				
			||||||
                try:
 | 
					                try:
 | 
				
			||||||
                    client_id.ParseFromString(data.client_id)
 | 
					                    client_id.ParseFromString(v1_struct.client_id)
 | 
				
			||||||
                    if client_id.SerializeToString() != data.client_id:
 | 
					                    if client_id.SerializeToString() != v1_struct.client_id:
 | 
				
			||||||
                        raise DecodeError("partial parse")
 | 
					                        raise DecodeError("partial parse")
 | 
				
			||||||
                except DecodeError as e:
 | 
					                except DecodeError as e:
 | 
				
			||||||
                    raise DecodeError(f"Failed to parse VMP data as FileHashes, {e}")
 | 
					                    raise DecodeError(f"Failed to parse VMP data as FileHashes, {e}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                new_vmp_data = data.vmp.SerializeToString()
 | 
					                new_vmp_data = v1_struct.vmp.SerializeToString()
 | 
				
			||||||
                if client_id.vmp_data and client_id.vmp_data != new_vmp_data:
 | 
					                if client_id.vmp_data and client_id.vmp_data != new_vmp_data:
 | 
				
			||||||
                    logging.getLogger("migrate").warning("Client ID already has Verified Media Path data")
 | 
					                    logging.getLogger("migrate").warning("Client ID already has Verified Media Path data")
 | 
				
			||||||
                client_id.vmp_data = new_vmp_data
 | 
					                client_id.vmp_data = new_vmp_data
 | 
				
			||||||
                data.client_id = client_id.SerializeToString()
 | 
					                v1_struct.client_id = client_id.SerializeToString()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                data = _Structures.v2.build(data)
 | 
					                data = _Structures.v2.build(v1_struct)
 | 
				
			||||||
            except ConstructError as e:
 | 
					            except ConstructError as e:
 | 
				
			||||||
                raise ValueError(f"Migration failed, {e}")
 | 
					                raise ValueError(f"Migration failed, {e}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -26,10 +26,8 @@ def main(version: bool, debug: bool) -> None:
 | 
				
			|||||||
    logging.basicConfig(level=logging.DEBUG if debug else logging.INFO)
 | 
					    logging.basicConfig(level=logging.DEBUG if debug else logging.INFO)
 | 
				
			||||||
    log = logging.getLogger()
 | 
					    log = logging.getLogger()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    copyright_years = 2022
 | 
					 | 
				
			||||||
    current_year = datetime.now().year
 | 
					    current_year = datetime.now().year
 | 
				
			||||||
    if copyright_years != current_year:
 | 
					    copyright_years = f"2022-{current_year}"
 | 
				
			||||||
        copyright_years = f"{copyright_years}-{current_year}"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    log.info("pywidevine version %s Copyright (c) %s rlaphoenix", __version__, copyright_years)
 | 
					    log.info("pywidevine version %s Copyright (c) %s rlaphoenix", __version__, copyright_years)
 | 
				
			||||||
    log.info("https://github.com/rlaphoenix/pywidevine")
 | 
					    log.info("https://github.com/rlaphoenix/pywidevine")
 | 
				
			||||||
@ -38,15 +36,15 @@ def main(version: bool, debug: bool) -> None:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@main.command(name="license")
 | 
					@main.command(name="license")
 | 
				
			||||||
@click.argument("device", type=Path)
 | 
					@click.argument("device_path", type=Path)
 | 
				
			||||||
@click.argument("pssh", type=str)
 | 
					@click.argument("pssh", type=PSSH)
 | 
				
			||||||
@click.argument("server", type=str)
 | 
					@click.argument("server", type=str)
 | 
				
			||||||
@click.option("-t", "--type", "type_", type=click.Choice(LicenseType.keys(), case_sensitive=False),
 | 
					@click.option("-t", "--type", "type_", type=click.Choice(LicenseType.keys(), case_sensitive=False),
 | 
				
			||||||
              default="STREAMING",
 | 
					              default="STREAMING",
 | 
				
			||||||
              help="License Type to Request.")
 | 
					              help="License Type to Request.")
 | 
				
			||||||
@click.option("-p", "--privacy", is_flag=True, default=False,
 | 
					@click.option("-p", "--privacy", is_flag=True, default=False,
 | 
				
			||||||
              help="Use Privacy Mode, off by default.")
 | 
					              help="Use Privacy Mode, off by default.")
 | 
				
			||||||
def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
 | 
					def license_(device_path: Path, pssh: PSSH, server: str, type_: str, privacy: bool) -> None:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Make a License Request for PSSH to SERVER using DEVICE.
 | 
					    Make a License Request for PSSH to SERVER using DEVICE.
 | 
				
			||||||
    It will return a list of all keys within the returned license.
 | 
					    It will return a list of all keys within the returned license.
 | 
				
			||||||
@ -65,11 +63,8 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
 | 
				
			|||||||
    """
 | 
					    """
 | 
				
			||||||
    log = logging.getLogger("license")
 | 
					    log = logging.getLogger("license")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # prepare pssh
 | 
					 | 
				
			||||||
    pssh = PSSH(pssh)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # load device
 | 
					    # load device
 | 
				
			||||||
    device = Device.load(device)
 | 
					    device = Device.load(device_path)
 | 
				
			||||||
    log.info("[+] Loaded Device (%s L%s)", device.system_id, device.security_level)
 | 
					    log.info("[+] Loaded Device (%s L%s)", device.system_id, device.security_level)
 | 
				
			||||||
    log.debug(device)
 | 
					    log.debug(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -84,18 +79,18 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    if privacy:
 | 
					    if privacy:
 | 
				
			||||||
        # get service cert for license server via cert challenge
 | 
					        # get service cert for license server via cert challenge
 | 
				
			||||||
        service_cert = requests.post(
 | 
					        service_cert_res = requests.post(
 | 
				
			||||||
            url=server,
 | 
					            url=server,
 | 
				
			||||||
            data=cdm.service_certificate_challenge
 | 
					            data=cdm.service_certificate_challenge
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        if service_cert.status_code != 200:
 | 
					        if service_cert_res.status_code != 200:
 | 
				
			||||||
            log.error(
 | 
					            log.error(
 | 
				
			||||||
                "[-] Failed to get Service Privacy Certificate: [%s] %s",
 | 
					                "[-] Failed to get Service Privacy Certificate: [%s] %s",
 | 
				
			||||||
                service_cert.status_code,
 | 
					                service_cert_res.status_code,
 | 
				
			||||||
                service_cert.text
 | 
					                service_cert_res.text
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
        service_cert = service_cert.content
 | 
					        service_cert = service_cert_res.content
 | 
				
			||||||
        provider_id = cdm.set_service_certificate(session_id, service_cert)
 | 
					        provider_id = cdm.set_service_certificate(session_id, service_cert)
 | 
				
			||||||
        log.info("[+] Set Service Privacy Certificate: %s", provider_id)
 | 
					        log.info("[+] Set Service Privacy Certificate: %s", provider_id)
 | 
				
			||||||
        log.debug(service_cert)
 | 
					        log.debug(service_cert)
 | 
				
			||||||
@ -107,14 +102,14 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
 | 
				
			|||||||
    log.debug(challenge)
 | 
					    log.debug(challenge)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # send license challenge
 | 
					    # send license challenge
 | 
				
			||||||
    licence = requests.post(
 | 
					    license_res = requests.post(
 | 
				
			||||||
        url=server,
 | 
					        url=server,
 | 
				
			||||||
        data=challenge
 | 
					        data=challenge
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    if licence.status_code != 200:
 | 
					    if license_res.status_code != 200:
 | 
				
			||||||
        log.error("[-] Failed to send challenge: [%s] %s", licence.status_code, licence.text)
 | 
					        log.error("[-] Failed to send challenge: [%s] %s", license_res.status_code, license_res.text)
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
    licence = licence.content
 | 
					    licence = license_res.content
 | 
				
			||||||
    log.info("[+] Got License Message")
 | 
					    log.info("[+] Got License Message")
 | 
				
			||||||
    log.debug(licence)
 | 
					    log.debug(licence)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -135,7 +130,7 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
 | 
				
			|||||||
@click.option("-p", "--privacy", is_flag=True, default=False,
 | 
					@click.option("-p", "--privacy", is_flag=True, default=False,
 | 
				
			||||||
              help="Use Privacy Mode, off by default.")
 | 
					              help="Use Privacy Mode, off by default.")
 | 
				
			||||||
@click.pass_context
 | 
					@click.pass_context
 | 
				
			||||||
def test(ctx: click.Context, device: Path, privacy: bool):
 | 
					def test(ctx: click.Context, device: Path, privacy: bool) -> None:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Test the CDM code by getting Content Keys for Bitmovin's Art of Motion example.
 | 
					    Test the CDM code by getting Content Keys for Bitmovin's Art of Motion example.
 | 
				
			||||||
    https://bitmovin.com/demos/drm
 | 
					    https://bitmovin.com/demos/drm
 | 
				
			||||||
@ -161,7 +156,7 @@ def test(ctx: click.Context, device: Path, privacy: bool):
 | 
				
			|||||||
    # it will print information as it goes to the terminal
 | 
					    # it will print information as it goes to the terminal
 | 
				
			||||||
    ctx.invoke(
 | 
					    ctx.invoke(
 | 
				
			||||||
        license_,
 | 
					        license_,
 | 
				
			||||||
        device=device,
 | 
					        device_path=device,
 | 
				
			||||||
        pssh=pssh,
 | 
					        pssh=pssh,
 | 
				
			||||||
        server=license_server,
 | 
					        server=license_server,
 | 
				
			||||||
        type_=LicenseType.Name(license_type),
 | 
					        type_=LicenseType.Name(license_type),
 | 
				
			||||||
@ -382,10 +377,10 @@ def migrate(ctx: click.Context, path: Path) -> None:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@main.command("serve", short_help="Serve your local CDM and Widevine Devices Remotely.")
 | 
					@main.command("serve", short_help="Serve your local CDM and Widevine Devices Remotely.")
 | 
				
			||||||
@click.argument("config", type=Path)
 | 
					@click.argument("config_path", type=Path)
 | 
				
			||||||
@click.option("-h", "--host", type=str, default="127.0.0.1", help="Host to serve from.")
 | 
					@click.option("-h", "--host", type=str, default="127.0.0.1", help="Host to serve from.")
 | 
				
			||||||
@click.option("-p", "--port", type=int, default=8786, help="Port to serve from.")
 | 
					@click.option("-p", "--port", type=int, default=8786, help="Port to serve from.")
 | 
				
			||||||
def serve_(config: Path, host: str, port: int):
 | 
					def serve_(config_path: Path, host: str, port: int) -> None:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Serve your local CDM and Widevine Devices Remotely.
 | 
					    Serve your local CDM and Widevine Devices Remotely.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -400,5 +395,5 @@ def serve_(config: Path, host: str, port: int):
 | 
				
			|||||||
    from pywidevine import serve  # isort:skip
 | 
					    from pywidevine import serve  # isort:skip
 | 
				
			||||||
    import yaml  # isort:skip
 | 
					    import yaml  # isort:skip
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    config = yaml.safe_load(config.read_text(encoding="utf8"))
 | 
					    config = yaml.safe_load(config_path.read_text(encoding="utf8"))
 | 
				
			||||||
    serve.run(config, host, port)
 | 
					    serve.run(config, host, port)
 | 
				
			||||||
 | 
				
			|||||||
@ -82,17 +82,17 @@ class PSSH:
 | 
				
			|||||||
                box = Box.parse(data)
 | 
					                box = Box.parse(data)
 | 
				
			||||||
            except (IOError, construct.ConstructError):  # not a box
 | 
					            except (IOError, construct.ConstructError):  # not a box
 | 
				
			||||||
                try:
 | 
					                try:
 | 
				
			||||||
                    cenc_header = WidevinePsshData()
 | 
					                    widevine_pssh_data = WidevinePsshData()
 | 
				
			||||||
                    cenc_header.ParseFromString(data)
 | 
					                    widevine_pssh_data.ParseFromString(data)
 | 
				
			||||||
                    cenc_header = cenc_header.SerializeToString()
 | 
					                    data_serialized = widevine_pssh_data.SerializeToString()
 | 
				
			||||||
                    if cenc_header != data:  # not actually a WidevinePsshData
 | 
					                    if data_serialized != data:  # not actually a WidevinePsshData
 | 
				
			||||||
                        raise DecodeError()
 | 
					                        raise DecodeError()
 | 
				
			||||||
                    box = Box.parse(Box.build(dict(
 | 
					                    box = Box.parse(Box.build(dict(
 | 
				
			||||||
                        type=b"pssh",
 | 
					                        type=b"pssh",
 | 
				
			||||||
                        version=0,
 | 
					                        version=0,
 | 
				
			||||||
                        flags=0,
 | 
					                        flags=0,
 | 
				
			||||||
                        system_ID=PSSH.SystemId.Widevine,
 | 
					                        system_ID=PSSH.SystemId.Widevine,
 | 
				
			||||||
                        init_data=cenc_header
 | 
					                        init_data=data_serialized
 | 
				
			||||||
                    )))
 | 
					                    )))
 | 
				
			||||||
                except DecodeError:  # not a widevine cenc header
 | 
					                except DecodeError:  # not a widevine cenc header
 | 
				
			||||||
                    if "</WRMHEADER>".encode("utf-16-le") in data:
 | 
					                    if "</WRMHEADER>".encode("utf-16-le") in data:
 | 
				
			||||||
@ -307,16 +307,16 @@ class PSSH:
 | 
				
			|||||||
        if self.system_id == PSSH.SystemId.Widevine:
 | 
					        if self.system_id == PSSH.SystemId.Widevine:
 | 
				
			||||||
            raise ValueError("This is already a Widevine PSSH")
 | 
					            raise ValueError("This is already a Widevine PSSH")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cenc_header = WidevinePsshData()
 | 
					        widevine_pssh_data = WidevinePsshData()
 | 
				
			||||||
        cenc_header.algorithm = 1  # 0=Clear, 1=AES-CTR
 | 
					        widevine_pssh_data.algorithm = WidevinePsshData.Algorithm.Value("AESCTR")
 | 
				
			||||||
        cenc_header.key_ids[:] = [x.bytes for x in self.key_ids]
 | 
					        widevine_pssh_data.key_ids[:] = [x.bytes for x in self.key_ids]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.version == 1:
 | 
					        if self.version == 1:
 | 
				
			||||||
            # ensure both cenc header and box has same Key IDs
 | 
					            # ensure both cenc header and box has same Key IDs
 | 
				
			||||||
            # v1 uses both this and within init data for basically no reason
 | 
					            # v1 uses both this and within init data for basically no reason
 | 
				
			||||||
            self.__key_ids = self.key_ids
 | 
					            self.__key_ids = self.key_ids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.init_data = cenc_header.SerializeToString()
 | 
					        self.init_data = widevine_pssh_data.SerializeToString()
 | 
				
			||||||
        self.system_id = PSSH.SystemId.Widevine
 | 
					        self.system_id = PSSH.SystemId.Widevine
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def to_playready(
 | 
					    def to_playready(
 | 
				
			||||||
 | 
				
			|||||||
@ -86,10 +86,10 @@ class RemoteCdm(Cdm):
 | 
				
			|||||||
        server = r.headers.get("Server")
 | 
					        server = r.headers.get("Server")
 | 
				
			||||||
        if not server or "pywidevine serve" not in server.lower():
 | 
					        if not server or "pywidevine serve" not in server.lower():
 | 
				
			||||||
            raise ValueError(f"This Remote CDM API does not seem to be a pywidevine serve API ({server}).")
 | 
					            raise ValueError(f"This Remote CDM API does not seem to be a pywidevine serve API ({server}).")
 | 
				
			||||||
        server_version = re.search(r"pywidevine serve v([\d.]+)", server, re.IGNORECASE)
 | 
					        server_version_re = re.search(r"pywidevine serve v([\d.]+)", server, re.IGNORECASE)
 | 
				
			||||||
        if not server_version:
 | 
					        if not server_version_re:
 | 
				
			||||||
            raise ValueError("The pywidevine server API is not stating the version correctly, cannot continue.")
 | 
					            raise ValueError("The pywidevine server API is not stating the version correctly, cannot continue.")
 | 
				
			||||||
        server_version = server_version.group(1)
 | 
					        server_version = server_version_re.group(1)
 | 
				
			||||||
        if server_version < "1.4.3":
 | 
					        if server_version < "1.4.3":
 | 
				
			||||||
            raise ValueError(f"This pywidevine serve API version ({server_version}) is not supported.")
 | 
					            raise ValueError(f"This pywidevine serve API version ({server_version}) is not supported.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -1,8 +1,9 @@
 | 
				
			|||||||
import base64
 | 
					import base64
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from typing import Optional, Union
 | 
					from typing import Any, Optional, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from aiohttp.typedefs import Handler
 | 
				
			||||||
from google.protobuf.message import DecodeError
 | 
					from google.protobuf.message import DecodeError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from pywidevine.pssh import PSSH
 | 
					from pywidevine.pssh import PSSH
 | 
				
			||||||
@ -26,8 +27,8 @@ from pywidevine.exceptions import (InvalidContext, InvalidInitData, InvalidLicen
 | 
				
			|||||||
routes = web.RouteTableDef()
 | 
					routes = web.RouteTableDef()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def _startup(app: web.Application):
 | 
					async def _startup(app: web.Application) -> None:
 | 
				
			||||||
    app["cdms"]: dict[tuple[str, str], Cdm] = {}
 | 
					    app["cdms"] = {}
 | 
				
			||||||
    app["config"]["devices"] = {
 | 
					    app["config"]["devices"] = {
 | 
				
			||||||
        path.stem: path
 | 
					        path.stem: path
 | 
				
			||||||
        for x in app["config"]["devices"]
 | 
					        for x in app["config"]["devices"]
 | 
				
			||||||
@ -38,7 +39,7 @@ async def _startup(app: web.Application):
 | 
				
			|||||||
            raise FileNotFoundError(f"Device file does not exist: {device}")
 | 
					            raise FileNotFoundError(f"Device file does not exist: {device}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def _cleanup(app: web.Application):
 | 
					async def _cleanup(app: web.Application) -> None:
 | 
				
			||||||
    app["cdms"].clear()
 | 
					    app["cdms"].clear()
 | 
				
			||||||
    del app["cdms"]
 | 
					    del app["cdms"]
 | 
				
			||||||
    app["config"].clear()
 | 
					    app["config"].clear()
 | 
				
			||||||
@ -46,7 +47,7 @@ async def _cleanup(app: web.Application):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@routes.get("/")
 | 
					@routes.get("/")
 | 
				
			||||||
async def ping(_) -> web.Response:
 | 
					async def ping(_: Any) -> web.Response:
 | 
				
			||||||
    return web.json_response({
 | 
					    return web.json_response({
 | 
				
			||||||
        "status": 200,
 | 
					        "status": 200,
 | 
				
			||||||
        "message": "Pong!"
 | 
					        "message": "Pong!"
 | 
				
			||||||
@ -211,13 +212,15 @@ async def get_service_certificate(request: web.Request) -> web.Response:
 | 
				
			|||||||
        }, status=400)
 | 
					        }, status=400)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if service_certificate:
 | 
					    if service_certificate:
 | 
				
			||||||
        service_certificate = base64.b64encode(service_certificate.SerializeToString()).decode()
 | 
					        service_certificate_b64 = base64.b64encode(service_certificate.SerializeToString()).decode()
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        service_certificate_b64 = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return web.json_response({
 | 
					    return web.json_response({
 | 
				
			||||||
        "status": 200,
 | 
					        "status": 200,
 | 
				
			||||||
        "message": "Successfully got the Service Certificate.",
 | 
					        "message": "Successfully got the Service Certificate.",
 | 
				
			||||||
        "data": {
 | 
					        "data": {
 | 
				
			||||||
            "service_certificate": service_certificate
 | 
					            "service_certificate": service_certificate_b64
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -366,7 +369,7 @@ async def get_keys(request: web.Request) -> web.Response:
 | 
				
			|||||||
    session_id = bytes.fromhex(body["session_id"])
 | 
					    session_id = bytes.fromhex(body["session_id"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # get key type
 | 
					    # get key type
 | 
				
			||||||
    key_type = request.match_info["key_type"]
 | 
					    key_type: Optional[str] = request.match_info["key_type"]
 | 
				
			||||||
    if key_type == "ALL":
 | 
					    if key_type == "ALL":
 | 
				
			||||||
        key_type = None
 | 
					        key_type = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -414,26 +417,24 @@ async def get_keys(request: web.Request) -> web.Response:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@web.middleware
 | 
					@web.middleware
 | 
				
			||||||
async def authentication(request: web.Request, handler) -> web.Response:
 | 
					async def authentication(request: web.Request, handler: Handler) -> web.Response:
 | 
				
			||||||
    response = None
 | 
					 | 
				
			||||||
    if request.path != "/":
 | 
					 | 
				
			||||||
    secret_key = request.headers.get("X-Secret-Key")
 | 
					    secret_key = request.headers.get("X-Secret-Key")
 | 
				
			||||||
        if not secret_key:
 | 
					
 | 
				
			||||||
 | 
					    if request.path != "/" and not secret_key:
 | 
				
			||||||
        request.app.logger.debug(f"{request.remote} did not provide authorization.")
 | 
					        request.app.logger.debug(f"{request.remote} did not provide authorization.")
 | 
				
			||||||
        response = web.json_response({
 | 
					        response = web.json_response({
 | 
				
			||||||
            "status": "401",
 | 
					            "status": "401",
 | 
				
			||||||
            "message": "Secret Key is Empty."
 | 
					            "message": "Secret Key is Empty."
 | 
				
			||||||
        }, status=401)
 | 
					        }, status=401)
 | 
				
			||||||
        elif secret_key not in request.app["config"]["users"]:
 | 
					    elif request.path != "/" and secret_key not in request.app["config"]["users"]:
 | 
				
			||||||
        request.app.logger.debug(f"{request.remote} failed authentication with '{secret_key}'.")
 | 
					        request.app.logger.debug(f"{request.remote} failed authentication with '{secret_key}'.")
 | 
				
			||||||
        response = web.json_response({
 | 
					        response = web.json_response({
 | 
				
			||||||
            "status": "401",
 | 
					            "status": "401",
 | 
				
			||||||
            "message": "Secret Key is Invalid, the Key is case-sensitive."
 | 
					            "message": "Secret Key is Invalid, the Key is case-sensitive."
 | 
				
			||||||
        }, status=401)
 | 
					        }, status=401)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
    if response is None:
 | 
					 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            response = await handler(request)
 | 
					            response = await handler(request)  # type: ignore[assignment]
 | 
				
			||||||
        except web.HTTPException as e:
 | 
					        except web.HTTPException as e:
 | 
				
			||||||
            request.app.logger.error(f"An unexpected error has occurred, {e}")
 | 
					            request.app.logger.error(f"An unexpected error has occurred, {e}")
 | 
				
			||||||
            response = web.json_response({
 | 
					            response = web.json_response({
 | 
				
			||||||
@ -448,7 +449,7 @@ async def authentication(request: web.Request, handler) -> web.Response:
 | 
				
			|||||||
    return response
 | 
					    return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run(config: dict, host: Optional[Union[str, web.HostSequence]] = None, port: Optional[int] = None):
 | 
					def run(config: dict, host: Optional[Union[str, web.HostSequence]] = None, port: Optional[int] = None) -> None:
 | 
				
			||||||
    app = web.Application(middlewares=[authentication])
 | 
					    app = web.Application(middlewares=[authentication])
 | 
				
			||||||
    app.on_startup.append(_startup)
 | 
					    app.on_startup.append(_startup)
 | 
				
			||||||
    app.on_cleanup.append(_cleanup)
 | 
					    app.on_cleanup.append(_cleanup)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user