JPEG图片编码格式分析

图片展示需要BGR模式的三维向量,图片的编码是把BGR图片编码成文件能存储的格式,解码则反之。目前常见的编码为jpg、png、gif等。新兴的如webp、heic。

BMP

从简单入手,BMP是最简单的编码方式,甚至数十行代码就能完成编码和解码简单的程序。

bmp由文件头和位图信息头组成

import struct
import numpy as np
 
BITMAP_FILE_HEADER_FMT = '<2sI4xI'
BITMAP_FILE_HEADER_SIZE = struct.calcsize(BITMAP_FILE_HEADER_FMT)
BITMAP_INFO_FMT = '<I2i2H6I'
BITMAP_INFO_SIZE = struct.calcsize(BITMAP_INFO_FMT)
 
 
class BmpHeader:
    def __init__(self):
        self.bf_type = None
        self.bf_size = 0
        self.bf_off_bits = 0
        self.bi_size = 0
        self.bi_width = 0
        self.bi_height = 0
        self.bi_planes = 1 # 颜色平面数
        self.bi_bit_count = 0
        self.bi_compression = 0
        self.bi_size_image = 0
        self.bi_x_pels_per_meter = 0
        self.bi_y_pels_per_meter = 0
        self.bi_clr_used = 0
        self.bi_clr_important = 0
 
 
class BmpDecoder:
    def __init__(self, data):
        self.__header = BmpHeader()
        self.__data = data
 
    def read_header(self):
        if self.__header.bf_type is not None:
            return self.__header
        # bmp信息头
        self.__header.bf_type, self.__header.bf_size,\
            self.__header.bf_off_bits = struct.unpack_from(BITMAP_FILE_HEADER_FMT, self.__data)
        if self.__header.bf_type != b'BM':
            return None
        # 位图信息头
        self.__header.bi_size, self.__header.bi_width, self.__header.bi_height, self.__header.bi_planes,\
            self.__header.bi_bit_count, self.__header.bi_compression, self.__header.bi_size_image,\
            self.__header.bi_x_pels_per_meter, self.__header.bi_y_pels_per_meter, self.__header.bi_clr_used,\
            self.__header.bi_clr_important = struct.unpack_from(BITMAP_INFO_FMT, self.__data, BITMAP_FILE_HEADER_SIZE)
        return self.__header
 
    def read_data(self):
        header = self.read_header()
        if header is None:
            return None
        # 目前只写了解析常见的24位或32位位图
        if header.bi_bit_count != 24 and header.bi_bit_count != 32:
            return None
        # 目前只写了RGB模式
        if header.bi_compression != 0:
            return None
        offset = header.bf_off_bits
        channel = int(header.bi_bit_count / 8)
        img = np.zeros([header.bi_height, header.bi_width, channel], np.uint8)
        y_axis = range(header.bi_height - 1, -1, -1) if header.bi_height > 0 else range(0, header.bi_height)
        for y in y_axis:
            for x in range(0, header.bi_width):
                plex = np.array(struct.unpack_from('<' + str(channel) + 'B', self.__data, offset), np.int8)
                img[y][x] = plex
                offset += channel
        return img
 
 
class BmpEncoder:
    def __init__(self, img):
        self.__img = img
 
    def write_data(self):
        image_height, image_width, channel = self.__img.shape
        # 只支持RGB或者RGBA图片
        if channel != 3 and channel != 4:
            return False
        header = BmpHeader()
        header.bf_type = b'BM'
        header.bi_bit_count = channel * 8
        header.bi_width = image_width
        header.bi_height = image_height
        header.bi_size = BITMAP_INFO_SIZE
        header.bf_off_bits = header.bi_size + BITMAP_FILE_HEADER_SIZE
        header.bf_size = header.bf_off_bits + image_height * image_width * channel
        buffer = bytearray(header.bf_size)
        # bmp信息头
        struct.pack_into(BITMAP_FILE_HEADER_FMT, buffer, 0, header.bf_type, header.bf_size, header.bf_off_bits)
        # 位图信息头
        struct.pack_into(BITMAP_INFO_FMT, buffer, BITMAP_FILE_HEADER_SIZE, header.bi_size, header.bi_width, header.bi_height,
                         header.bi_planes, header.bi_bit_count, header.bi_compression, header.bi_size_image,
                         header.bi_x_pels_per_meter, header.bi_y_pels_per_meter, header.bi_clr_used,
                         header.bi_clr_important)
        # 位图,一般都是纵坐标倒序模式
        offset = header.bf_off_bits
        for y in range(header.bi_height - 1, -1, -1):
            for x in range(header.bi_width):
                struct.pack_into('<' + str(channel) + 'B', buffer, offset, *self.__img[y][x])
                offset += channel
        return buffer

bmp图片的纵坐标是反过来的,如下图所示:

《JPEG图片编码格式分析》

 JPEG

JPEG是一种编码压缩方法,真正描述图片如何存储的是JFIF(JPEG File Interchange Format),但是普通交流中往往使用“JPEG文件”这种叫法。由于精力有限,只尝试了JPEG解码的步骤。

背景知识

DCT

离散余弦变换(discrete cosine transform),把信号从空域转换成频域,且具有较好的能量聚集。变换公式如下:

DCT:《JPEG图片编码格式分析》,其中《JPEG图片编码格式分析》

IDCT:《JPEG图片编码格式分析》,其中《JPEG图片编码格式分析》

可以阅读matlab的帮助文档离散余弦变换- MATLAB & Simulink- MathWorks 中国,或者一篇博客离散余弦变换(DCT)的来龙去脉_独孤呆博的博客-CSDN博客_二维离散余弦变换

哈夫曼编码

根据符号出现概率,使用较短的编码更频繁出现的符号。更详细的可以阅读详细图解哈夫曼Huffman编码树_无鞋童鞋的博客-CSDN博客_huffman编码树

色差信号

使用亮度和蓝色、红色的浓度偏移量描述图像信号的色彩空间,和RGB转换公式可阅读https://en.wikipedia.org/wiki/YCbCr。使用YCbCr是因为,人眼对于亮度对比的感知能力比色彩的感知能力要强,把亮度分量分离出来后,可以有针对性地使用不同的量化表、采样因子来达到不同的压缩率,且人眼感知不强。

读取JPEG文件Header

JPEG文件在制定规范时,定义文件是由marker和segment组成。marker都是以0xff开头,以非0x00结束。对应常用marker如下:

markervaluedescription
SOI0xFFD8图像开始(Start Of Scan)
APP00xFFE0存储图像参数
APP10xFFE1EXIF
APP20xFFE2
APP120xFFEC图片质量等信息
APP130xFFEDphptoshop存储的信息Photoshop Tags
SOF00xFFC0Start Of Frame,SOF0是baseline DCT
SOF20xFFC2Start Of Frame,SOF2是progressive DCT
DHT0xFFC4Define Huffman Table,定义哈夫曼编码表,可以有多个,具体重建哈夫曼树方法见下
DQT0xFFDBDefine Quantization Table,定义量化表,可以有多个。量化表能影响图片的压缩质量
DRI0xFFDDDefine Restart Interval,重置DC信号的间隔(每解码指定次MCU就重置DC信号)
SOS0xFFDAStart Of Scan
image data如果有0xFF的数据,会使用0xFF00表示,解码的时候需要注意
EOI0xFFD9End Of Image

更多marker可以参考exiftool的文档JPEG Tags

APP0

fieldsize(bytes)description
长度2包括这个字段为首的整个segment长度
标识符5图片编码方式,“JFIF\0″或者”JFXX\0“等,下面的字段均以JFIF为示例

JFIF

JFIF版本2第一个字节为主版本,第二个字节为次要版本(01 02表示1.02)
密度单位1

下列像素密度字段的单位
00:无单位;width:height像素宽高比 = Xdensity:Ydensity
01:每英寸像素(2.54厘米)
02:每厘米像素

x方向密度2水平像素密度。不得为零。
y方向密度2垂直像素密度。不得为零。
缩略图宽度1嵌入的RGB缩略图的水平像素数。可以为零。
缩略图高度1嵌入的RGB缩略图的垂直像素数。可以为零。
缩略图数据3n未压缩的24位RGB(每个颜色通道8位)光栅缩略图数据,顺序为R0、G0、B0、…Rn、Gn、Bn;其中n = Xthumbnail × Ythumbnail。

APP12

fieldsize(bytes)description
长度2包括这个字段为首的整个segment长度
标识符“Ducky”等

Ducky

fieldsize(bytes)description
tag2

0x0001:压缩质量,uint32

0x0002:评论,string

0x0003:版权,string

长度2

接下来的内容长度

内容

SOF0

fieldsize(bytes)description
长度2包括这个字段为首的整个segment长度
精度1每像素数据位数,一般为8bit
高度2图像高度
宽度2图像宽度
颜色分量数1

01:灰度图

03:YCbCr图,一般为这个

04:CMYK

颜色分量信息三色分量数*3

3字节分别为:

  1. 颜色分量ID
  2. 水平和垂直采样因子,水平为高4位,垂直为低4位
  3. 使用的量化表ID

DHT

field

size(bytes)

description

长度2包括这个字段为首的整个segment长度
表类型和表ID1

高4位为表类型,其中:0位DC,1为AC。

低4位为哈夫曼表ID

不同长度码字数量16码字长度为1-16位的时候各位的数量,详见下
各个码字对应的内容用于构建哈夫曼树

举例,若不同码字长度的记录如:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

0151111110000000

各个长度码字对应内容如下:

bit_length

code

symbol

2000
30101
30112
31003
31014
31105
411106
5111107
61111108
711111109
81111111010
911111111011

DQT

field

size(bytes)

description

长度2包括这个字段为首的整个segment长度
精度和表ID1

高4位为精度,0为8位,1为16位。

低4位未量化表ID。

表数据64*(精度)8×8的量化表数据

DRI

field

size(bytes)

description

长度2包括这个字段为首的整个segment长度
重置DC间隔2每隔x个MCU就重置一次所有DC信号prev的值,此时读取的文件流会有RST标记,这个RST标记会随着重启次数而递增,如第一次重启标记为0xD0,第二次则为0xD1直到0xD7后会再回到0xD0如此循环

一个解码jpeg文件的示例

初始化

获取整个图片component的数量,图片宽高,构建量化表、哈夫曼表,获取到各个component对应哈夫曼表ID、量化表、垂直和水平采样因子。

采样因子

因为图片保存时候是按照YCbCr保存的,而人眼对亮度比较敏感,对于色度不太敏感,因此可以对色度浓度更小的抽样,亮度按照原样抽样。JPEG在解码时候一个解码单元为一个MCU(Minimum Coded Unit)。一个MCU的大小取决于各个component的采样因子。MCU的宽度为max(各个componet的水平采样因子) * 8,高度为max(各个component的垂直采样因子) * 8。常见的,如Y的采样因子为2×2,Cb、Cr的采样因子为1×1,则一个MCU内数据的分布则是Y1、Y2、Y3、Y4、Cb1、Cr1,6小块,每小块是8×8像素的编码单元,整个MCU大小为16×16,这个MCU的各个数据分布如下:

Y1Cb1Cr1Y2
Y3Y4

如果一个Y、Cb、Cr的采样因子均为1×1,则整个MCU大小为8×8。

《JPEG图片编码格式分析》

 采样因子比例常见有4:2:2, 4:1:1, 4:4:4。

哈夫曼表

哈夫曼表区分了之流和交流表,一般有4个哈夫曼表,分别是亮度DC,亮度AC,色度DC,色度AC表。哈夫曼表用于压缩最终的图像编码。如示例的图片的亮度DC表如下:

bit_length

code

symbol

2000
30101
30112
31003
31014
31105
411106
5111107
61111108
711111109
81111111010
911111111011

量化表

量化表一般为2个,亮度一个,色度一个。量化表用于编码时候把DCT转换后的矩阵除以这个量化表,使得高频部分尽量约等于0,这样在最终编码的时候四舍五入能把大部分高频部分都压缩成0,乐观情况下只有DC,AC全为0,只需要2bit就能表示剩余63个AC信号了!例如示例的DC量化表如下:

1611121412101614
1314181716192440
2624222224493537
2940585161605751
5655647292786468
87695556801098187
95981031041036277113
1211121001209210110399

解码一个MCU

解码MCU时,需要读取多个8×8的block。每个block均由1个DC信号和63个AC信号组成。解码时先读取DC信号,再读取AC信号。我们捋一下,对于baseline DCT编码的图片,目前我们各个颜色分量对应的ht,qt关系如下:

《JPEG图片编码格式分析》

也就是说,一个颜色分量有1个直流ht,1个交流ht,1个量化表。

解码一个block

解码DC信号

由于相邻的block之间DC信号的差异很小,DC信号使用DPCM(Differential Pulse Code Modulation)编码。读取时通过对应哈夫曼表找到对应代表符号,然后读取对应位数的数值,然后进行DPCM解码。例如例子中,第一个MCU的亮度component的4个block读取出来的值分别为-37,1,-1,-1,则最终4个block的DC信号分别为-37,-36,-37,-38。每个component之间的DCPM编码是独立的。

解码AC信号

AC信号的编码使用RLE(Run Length Encoding)编码。读取时通过对应哈夫曼表找到对应符号,对应符号的高4位未接下来有几个连续的0,低4位为这些0后面跟着的位数,然后读取对应位数的数值。特别地,当读取到0x00时,代表接下来这个block所有数字为0,当读取的值为0xf0时,代表接下来有16个0。

zig-zag解码

8×8矩阵在编码时并不是按顺序读取,因此解码的时候也需要还原,其读取顺序如下:

《JPEG图片编码格式分析》

zig-zag编码原因是,大部分能量(振幅)都集中在左上角(低频区域),为了编码AC信号时能更高的连续性,AC信号是RLE编码的,因此同一个数值连续数量越多压缩率越大。高频数据往往为0,而AC信号解码时遇到0则表示后续全部为0。

反量化

经过zigzag解码后,此时block的矩阵长这样:

-370000000
00000000
00000000
00000000
00000000
00000000
00000000
00000000

需要和量化表点乘一下,最终结果长这样:

-5920000000
00000000
00000000
00000000
00000000
00000000
00000000
00000000

只有-592一个直流信号,可见这个8×8的block在DCT变换前,64个像素颜色一样(也可能存在一点点差别,但是在量化时候四舍五入走了)。

IDCT

把block从频域转换为空域,计算时往往使用矩阵叉乘加快速度。结果补码转成原码,并把符号位去掉,最终结果:

54545454545454-54
5454545454545454
5454545454545454
5454545454545454
5454545454545454
5454545454545454
5454545454545454
5454545454545454

解码JPEG文件示例

import abc
import itertools
import struct
import sys
import typing
import cv2
import numpy as np
from scipy import fft, ndimage
 
FRAME_SOF0 = b'\xc0'  # Start Of Frame N
FRAME_SOF1 = b'\xc1'  # N indicates which compression process
FRAME_SOF2 = b'\xc2'  # Only SOF0-SOF2 are now in common use
FRAME_SOF3 = b'\xc3'
FRAME_SOF5 = b'\xc5'  # NB: codes C4 and CC are NOT SOF markers
FRAME_SOF6 = b'\xc6'
FRAME_SOF7 = b'\xc7'
FRAME_SOF8 = b'\xc8'
FRAME_SOF9 = b'\xc9'
FRAME_SOF10 = b'\xca'
FRAME_SOF11 = b'\xcb'
FRAME_SOF13 = b'\xcd'
FRAME_SOF14 = b'\xce'
FRAME_SOF15 = b'\xcf'
FRAME_SOI = b'\xd8'
FRAME_EOI = b'\xd9'  # End Of Image (end of datastream)
FRAME_SOS = b'\xda'  # Start Of Scan (begins compressed data)
FRAME_APP0 = b'\xe0'
FRAME_APP1 = b'\xe1'
FRAME_APP2 = b'\xe2'
FRAME_APP3 = b'\xe3'
FRAME_APP4 = b'\xe4'
FRAME_APP5 = b'\xe5'
FRAME_APP6 = b'\xe6'
FRAME_APP7 = b'\xe7'
FRAME_APP8 = b'\xe8'
FRAME_APP9 = b'\xe9'
FRAME_APP10 = b'\xea'
FRAME_APP11 = b'\xeb'
FRAME_APP12 = b'\xec'
FRAME_APP13 = b'\xed'
FRAME_APP14 = b'\xee'
FRAME_APP15 = b'\xef'
FRAME_DQT = b'\xdb'
FRAME_DRI = b'\xdd'
FRAME_DHT = b'\xc4'
 
ZIGZAGINVERSE = np.array([[0, 1, 5, 6, 14, 15, 27, 28],
                          [2, 4, 7, 13, 16, 26, 29, 42],
                          [3, 8, 12, 17, 25, 30, 41, 43],
                          [9, 11, 18, 24, 31, 40, 44, 53],
                          [10, 19, 23, 32, 39, 45, 52, 54],
                          [20, 22, 33, 38, 46, 51, 55, 60],
                          [21, 34, 37, 47, 50, 56, 59, 61],
                          [35, 36, 48, 49, 57, 58, 62, 63]])
ZIGZAGFLATINVERSE = ZIGZAGINVERSE.flatten()
ZIGZAGFLAT = np.argsort(ZIGZAGFLATINVERSE)
 
 
def zigzag_single(block):
    return block.flatten()[ZIGZAGFLAT].reshape([8, 8])
 
 
def inverse_zigzag_single(array):
    return array.flatten()[ZIGZAGFLATINVERSE].reshape([8, 8])
 
 
def decode_number(bit_length, bits):
    """
    补码转原码
    :param bit_length:
    :param bits:
    :return:
    """
    b = 2 ** (bit_length - 1)
    if bits >= b:
        return bits
    else:
        return bits - (2 * b - 1)
 
 
class BitStream:
    def __init__(self, data):
        self.data = data
        self.pos = 0
 
    def get_bit(self):
        byte = self.data[self.pos >> 3]
        offset = 7 - (self.pos & 0x07)
        self.pos += 1
        return (byte >> offset) & 0x01
 
    def get_bits(self, n):
        val = 0
        for _ in range(n):
            val = val << 1 | self.get_bit()
        return val
 
 
class HuffmanCodec:
    def __init__(self):
        self._root = []
        self._symbol = []
 
    def _travel_root(self, root, code, bit_length, result):
        if len(root) == 0:
            return
        if isinstance(root[0], list):
            self._travel_root(root[0], code << 1, bit_length + 1, result)
        else:
            result.append((code << 1, bit_length, root[0]))
        if len(root) == 2:
            if isinstance(root[1], list):
                self._travel_root(root[1], code << 1 | 1, bit_length + 1, result)
            else:
                result.append((code << 1 | 1, bit_length, root[1]))
 
    def print_code_table(self, out=sys.stdout):
        code = 0
        result = []
        self._travel_root(self._root, code, 1, result)
        columns = list(zip(*itertools.chain(
            [('Bits', 'Code', 'Value', 'Symbol')],
            ((str(v[1]), bin(v[0])[2:].rjust(v[1], '0'), str(v[0]), repr(v[2])) for v in result)
        )))
        widths = tuple(max(len(s) for s in col) for col in columns)
        template = '{0:>%d} {1:%d} {2:>%d} {3}\n' % widths[:3]
        for row in zip(*columns):
            out.write(template.format(*row))
 
    def _add_symbol(self, root, symbol, pos):
        if isinstance(root, list):
            if pos == 0:
                if len(root) < 2:
                    root.append(symbol)
                    return True
                return False
            for i in [0, 1]:
                if len(root) == i:
                    root.append([])
                if self._add_symbol(root[i], symbol, pos - 1):
                    return True
        return False
 
    def build_from_bits(self, bits_length_seq, symbol_seq):
        self._symbol = symbol_seq
        symbol_index = 0  # 当前使用到symbol列表的下标
        for i, count in enumerate(bits_length_seq):
            for j in range(count):
                self._add_symbol(self._root, symbol_seq[symbol_index], i)
                symbol_index += 1
 
    def _find(self, bit_stream: BitStream):
        node = self._root
        while isinstance(node, list):
            node = node[bit_stream.get_bit()]
        return node
 
    def get_symbol(self, bit_stream) -> int:
        while True:
            res = self._find(bit_stream)
            if bit_stream == 0:
                return 0
            elif res != -1:
                return res
 
 
class JpegMarker:
    identify = ''
 
    def __init__(self, marker):
        self.__marker = marker
 
    def marker(self):
        return self.__marker
 
    @abc.abstractmethod
    def encode(self):
        pass
 
    @abc.abstractmethod
    def decode(self, data):
        pass
 
 
class JpegApp0Marker(JpegMarker):
 
    def __init__(self, marker):
        super().__init__(marker)
        self.__main_version = 0
        self.__sub_version = 0
        self.__density = 0
        self.__x_density = 0
        self.__y_density = 0
        self.__x_thumbnail = 0
        self.__y_thumbnail = 0
 
    def encode(self):
        pass
 
    def decode(self, data):
        offset = 0
        while data[offset] != 0 and offset < len(data) - 1:
            self.identify += chr(data[offset])
            offset += 1
        self.__main_version, self.__sub_version, self.__density, self.__x_density, self.__y_density, \
        self.__x_thumbnail, self.__y_thumbnail = struct.unpack_from('>3B2H2B', data, offset + 1)
 
 
class JpegApp1Marker(JpegMarker):
 
    def encode(self):
        pass
 
    def decode(self, data):
        pass
 
 
class JpegApp12Marker(JpegMarker):
 
    def __init__(self, marker):
        super().__init__(marker)
        self.quality = None
        self.comment = None
        self.copyright = None
 
    def encode(self):
        pass
 
    def decode(self, data):
        data_size = len(data)
        tag_name = ''
        offset = 0
        while data[offset] != 0 and offset < data_size:
            tag_name += chr(data[offset])
            offset += 1
        if tag_name == 'Ducky':
            while offset < data_size:
                tag, = struct.unpack_from('2s', data, offset)
                if tag == b'\x00\x00':
                    break
                offset += 2
                value_size, = struct.unpack_from('>H', data, offset)
                offset += 2
                if tag == b'\x00\x01':
                    # quality
                    self.quality, = struct.unpack_from('>I', data, offset)
                if tag == 2:
                    # comment
                    self.comment, = struct.unpack_from(str(value_size) + 's', data, offset)
                if tag == 3:
                    # copyright
                    self.copyright, = struct.unpack_from(str(value_size) + 's', data, offset)
                offset += value_size
 
 
class JpegApp13Marker(JpegMarker):
 
    def encode(self):
        pass
 
    def decode(self, data):
        print(data)
 
 
class JpegApp14Marker(JpegMarker):
 
    def __init__(self, marker):
        super().__init__(marker)
        self.dct_encode_version = None
        self.flags_0 = None
        self.flags_1 = None
        self.color_transform = None
 
    def encode(self):
        pass
 
    def decode(self, data):
        tag_name, version, flag0, flag1, color = struct.unpack('>5sxB2HB', data)
        if tag_name == b'Adobe':
            self.dct_encode_version = version
            self.flags_0 = flag0
            self.flags_1 = flag1
            self.color_transform = color
 
 
class JpegDQTMarker(JpegMarker):
    def __init__(self, marker):
        super().__init__(marker)
        self.table_id = None
        self.data_matrix = None
 
    def encode(self):
        pass
 
    def decode(self, data):
        tmp, = struct.unpack_from('B', data)
        precision = (tmp & 0xF0) >> 4  # 0: 8位, 1: 16位
        self.table_id = tmp & 0x0F
        data = struct.unpack_from(str(64 * (precision + 1)) + 'B', data, 1)
        self.data_matrix = np.array(data, np.uint8).reshape([8 * (precision + 1), 8 * (precision + 1)])
 
 
class JpegSof0Marker(JpegMarker):
 
    def __init__(self, marker):
        super().__init__(marker)
        self.precision = 0
        self.width = 0
        self.height = 0
        self.channel_count = 0
        self.channel_info = {}
 
    def encode(self):
        pass
 
    def decode(self, data):
        self.precision, self.width, self.height, self.channel_count = struct.unpack_from('>B2HB', data)
        offset = 6
        for i in range(self.channel_count):
            channel_id, sampling_factor, dqt_id = struct.unpack_from('>3B', data, offset)
            self.channel_info[channel_id] = {
                'horizontal_factor': (sampling_factor & 0xF0) >> 4,
                'vertical_factor': sampling_factor & 0x0F,
                'dqt_id': dqt_id
            }
            offset += 3
 
 
class JpegDhtMarker(JpegMarker):
 
    def __init__(self, marker):
        super().__init__(marker)
        self.table_id = 0
        self.type = None
        self.bits_length_seq = None
        self.symbol_seq = None
 
    def encode(self):
        pass
 
    def decode(self, data):
        tmp, = struct.unpack_from('B', data)
        self.type = (tmp & 0xF0) >> 4
        self.table_id = tmp & 0x0F
        self.bits_length_seq = struct.unpack_from('16B', data, 1)
        self.symbol_seq = struct.unpack_from(str(sum(self.bits_length_seq)) + 'B', data, 17)
 
 
class JpegDriMarker(JpegMarker):
    def __init__(self, marker):
        super().__init__(marker)
        self.restart_interval = 0
 
    def encode(self):
        pass
 
    def decode(self, data):
        self.restart_interval, = struct.unpack_from('>H', data)
 
 
class JpegImageMarker(JpegMarker):
    def __init__(self, marker):
        super().__init__(marker)
        self.channel_count = None  # 1灰度图 3YCrCb 4 CMYK
        self.huffman_map = {}
        self.image_data = None
 
    def encode(self):
        pass
 
    def decode(self, data):
        info_size, = struct.unpack_from('>H', data)
        self.image_data = self._remove_ff00(data[info_size:])
        self.channel_count, = struct.unpack_from('>B', data, 2)
        offset = 3
        for i in range(0, self.channel_count):
            channel_id, huffman_table = struct.unpack_from('>2B', data, offset)
            offset += 2
            self.huffman_map[channel_id] = {
                'AC': (huffman_table & 0xF0) >> 4,
                'DC': huffman_table & 0x0F,
            }
 
    @staticmethod
    def _remove_ff00(data):
        data_size = len(data)
        ret = []
        i = 0
        while i < data_size - 1:
            b, bnext = struct.unpack_from('2B', data, i)
            if b == 0xff and bnext == 0x00:
                ret.append(data[i])
                i += 2
            else:
                ret.append(data[i])
                i += 1
        ret.append(data[-1])
        return ret
 
 
class JpegMarkerList:
    def __init__(self):
        self.__markers = []
 
    def add(self, marker):
        """
        添加marker
        :param None|JpegMarker marker:
        :return:
        """
        if marker is not None:
            self.__markers.append(marker)
 
    def dump(self):
        for marker in self.__markers:
            print(marker.__dict__)
 
    def get_markers(self, marker_identify) -> typing.List[JpegMarker]:
        ret = []
        for marker in self.__markers:
            if marker_identify == marker.marker():
                ret.append(marker)
        return ret
 
    def get_image(self) -> typing.Union[JpegImageMarker, None]:
        markers = self.get_markers(FRAME_SOS)
        if len(markers) != 1:
            return None
        else:
            return markers[0]
 
    def get_quantization_table(self) -> typing.Union[typing.Dict[int, np.ndarray], None]:
        markers = self.get_markers(FRAME_DQT)
        if len(markers) == 0:
            return None
        quantization_table = {}
        for marker in markers:
            quantization_table[marker.table_id] = marker.data_matrix
        return quantization_table
 
    def get_image_info(self) -> typing.Union[JpegSof0Marker, None]:
        markers = self.get_markers(FRAME_SOF0)
        if len(markers) != 1:
            return None
        return markers[0]
 
    def get_restart_interval(self) -> int:
        markers = self.get_markers(FRAME_DRI)
        if len(markers) != 1:
            return 0
        return markers[0].restart_interval
 
    def get_huffman_table(self) -> typing.Union[typing.Dict[str, typing.Dict[int, HuffmanCodec]], None]:
        markers = self.get_markers(FRAME_DHT)
        if len(markers) == 0:
            return None
        huffman_table = {
            'DC': {},
            'AC': {},
        }
        for marker in markers:
            codec = HuffmanCodec()
            codec.build_from_bits(marker.bits_length_seq, marker.symbol_seq)
            if marker.type == 0:
                huffman_table['DC'][marker.table_id] = codec
            else:
                huffman_table['AC'][marker.table_id] = codec
        return huffman_table
 
 
class JpegDecoder:
 
    def __init__(self, data):
        self.__data = data
        self.__offset = 0
 
    def _read_bytes(self, fmt='c'):
        ret = struct.unpack_from(fmt, self.__data, self.__offset)
        self.__offset += struct.calcsize(fmt)
        return ret[0]
 
    @staticmethod
    def _is_sof_marker(marker):
        return marker == FRAME_SOF0 or marker == FRAME_SOF1 or marker == FRAME_SOF2 or marker == FRAME_SOF3 \
               or marker == FRAME_SOF5 or marker == FRAME_SOF6 or marker == FRAME_SOF7 or marker == FRAME_SOF8 \
               or marker == FRAME_SOF9 or marker == FRAME_SOF10 or marker == FRAME_SOF11 or marker == FRAME_SOF13 \
               or marker == FRAME_SOF14 or marker == FRAME_SOF15
 
    @staticmethod
    def _is_app_marker(marker):
        return marker == FRAME_APP0 or marker == FRAME_APP1 or marker == FRAME_APP2 or marker == FRAME_APP3 \
               or marker == FRAME_APP4 or marker == FRAME_APP5 or marker == FRAME_APP6 or marker == FRAME_APP7 \
               or marker == FRAME_APP8 or marker == FRAME_APP9 or marker == FRAME_APP10 or marker == FRAME_APP11 \
               or marker == FRAME_APP12 or marker == FRAME_APP13 or marker == FRAME_APP14 or marker == FRAME_APP15
 
    def _read_marker_size(self):
        size = self._read_bytes('>H')
        return size
 
    def _read_app_marker(self, marker):
        marker_size = self._read_marker_size()
        if marker_size < 2:
            return None
        if marker == FRAME_APP0:
            ret = JpegApp0Marker(marker)
        elif marker == FRAME_APP1:
            ret = JpegApp1Marker(marker)
        elif marker == FRAME_APP12:
            ret = JpegApp12Marker(marker)
        elif marker == FRAME_APP13:
            ret = JpegApp13Marker(marker)
        elif marker == FRAME_APP14:
            ret = JpegApp14Marker(marker)
        else:
            ret = JpegMarker(marker)
        ret.decode(self.__data[self.__offset:self.__offset + marker_size - 2])
        self.__offset += marker_size - 2
        return ret
 
    def _read_sof_marker(self, marker):
        marker_size = self._read_marker_size()
        if marker_size < 2:
            return None
        if marker == FRAME_SOF0:
            ret = JpegSof0Marker(marker)
        else:
            ret = JpegMarker(marker)
        ret.decode(self.__data[self.__offset:self.__offset + marker_size - 2])
        self.__offset += marker_size - 2
        return ret
 
    def _read_dqt_marker(self):
        marker_size = self._read_marker_size()
        if marker_size < 2:
            return None
        ret = JpegDQTMarker(FRAME_DQT)
        ret.decode(self.__data[self.__offset:self.__offset + marker_size - 2])
        self.__offset += marker_size - 2
        return ret
 
    def _read_dht_marker(self):
        marker_size = self._read_marker_size()
        if marker_size < 2:
            return None
        ret = JpegDhtMarker(FRAME_DHT)
        ret.decode(self.__data[self.__offset:self.__offset + marker_size - 2])
        self.__offset += marker_size - 2
        return ret
 
    def _read_dri_marker(self):
        marker_size = self._read_marker_size()
        ret = JpegDriMarker(FRAME_DRI)
        ret.decode(self.__data[self.__offset:self.__offset + marker_size - 2])
        self.__offset += marker_size - 2
        return ret
 
    def _read_image(self):
        marker = JpegImageMarker(FRAME_SOS)
        if self.__data[-2:] != b'\xff' + FRAME_EOI:
            return None
        marker.decode(self.__data[self.__offset:-2])
        self.__offset = len(self.__data) - 2
        return marker
 
    def read_markers(self):
        marker_list = JpegMarkerList()
        while self.__offset < len(self.__data):
            if self._read_bytes() != b'\xff':
                return None
            marker = self._read_bytes()
            if self._is_app_marker(marker):
                marker_list.add(self._read_app_marker(marker))
            elif self._is_sof_marker(marker):
                marker_list.add(self._read_sof_marker(marker))
            elif marker == FRAME_DQT:
                marker_list.add(self._read_dqt_marker())
            elif marker == FRAME_DHT:
                marker_list.add(self._read_dht_marker())
            elif marker == FRAME_SOS:
                marker_list.add(self._read_image())
            elif marker == FRAME_DRI:
                marker_list.add(self._read_dri_marker())
            elif marker == FRAME_EOI:
                return marker_list
            elif marker == FRAME_SOI:
                continue
            else:
                return None
 
    def read_data(self):
        markers = self.read_markers()
        image_marker = markers.get_image()
        # 仅处理YCrCb
        if image_marker.channel_count != 3:
            return None
        # 获取sof0,得到图片大小和各通道信息
        image_info = markers.get_image_info()
        # 仅处理YCrCb
        if image_info.channel_count != 3:
            return None
        # 获取dqt表
        quantization_table = markers.get_quantization_table()
        # 哈夫曼编码表
        huffman_table = markers.get_huffman_table()
        # 初始化之流信号prev表
        dc_prev = {}
        for channel_id in image_info.channel_info.keys():
            dc_prev[channel_id] = 0
        # restart interval
        restart_interval = markers.get_restart_interval()
        restart_interval_to_go = restart_interval
        restart_count = 0
        # 开始解码
        image_stream = BitStream(image_marker.image_data)
        img = np.zeros([image_info.width + 16, image_info.height + 16, 3], np.uint8)
        for y in range(0, image_info.width, 16):
            for x in range(0, image_info.height, 16):
                for channel_id, channel_info in image_info.channel_info.items():
                    channel_huffman_map = image_marker.huffman_map[channel_id]
                    dc_codec = huffman_table['DC'][channel_huffman_map['DC']]
                    ac_codec = huffman_table['AC'][channel_huffman_map['AC']]
                    quantization = quantization_table[channel_info['dqt_id']]
                    for y_block in range(channel_info['vertical_factor']):
                        for x_block in range(channel_info['horizontal_factor']):
                            if restart_interval > 0 and restart_interval_to_go == 0:
                                marker = image_stream.get_bits(8)
                                if (marker & 0x07) != (restart_count & 0x07):
                                    raise UserWarning('error restart marker')
                                restart_count += 1
                                restart_interval_to_go = restart_interval
                                for i in image_info.channel_info.keys():
                                    dc_prev[i] = 0
                            idct, dc_prev[channel_id] = self.process_huffman_data_unit(image_stream,
                                                                                       dc_codec,
                                                                                       ac_codec,
                                                                                       quantization,
                                                                                       dc_prev[channel_id])
                            if channel_info['vertical_factor'] == 2 and channel_info['horizontal_factor'] == 2:
                                img[y + y_block * 8: y + (y_block + 1) * 8, x + x_block * 8: x + (x_block + 1) * 8,
                                    channel_id - 1] = idct
                            elif channel_info['vertical_factor'] == 2 and channel_info['horizontal_factor'] == 1:
                                tmp = np.zeros([8, 16], np.uint8)
                                tmp[:, ::2] = idct
                                tmp[:, 1::2] = idct
                                img[y + y_block * 8: y + (y_block + 1) * 8, x: x+16, channel_id - 1] = tmp
                            elif channel_info['vertical_factor'] == 1 and channel_info['horizontal_factor'] == 2:
                                tmp = np.zeros([16, 8], np.uint8)
                                tmp[::2, :] = idct
                                tmp[1::2, :] = idct
                                img[y: y + 16, x + x_block * 8: x + (x_block + 1) * 8, channel_id - 1] = tmp
                            else:
                                img[y: y + 16, x: x + 16, channel_id - 1] = ndimage.interpolation.zoom(idct, 2)
                            restart_interval_to_go -= 1
        # YCbCr 转 BGR
        for y in range(0, image_info.width):
            for x in range(0, image_info.height):
                color_b = img[y, x, 0] + 1.772 * (img[y, x, 1] - 128)
                color_g = img[y, x, 0] - 0.344136 * (img[y, x, 1] - 128) - 0.714136 * (img[y, x, 2] - 128)
                color_r = img[y, x, 0] + 1.402 * (img[y, x, 2] - 128)
                img[y, x, :] = [color_b, color_g, color_r]
        return img[0:image_info.width, 0:image_info.height, :]
 
    @staticmethod
    def process_huffman_data_unit(image_stream: BitStream,
                                  dc_codec: HuffmanCodec,
                                  ac_codec: HuffmanCodec,
                                  quantization: np.ndarray,
                                  prev_dc_symbol: int):
        dct = np.zeros([64])
        # 解码直流部分
        huff_symbol = dc_codec.get_symbol(image_stream)
        if huff_symbol > 0:
            bits = image_stream.get_bits(huff_symbol)
            prev_dc_symbol += decode_number(huff_symbol, bits)
            dct[0] = prev_dc_symbol
        else:
            dct[0] = prev_dc_symbol
 
        # 解码交流部分
        index = 1
        while index < 64:
            huff_symbol = ac_codec.get_symbol(image_stream)
            if huff_symbol == 0:
                # EOF
                break
            if huff_symbol > 15:
                index += huff_symbol >> 4
                huff_symbol &= 0x0f
            dct[index] = decode_number(huff_symbol, image_stream.get_bits(huff_symbol))
            index += 1
        dct = dct.reshape([8, 8])
        dct = inverse_zigzag_single(dct)
        dct = dct * quantization
        idct = fft.idct(fft.idct(dct.T, norm='ortho').T, norm='ortho')
        idct = idct + 0x80
        idct = idct.astype(np.uint8)
        return idct, prev_dc_symbol
 
 
if __name__ == '__main__':
    with open('test.jpg', 'rb') as fp:
        decoder = JpegDecoder(fp.read())
        img = decoder.read_data()
        cv2.imshow('img', img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

 参考文档

https://zh.wikipedia.org/wiki/JPEG%E6%96%87%E4%BB%B6%E4%BA%A4%E6%8D%A2%E6%A0%BC%E5%BC%8F

https://zh.wikipedia.org/wiki/%E8%89%B2%E5%BA%A6%E6%8A%BD%E6%A0%B7

离散余弦变换(DCT)的来龙去脉_独孤呆博的博客-CSDN博客_二维离散余弦变换

    原文作者:_小B
    原文地址: https://blog.csdn.net/picone/article/details/123788128
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞