1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0
2*4882a593Smuzhiyun
3*4882a593Smuzhiyun #include <string.h>
4*4882a593Smuzhiyun
5*4882a593Smuzhiyun #include "util/compress.h"
6*4882a593Smuzhiyun #include "util/debug.h"
7*4882a593Smuzhiyun
zstd_init(struct zstd_data * data,int level)8*4882a593Smuzhiyun int zstd_init(struct zstd_data *data, int level)
9*4882a593Smuzhiyun {
10*4882a593Smuzhiyun size_t ret;
11*4882a593Smuzhiyun
12*4882a593Smuzhiyun data->dstream = ZSTD_createDStream();
13*4882a593Smuzhiyun if (data->dstream == NULL) {
14*4882a593Smuzhiyun pr_err("Couldn't create decompression stream.\n");
15*4882a593Smuzhiyun return -1;
16*4882a593Smuzhiyun }
17*4882a593Smuzhiyun
18*4882a593Smuzhiyun ret = ZSTD_initDStream(data->dstream);
19*4882a593Smuzhiyun if (ZSTD_isError(ret)) {
20*4882a593Smuzhiyun pr_err("Failed to initialize decompression stream: %s\n", ZSTD_getErrorName(ret));
21*4882a593Smuzhiyun return -1;
22*4882a593Smuzhiyun }
23*4882a593Smuzhiyun
24*4882a593Smuzhiyun if (!level)
25*4882a593Smuzhiyun return 0;
26*4882a593Smuzhiyun
27*4882a593Smuzhiyun data->cstream = ZSTD_createCStream();
28*4882a593Smuzhiyun if (data->cstream == NULL) {
29*4882a593Smuzhiyun pr_err("Couldn't create compression stream.\n");
30*4882a593Smuzhiyun return -1;
31*4882a593Smuzhiyun }
32*4882a593Smuzhiyun
33*4882a593Smuzhiyun ret = ZSTD_initCStream(data->cstream, level);
34*4882a593Smuzhiyun if (ZSTD_isError(ret)) {
35*4882a593Smuzhiyun pr_err("Failed to initialize compression stream: %s\n", ZSTD_getErrorName(ret));
36*4882a593Smuzhiyun return -1;
37*4882a593Smuzhiyun }
38*4882a593Smuzhiyun
39*4882a593Smuzhiyun return 0;
40*4882a593Smuzhiyun }
41*4882a593Smuzhiyun
zstd_fini(struct zstd_data * data)42*4882a593Smuzhiyun int zstd_fini(struct zstd_data *data)
43*4882a593Smuzhiyun {
44*4882a593Smuzhiyun if (data->dstream) {
45*4882a593Smuzhiyun ZSTD_freeDStream(data->dstream);
46*4882a593Smuzhiyun data->dstream = NULL;
47*4882a593Smuzhiyun }
48*4882a593Smuzhiyun
49*4882a593Smuzhiyun if (data->cstream) {
50*4882a593Smuzhiyun ZSTD_freeCStream(data->cstream);
51*4882a593Smuzhiyun data->cstream = NULL;
52*4882a593Smuzhiyun }
53*4882a593Smuzhiyun
54*4882a593Smuzhiyun return 0;
55*4882a593Smuzhiyun }
56*4882a593Smuzhiyun
zstd_compress_stream_to_records(struct zstd_data * data,void * dst,size_t dst_size,void * src,size_t src_size,size_t max_record_size,size_t process_header (void * record,size_t increment))57*4882a593Smuzhiyun size_t zstd_compress_stream_to_records(struct zstd_data *data, void *dst, size_t dst_size,
58*4882a593Smuzhiyun void *src, size_t src_size, size_t max_record_size,
59*4882a593Smuzhiyun size_t process_header(void *record, size_t increment))
60*4882a593Smuzhiyun {
61*4882a593Smuzhiyun size_t ret, size, compressed = 0;
62*4882a593Smuzhiyun ZSTD_inBuffer input = { src, src_size, 0 };
63*4882a593Smuzhiyun ZSTD_outBuffer output;
64*4882a593Smuzhiyun void *record;
65*4882a593Smuzhiyun
66*4882a593Smuzhiyun while (input.pos < input.size) {
67*4882a593Smuzhiyun record = dst;
68*4882a593Smuzhiyun size = process_header(record, 0);
69*4882a593Smuzhiyun compressed += size;
70*4882a593Smuzhiyun dst += size;
71*4882a593Smuzhiyun dst_size -= size;
72*4882a593Smuzhiyun output = (ZSTD_outBuffer){ dst, (dst_size > max_record_size) ?
73*4882a593Smuzhiyun max_record_size : dst_size, 0 };
74*4882a593Smuzhiyun ret = ZSTD_compressStream(data->cstream, &output, &input);
75*4882a593Smuzhiyun ZSTD_flushStream(data->cstream, &output);
76*4882a593Smuzhiyun if (ZSTD_isError(ret)) {
77*4882a593Smuzhiyun pr_err("failed to compress %ld bytes: %s\n",
78*4882a593Smuzhiyun (long)src_size, ZSTD_getErrorName(ret));
79*4882a593Smuzhiyun memcpy(dst, src, src_size);
80*4882a593Smuzhiyun return src_size;
81*4882a593Smuzhiyun }
82*4882a593Smuzhiyun size = output.pos;
83*4882a593Smuzhiyun size = process_header(record, size);
84*4882a593Smuzhiyun compressed += size;
85*4882a593Smuzhiyun dst += size;
86*4882a593Smuzhiyun dst_size -= size;
87*4882a593Smuzhiyun }
88*4882a593Smuzhiyun
89*4882a593Smuzhiyun return compressed;
90*4882a593Smuzhiyun }
91*4882a593Smuzhiyun
zstd_decompress_stream(struct zstd_data * data,void * src,size_t src_size,void * dst,size_t dst_size)92*4882a593Smuzhiyun size_t zstd_decompress_stream(struct zstd_data *data, void *src, size_t src_size,
93*4882a593Smuzhiyun void *dst, size_t dst_size)
94*4882a593Smuzhiyun {
95*4882a593Smuzhiyun size_t ret;
96*4882a593Smuzhiyun ZSTD_inBuffer input = { src, src_size, 0 };
97*4882a593Smuzhiyun ZSTD_outBuffer output = { dst, dst_size, 0 };
98*4882a593Smuzhiyun
99*4882a593Smuzhiyun while (input.pos < input.size) {
100*4882a593Smuzhiyun ret = ZSTD_decompressStream(data->dstream, &output, &input);
101*4882a593Smuzhiyun if (ZSTD_isError(ret)) {
102*4882a593Smuzhiyun pr_err("failed to decompress (B): %zd -> %zd, dst_size %zd : %s\n",
103*4882a593Smuzhiyun src_size, output.size, dst_size, ZSTD_getErrorName(ret));
104*4882a593Smuzhiyun break;
105*4882a593Smuzhiyun }
106*4882a593Smuzhiyun output.dst = dst + output.pos;
107*4882a593Smuzhiyun output.size = dst_size - output.pos;
108*4882a593Smuzhiyun }
109*4882a593Smuzhiyun
110*4882a593Smuzhiyun return output.pos;
111*4882a593Smuzhiyun }
112