ml28xx/oki.py
2026-02-22 00:52:27 +07:00

247 lines
9.4 KiB
Python

from dataclasses import dataclass
from typing_extensions import override
from construct_typed import DataclassMixin, DataclassStruct, csfield
from typing import IO, Any, Self
from construct import *
from tap import Tap
class _EnhancedDataclassMixin(DataclassMixin):
@classmethod
def format(cls):
return DataclassStruct(cls)
@classmethod
def build(cls, obj: Self, **kw: Any): # pyright: ignore[reportExplicitAny, reportAny]
return cls.format().build(obj, **kw)
@classmethod
def parse(cls, data: bytes|bytearray, **kw: Any):# pyright: ignore[reportExplicitAny, reportAny]
return cls.format().parse(data, **kw)
@classmethod
def parse_file(cls, file: str, **kw: Any):# pyright: ignore[reportExplicitAny, reportAny]
return cls.format().parse_file(file, **kw)
@classmethod
def parse_stream(cls, stream: IO[bytes], **kw: Any):# pyright: ignore[reportExplicitAny, reportAny]
return cls.format().parse_stream(stream, **kw)
# Wave file
@dataclass
class WAV(_EnhancedDataclassMixin):
@dataclass
class WAVE(_EnhancedDataclassMixin):
@dataclass
class FMT(_EnhancedDataclassMixin):
audio_fmt: int = csfield(Hex(Int16ul))
channels: int = csfield(Int16ul)
sample_rate: int = csfield(Int32ul)
bytes_rate: int = csfield(Rebuild(Int32ul, this.sample_rate * this.channels * (this.bits_per_samples // 8)))
block_align: int = csfield(Rebuild(Int16ul, this.channels * (this.bits_per_samples // 8)))
bits_per_samples: int = csfield(Int16ul)
@dataclass
class Chunk(_EnhancedDataclassMixin):
type: str = csfield(PaddedString(4, "ascii"))
data: bytes = csfield(Prefixed(Int32ul, GreedyBytes) )
magic: bytes = csfield(Const(b"WAVE"))
chunks: list[Chunk] = csfield(GreedyRange(Chunk.format()))
magic: bytes = csfield(Const(b"RIFF"))
wave: WAVE = csfield(Prefixed(Int32ul, WAVE.format()))
# ADPCM2 parameters
STEP_TABLES = [16, 17, 18, 20, 21, 23, 25, 27, 29, 31, 34, 37, 40, 43, 46, 50, 54, 59, 63, 69, 74, 80, 86, 93, 101, 109, 118, 127, 138, 149, 161, 173, 187, 202, 219, 236, 255, 275, 298, 321, 347, 375, 405, 437, 472, 510, 551, 595, 643, 694, 750, 810, 875, 945, 1020, 1102, 1190, 1286, 1388, 1500, 1620, 1749, 1889, 2040, 2204, 2380, 2570, 2776, 2998, 3238, 3497, 3777, 4079, 4406, 4758, 5139, 5550, 5994, 6474, 6991, 7551, 8155, 8807, 9512, 10273, 11095, 11982, 12941, 13976, 15095, 16302]
STEP_INDEX = [-2, -2, -2, -2, 2, 6, 9, 11]
STEP_INDEX_2BIT = [-2, 3]
# ADPCM2 decoder (used in ML2871)
class OKIAdpcm2Decode():
def __init__(self, bits: int=4, channels: int=1):
assert bits in [2, 4]
assert channels in [1, 2]
self.__decode_bits = bits
self.delta: list[int] = [0] * channels
self.step_index: list[int] = [0] * channels
self.__channels = channels
self.__decoder_idx = 0
def __expand_sample(self, nibble: int):
SIGN_MASK = 1 << (self.__decode_bits - 1)
VALU_MASK = SIGN_MASK - 1
cur_step = STEP_TABLES[self.step_index[self.__decoder_idx]]
sign = (-1 if nibble & SIGN_MASK else 1)
if self.__decode_bits == 4:
cur_delta = cur_step >> 3
if nibble & 1:
cur_delta += cur_step >> 2
if nibble & 2:
cur_delta += cur_step >> 1
if nibble & 4:
cur_delta += cur_step
else:
cur_delta = cur_step >> 1
if nibble & 1:
cur_delta += cur_step
self.delta[self.__decoder_idx] += sign * cur_delta
self.step_index[self.__decoder_idx] = max(0, min(self.step_index[self.__decoder_idx] + (STEP_INDEX if self.__decode_bits == 4 else STEP_INDEX_2BIT)[nibble & VALU_MASK], len(STEP_TABLES) - 1))
self.delta[self.__decoder_idx] = max(-32768, min(self.delta[self.__decoder_idx], 32767))
ret = self.delta[self.__decoder_idx]
self.__decoder_idx = (self.__decoder_idx + 1) % self.__channels
return ret
def decode(self, data: bytes):
outp: list[int] = []
for p in data:
if self.__decode_bits == 2:
outp.append(self.__expand_sample(p >> 6))
outp.append(self.__expand_sample((p >> 4) & 3))
outp.append(self.__expand_sample((p >> 2) & 3))
outp.append(self.__expand_sample(p & 3))
else:
outp.append(self.__expand_sample(p >> 4))
outp.append(self.__expand_sample(p & 0xf))
return b"".join(x.to_bytes(2, "little", signed=True) for x in outp)
# ADPCM1 parameters
STEP_TABLES_VOX = [16, 17, 19, 21, 23, 25, 28, 31, 34, 37, 41, 45, 50, 55, 60, 66, 73, 80, 88, 97, 107, 118, 130, 143, 157, 173, 190, 209, 230, 253, 279, 307, 337, 371, 408, 449, 494, 544, 598, 658, 724, 796, 876, 963, 1060, 1166, 1282, 1411, 1552]
STEP_INDEX_VOX = [-1, -1, -1, -1, 2, 4, 6, 8]
# ADPCM1 decoder (VOX)
class OKIAdpcm1Decode():
def __init__(self, bits: int=4, channels: int=1):
assert bits in [4]
assert channels in [1, 2]
self.__decode_bits = bits
self.delta: list[int] = [0] * channels
self.step_index: list[int] = [0] * channels
self.__channels = channels
self.__decoder_idx = 0
def __expand_sample(self, nibble: int):
SIGN_MASK = 1 << (self.__decode_bits - 1)
VALU_MASK = SIGN_MASK - 1
cur_step = STEP_TABLES_VOX[self.step_index[self.__decoder_idx]]
sign = (-1 if nibble & SIGN_MASK else 1)
cur_delta = cur_step >> 3
if nibble & 1:
cur_delta += cur_step >> 2
if nibble & 2:
cur_delta += cur_step >> 1
if nibble & 4:
cur_delta += cur_step
# cur_delta = ((((nibble & 7) * 2) + 1) * cur_step) >> 3 # sox, libsndfile and FFmpeg VOX uses IMA mult instead of shift, while official implementation uses IMA shift.
self.delta[self.__decoder_idx] += sign * cur_delta
self.step_index[self.__decoder_idx] = max(0, min(self.step_index[self.__decoder_idx] + STEP_INDEX_VOX[nibble & VALU_MASK], len(STEP_TABLES_VOX) - 1))
self.delta[self.__decoder_idx] = max(-2048, min(self.delta[self.__decoder_idx], 2047))
ret = self.delta[self.__decoder_idx]
self.__decoder_idx = (self.__decoder_idx + 1) % self.__channels
return ret << 4
def decode(self, data: bytes):
outp: list[int] = []
for p in data:
outp.append(self.__expand_sample(p >> 4))
outp.append(self.__expand_sample(p & 0xf))
return b"".join(x.to_bytes(2, "little", signed=True) for x in outp)
# OKI ADP (Audi) format
@dataclass
class OkiADP(_EnhancedDataclassMixin):
magic: bytes = csfield(Const(b"Audi"))
type: int = csfield(Hex(Int8ub))
codec: int = csfield(Hex(Int8ub))
channels: int = csfield(Hex(Int8ub))
bits_per_samples: int = csfield(Hex(Int8ub))
sample_rate: int = csfield(Int16ub)
_pad: None = csfield(Padding(2))
_audio_data: bytes = csfield(Prefixed(Int32ub, GreedyBytes))
@staticmethod
def __byswap(data: bytes):
temp = bytearray()
for i in range(len(data) >> 1):
temp.append(data[(i << 1) + 1])
temp.append(data[i << 1])
return bytes(temp)
def decode(self):
assert self.type in [0x0, 0x80], "Only PCM and ML28xx ADP files were supported"
if self.type == 0x00 and self.codec == 0x00: # LPCM (Signed)
assert (self.codec == 0 and self.bits_per_samples in [4, 8, 16]), "Bad codec and bits per sample combinations"
if self.bits_per_samples == 4:
decoder = OKIAdpcm1Decode(self.bits_per_samples, self.channels)
return decoder.decode(self._audio_data)
return self.__byswap(self._audio_data) if self.bits_per_samples == 16 else bytes(x ^ 0x80 for x in self._audio_data)
else:
assert (self.codec == 0 and self.bits_per_samples == 2) or (self.codec == 1 and self.bits_per_samples == 4), "Bad codec and bits per sample combinations"
decoder = OKIAdpcm2Decode(self.bits_per_samples, self.channels)
return decoder.decode(self._audio_data)
class Args(Tap):
in_file: str # Input file
out_file: str = "" # Output file (leave blank for playback)
@override
def configure(self) -> None:
self.add_argument("in_file")
self.add_argument("out_file", nargs="?")
if __name__ == "__main__":
import pyaudio
import time
args = Args().parse_args()
adpc = OkiADP.parse_file(args.in_file)
if not args.out_file:
snd = pyaudio.PyAudio()
stream = snd.open(format=pyaudio.paUInt8 if adpc.bits_per_samples == 8 else pyaudio.paInt16, channels=adpc.channels, rate=adpc.sample_rate, output=True)
buffer = adpc.decode()
buf_size = (1024 if adpc.bits_per_samples == 8 else 2048) * adpc.channels
while len(buffer) > 0:
stream.write(buffer[:buf_size])
buffer = buffer[buf_size:]
time.sleep(0.5)
stream.close()
else:
WaveOut = WAV(wave=WAV.WAVE(chunks=[]))
WaveOut.wave.chunks.append(WAV.WAVE.Chunk("fmt ", WAV.WAVE.FMT.build(WAV.WAVE.FMT(1, adpc.channels, adpc.sample_rate, bits_per_samples=8 if adpc.bits_per_samples == 8 else 16))))
WaveOut.wave.chunks.append(WAV.WAVE.Chunk("data", adpc.decode()))
open(args.out_file, "wb").write(WAV.build(WaveOut))
#open("audi.bin", "wb").write(adpTest.decode())