diff --git a/libavcodec/jpegxl_parser.c b/libavcodec/jpegxl_parser.c index bbd7338a61..dde36b0d6e 100644 --- a/libavcodec/jpegxl_parser.c +++ b/libavcodec/jpegxl_parser.c @@ -46,6 +46,8 @@ #define JXL_FLAG_USE_LF_FRAME 32 #define JXL_FLAG_SKIP_ADAPTIVE_LF_SMOOTH 128 +#define MAX_PREFIX_ALPHABET_SIZE (1u << 15) + #define clog1p(x) (ff_log2(x) + !!(x)) #define unpack_signed(x) (((x) & 1 ? -(x)-1 : (x))/2) #define div_ceil(x, y) (((x) - 1) / (y) + 1) @@ -724,16 +726,17 @@ static int read_vlc_prefix(GetBitContext *gb, JXLEntropyDecoder *dec, JXLSymbolD if (ret < 0) goto end; - buf = av_calloc(1, 262148); // 32768 * 8 + 4 + buf = av_mallocz(dist->alphabet_size * (2 * sizeof(int8_t) + sizeof(int16_t) + sizeof(uint32_t)) + + sizeof(uint32_t)); if (!buf) { ret = AVERROR(ENOMEM); goto end; } level2_lens = (int8_t *)buf; - level2_lens_s = (int8_t *)(buf + 32768); - level2_syms = (int16_t *)(buf + 65536); - level2_codecounts = (uint32_t *)(buf + 131072); + level2_lens_s = (int8_t *)(buf + dist->alphabet_size * sizeof(int8_t)); + level2_syms = (int16_t *)(buf + dist->alphabet_size * (2 * sizeof(int8_t))); + level2_codecounts = (uint32_t *)(buf + dist->alphabet_size * (2 * sizeof(int8_t) + sizeof(int16_t))); total_code = 0; for (int i = 0; i < dist->alphabet_size; i++) { @@ -742,6 +745,10 @@ static int read_vlc_prefix(GetBitContext *gb, JXLEntropyDecoder *dec, JXLSymbolD int extra = 3 + get_bits(gb, 2); if (repeat_count_prev) extra = 4 * (repeat_count_prev - 2) - repeat_count_prev + extra; + if (i + extra > dist->alphabet_size) { + ret = AVERROR_INVALIDDATA; + goto end; + } for (int j = 0; j < extra; j++) level2_lens[i + j] = prev; total_code += (32768 >> prev) * extra; @@ -772,8 +779,10 @@ static int read_vlc_prefix(GetBitContext *gb, JXLEntropyDecoder *dec, JXLSymbolD } } - if (total_code != 32768 && level2_codecounts[0] < dist->alphabet_size - 1) - return AVERROR_INVALIDDATA; + if (total_code != 32768 && level2_codecounts[0] < dist->alphabet_size - 1) { + ret = AVERROR_INVALIDDATA; + goto end; + } for (int i = 1; i < dist->alphabet_size + 1; i++) level2_codecounts[i] += level2_codecounts[i - 1]; @@ -848,6 +857,8 @@ static int read_distribution_bundle(GetBitContext *gb, JXLEntropyDecoder *dec, if (get_bits1(gb)) { int n = get_bits(gb, 4); dist->alphabet_size = 1 + (1 << n) + get_bitsz(gb, n); + if (dist->alphabet_size > MAX_PREFIX_ALPHABET_SIZE) + return AVERROR_INVALIDDATA; } else { dist->alphabet_size = 1; }