xref: /OK3568_Linux_fs/kernel/crypto/zstd.c (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0-only
2*4882a593Smuzhiyun /*
3*4882a593Smuzhiyun  * Cryptographic API.
4*4882a593Smuzhiyun  *
5*4882a593Smuzhiyun  * Copyright (c) 2017-present, Facebook, Inc.
6*4882a593Smuzhiyun  */
7*4882a593Smuzhiyun #include <linux/crypto.h>
8*4882a593Smuzhiyun #include <linux/init.h>
9*4882a593Smuzhiyun #include <linux/interrupt.h>
10*4882a593Smuzhiyun #include <linux/mm.h>
11*4882a593Smuzhiyun #include <linux/module.h>
12*4882a593Smuzhiyun #include <linux/net.h>
13*4882a593Smuzhiyun #include <linux/vmalloc.h>
14*4882a593Smuzhiyun #include <linux/zstd.h>
15*4882a593Smuzhiyun #include <crypto/internal/scompress.h>
16*4882a593Smuzhiyun 
17*4882a593Smuzhiyun 
18*4882a593Smuzhiyun #define ZSTD_DEF_LEVEL	3
19*4882a593Smuzhiyun 
20*4882a593Smuzhiyun struct zstd_ctx {
21*4882a593Smuzhiyun 	ZSTD_CCtx *cctx;
22*4882a593Smuzhiyun 	ZSTD_DCtx *dctx;
23*4882a593Smuzhiyun 	void *cwksp;
24*4882a593Smuzhiyun 	void *dwksp;
25*4882a593Smuzhiyun };
26*4882a593Smuzhiyun 
zstd_params(void)27*4882a593Smuzhiyun static ZSTD_parameters zstd_params(void)
28*4882a593Smuzhiyun {
29*4882a593Smuzhiyun 	return ZSTD_getParams(ZSTD_DEF_LEVEL, 0, 0);
30*4882a593Smuzhiyun }
31*4882a593Smuzhiyun 
zstd_comp_init(struct zstd_ctx * ctx)32*4882a593Smuzhiyun static int zstd_comp_init(struct zstd_ctx *ctx)
33*4882a593Smuzhiyun {
34*4882a593Smuzhiyun 	int ret = 0;
35*4882a593Smuzhiyun 	const ZSTD_parameters params = zstd_params();
36*4882a593Smuzhiyun 	const size_t wksp_size = ZSTD_CCtxWorkspaceBound(params.cParams);
37*4882a593Smuzhiyun 
38*4882a593Smuzhiyun 	ctx->cwksp = vzalloc(wksp_size);
39*4882a593Smuzhiyun 	if (!ctx->cwksp) {
40*4882a593Smuzhiyun 		ret = -ENOMEM;
41*4882a593Smuzhiyun 		goto out;
42*4882a593Smuzhiyun 	}
43*4882a593Smuzhiyun 
44*4882a593Smuzhiyun 	ctx->cctx = ZSTD_initCCtx(ctx->cwksp, wksp_size);
45*4882a593Smuzhiyun 	if (!ctx->cctx) {
46*4882a593Smuzhiyun 		ret = -EINVAL;
47*4882a593Smuzhiyun 		goto out_free;
48*4882a593Smuzhiyun 	}
49*4882a593Smuzhiyun out:
50*4882a593Smuzhiyun 	return ret;
51*4882a593Smuzhiyun out_free:
52*4882a593Smuzhiyun 	vfree(ctx->cwksp);
53*4882a593Smuzhiyun 	goto out;
54*4882a593Smuzhiyun }
55*4882a593Smuzhiyun 
zstd_decomp_init(struct zstd_ctx * ctx)56*4882a593Smuzhiyun static int zstd_decomp_init(struct zstd_ctx *ctx)
57*4882a593Smuzhiyun {
58*4882a593Smuzhiyun 	int ret = 0;
59*4882a593Smuzhiyun 	const size_t wksp_size = ZSTD_DCtxWorkspaceBound();
60*4882a593Smuzhiyun 
61*4882a593Smuzhiyun 	ctx->dwksp = vzalloc(wksp_size);
62*4882a593Smuzhiyun 	if (!ctx->dwksp) {
63*4882a593Smuzhiyun 		ret = -ENOMEM;
64*4882a593Smuzhiyun 		goto out;
65*4882a593Smuzhiyun 	}
66*4882a593Smuzhiyun 
67*4882a593Smuzhiyun 	ctx->dctx = ZSTD_initDCtx(ctx->dwksp, wksp_size);
68*4882a593Smuzhiyun 	if (!ctx->dctx) {
69*4882a593Smuzhiyun 		ret = -EINVAL;
70*4882a593Smuzhiyun 		goto out_free;
71*4882a593Smuzhiyun 	}
72*4882a593Smuzhiyun out:
73*4882a593Smuzhiyun 	return ret;
74*4882a593Smuzhiyun out_free:
75*4882a593Smuzhiyun 	vfree(ctx->dwksp);
76*4882a593Smuzhiyun 	goto out;
77*4882a593Smuzhiyun }
78*4882a593Smuzhiyun 
zstd_comp_exit(struct zstd_ctx * ctx)79*4882a593Smuzhiyun static void zstd_comp_exit(struct zstd_ctx *ctx)
80*4882a593Smuzhiyun {
81*4882a593Smuzhiyun 	vfree(ctx->cwksp);
82*4882a593Smuzhiyun 	ctx->cwksp = NULL;
83*4882a593Smuzhiyun 	ctx->cctx = NULL;
84*4882a593Smuzhiyun }
85*4882a593Smuzhiyun 
zstd_decomp_exit(struct zstd_ctx * ctx)86*4882a593Smuzhiyun static void zstd_decomp_exit(struct zstd_ctx *ctx)
87*4882a593Smuzhiyun {
88*4882a593Smuzhiyun 	vfree(ctx->dwksp);
89*4882a593Smuzhiyun 	ctx->dwksp = NULL;
90*4882a593Smuzhiyun 	ctx->dctx = NULL;
91*4882a593Smuzhiyun }
92*4882a593Smuzhiyun 
__zstd_init(void * ctx)93*4882a593Smuzhiyun static int __zstd_init(void *ctx)
94*4882a593Smuzhiyun {
95*4882a593Smuzhiyun 	int ret;
96*4882a593Smuzhiyun 
97*4882a593Smuzhiyun 	ret = zstd_comp_init(ctx);
98*4882a593Smuzhiyun 	if (ret)
99*4882a593Smuzhiyun 		return ret;
100*4882a593Smuzhiyun 	ret = zstd_decomp_init(ctx);
101*4882a593Smuzhiyun 	if (ret)
102*4882a593Smuzhiyun 		zstd_comp_exit(ctx);
103*4882a593Smuzhiyun 	return ret;
104*4882a593Smuzhiyun }
105*4882a593Smuzhiyun 
zstd_alloc_ctx(struct crypto_scomp * tfm)106*4882a593Smuzhiyun static void *zstd_alloc_ctx(struct crypto_scomp *tfm)
107*4882a593Smuzhiyun {
108*4882a593Smuzhiyun 	int ret;
109*4882a593Smuzhiyun 	struct zstd_ctx *ctx;
110*4882a593Smuzhiyun 
111*4882a593Smuzhiyun 	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
112*4882a593Smuzhiyun 	if (!ctx)
113*4882a593Smuzhiyun 		return ERR_PTR(-ENOMEM);
114*4882a593Smuzhiyun 
115*4882a593Smuzhiyun 	ret = __zstd_init(ctx);
116*4882a593Smuzhiyun 	if (ret) {
117*4882a593Smuzhiyun 		kfree(ctx);
118*4882a593Smuzhiyun 		return ERR_PTR(ret);
119*4882a593Smuzhiyun 	}
120*4882a593Smuzhiyun 
121*4882a593Smuzhiyun 	return ctx;
122*4882a593Smuzhiyun }
123*4882a593Smuzhiyun 
zstd_init(struct crypto_tfm * tfm)124*4882a593Smuzhiyun static int zstd_init(struct crypto_tfm *tfm)
125*4882a593Smuzhiyun {
126*4882a593Smuzhiyun 	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
127*4882a593Smuzhiyun 
128*4882a593Smuzhiyun 	return __zstd_init(ctx);
129*4882a593Smuzhiyun }
130*4882a593Smuzhiyun 
__zstd_exit(void * ctx)131*4882a593Smuzhiyun static void __zstd_exit(void *ctx)
132*4882a593Smuzhiyun {
133*4882a593Smuzhiyun 	zstd_comp_exit(ctx);
134*4882a593Smuzhiyun 	zstd_decomp_exit(ctx);
135*4882a593Smuzhiyun }
136*4882a593Smuzhiyun 
zstd_free_ctx(struct crypto_scomp * tfm,void * ctx)137*4882a593Smuzhiyun static void zstd_free_ctx(struct crypto_scomp *tfm, void *ctx)
138*4882a593Smuzhiyun {
139*4882a593Smuzhiyun 	__zstd_exit(ctx);
140*4882a593Smuzhiyun 	kfree_sensitive(ctx);
141*4882a593Smuzhiyun }
142*4882a593Smuzhiyun 
zstd_exit(struct crypto_tfm * tfm)143*4882a593Smuzhiyun static void zstd_exit(struct crypto_tfm *tfm)
144*4882a593Smuzhiyun {
145*4882a593Smuzhiyun 	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
146*4882a593Smuzhiyun 
147*4882a593Smuzhiyun 	__zstd_exit(ctx);
148*4882a593Smuzhiyun }
149*4882a593Smuzhiyun 
__zstd_compress(const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen,void * ctx)150*4882a593Smuzhiyun static int __zstd_compress(const u8 *src, unsigned int slen,
151*4882a593Smuzhiyun 			   u8 *dst, unsigned int *dlen, void *ctx)
152*4882a593Smuzhiyun {
153*4882a593Smuzhiyun 	size_t out_len;
154*4882a593Smuzhiyun 	struct zstd_ctx *zctx = ctx;
155*4882a593Smuzhiyun 	const ZSTD_parameters params = zstd_params();
156*4882a593Smuzhiyun 
157*4882a593Smuzhiyun 	out_len = ZSTD_compressCCtx(zctx->cctx, dst, *dlen, src, slen, params);
158*4882a593Smuzhiyun 	if (ZSTD_isError(out_len))
159*4882a593Smuzhiyun 		return -EINVAL;
160*4882a593Smuzhiyun 	*dlen = out_len;
161*4882a593Smuzhiyun 	return 0;
162*4882a593Smuzhiyun }
163*4882a593Smuzhiyun 
zstd_compress(struct crypto_tfm * tfm,const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen)164*4882a593Smuzhiyun static int zstd_compress(struct crypto_tfm *tfm, const u8 *src,
165*4882a593Smuzhiyun 			 unsigned int slen, u8 *dst, unsigned int *dlen)
166*4882a593Smuzhiyun {
167*4882a593Smuzhiyun 	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
168*4882a593Smuzhiyun 
169*4882a593Smuzhiyun 	return __zstd_compress(src, slen, dst, dlen, ctx);
170*4882a593Smuzhiyun }
171*4882a593Smuzhiyun 
zstd_scompress(struct crypto_scomp * tfm,const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen,void * ctx)172*4882a593Smuzhiyun static int zstd_scompress(struct crypto_scomp *tfm, const u8 *src,
173*4882a593Smuzhiyun 			  unsigned int slen, u8 *dst, unsigned int *dlen,
174*4882a593Smuzhiyun 			  void *ctx)
175*4882a593Smuzhiyun {
176*4882a593Smuzhiyun 	return __zstd_compress(src, slen, dst, dlen, ctx);
177*4882a593Smuzhiyun }
178*4882a593Smuzhiyun 
__zstd_decompress(const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen,void * ctx)179*4882a593Smuzhiyun static int __zstd_decompress(const u8 *src, unsigned int slen,
180*4882a593Smuzhiyun 			     u8 *dst, unsigned int *dlen, void *ctx)
181*4882a593Smuzhiyun {
182*4882a593Smuzhiyun 	size_t out_len;
183*4882a593Smuzhiyun 	struct zstd_ctx *zctx = ctx;
184*4882a593Smuzhiyun 
185*4882a593Smuzhiyun 	out_len = ZSTD_decompressDCtx(zctx->dctx, dst, *dlen, src, slen);
186*4882a593Smuzhiyun 	if (ZSTD_isError(out_len))
187*4882a593Smuzhiyun 		return -EINVAL;
188*4882a593Smuzhiyun 	*dlen = out_len;
189*4882a593Smuzhiyun 	return 0;
190*4882a593Smuzhiyun }
191*4882a593Smuzhiyun 
zstd_decompress(struct crypto_tfm * tfm,const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen)192*4882a593Smuzhiyun static int zstd_decompress(struct crypto_tfm *tfm, const u8 *src,
193*4882a593Smuzhiyun 			   unsigned int slen, u8 *dst, unsigned int *dlen)
194*4882a593Smuzhiyun {
195*4882a593Smuzhiyun 	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
196*4882a593Smuzhiyun 
197*4882a593Smuzhiyun 	return __zstd_decompress(src, slen, dst, dlen, ctx);
198*4882a593Smuzhiyun }
199*4882a593Smuzhiyun 
zstd_sdecompress(struct crypto_scomp * tfm,const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen,void * ctx)200*4882a593Smuzhiyun static int zstd_sdecompress(struct crypto_scomp *tfm, const u8 *src,
201*4882a593Smuzhiyun 			    unsigned int slen, u8 *dst, unsigned int *dlen,
202*4882a593Smuzhiyun 			    void *ctx)
203*4882a593Smuzhiyun {
204*4882a593Smuzhiyun 	return __zstd_decompress(src, slen, dst, dlen, ctx);
205*4882a593Smuzhiyun }
206*4882a593Smuzhiyun 
207*4882a593Smuzhiyun static struct crypto_alg alg = {
208*4882a593Smuzhiyun 	.cra_name		= "zstd",
209*4882a593Smuzhiyun 	.cra_driver_name	= "zstd-generic",
210*4882a593Smuzhiyun 	.cra_flags		= CRYPTO_ALG_TYPE_COMPRESS,
211*4882a593Smuzhiyun 	.cra_ctxsize		= sizeof(struct zstd_ctx),
212*4882a593Smuzhiyun 	.cra_module		= THIS_MODULE,
213*4882a593Smuzhiyun 	.cra_init		= zstd_init,
214*4882a593Smuzhiyun 	.cra_exit		= zstd_exit,
215*4882a593Smuzhiyun 	.cra_u			= { .compress = {
216*4882a593Smuzhiyun 	.coa_compress		= zstd_compress,
217*4882a593Smuzhiyun 	.coa_decompress		= zstd_decompress } }
218*4882a593Smuzhiyun };
219*4882a593Smuzhiyun 
220*4882a593Smuzhiyun static struct scomp_alg scomp = {
221*4882a593Smuzhiyun 	.alloc_ctx		= zstd_alloc_ctx,
222*4882a593Smuzhiyun 	.free_ctx		= zstd_free_ctx,
223*4882a593Smuzhiyun 	.compress		= zstd_scompress,
224*4882a593Smuzhiyun 	.decompress		= zstd_sdecompress,
225*4882a593Smuzhiyun 	.base			= {
226*4882a593Smuzhiyun 		.cra_name	= "zstd",
227*4882a593Smuzhiyun 		.cra_driver_name = "zstd-scomp",
228*4882a593Smuzhiyun 		.cra_module	 = THIS_MODULE,
229*4882a593Smuzhiyun 	}
230*4882a593Smuzhiyun };
231*4882a593Smuzhiyun 
zstd_mod_init(void)232*4882a593Smuzhiyun static int __init zstd_mod_init(void)
233*4882a593Smuzhiyun {
234*4882a593Smuzhiyun 	int ret;
235*4882a593Smuzhiyun 
236*4882a593Smuzhiyun 	ret = crypto_register_alg(&alg);
237*4882a593Smuzhiyun 	if (ret)
238*4882a593Smuzhiyun 		return ret;
239*4882a593Smuzhiyun 
240*4882a593Smuzhiyun 	ret = crypto_register_scomp(&scomp);
241*4882a593Smuzhiyun 	if (ret)
242*4882a593Smuzhiyun 		crypto_unregister_alg(&alg);
243*4882a593Smuzhiyun 
244*4882a593Smuzhiyun 	return ret;
245*4882a593Smuzhiyun }
246*4882a593Smuzhiyun 
zstd_mod_fini(void)247*4882a593Smuzhiyun static void __exit zstd_mod_fini(void)
248*4882a593Smuzhiyun {
249*4882a593Smuzhiyun 	crypto_unregister_alg(&alg);
250*4882a593Smuzhiyun 	crypto_unregister_scomp(&scomp);
251*4882a593Smuzhiyun }
252*4882a593Smuzhiyun 
253*4882a593Smuzhiyun subsys_initcall(zstd_mod_init);
254*4882a593Smuzhiyun module_exit(zstd_mod_fini);
255*4882a593Smuzhiyun 
256*4882a593Smuzhiyun MODULE_LICENSE("GPL");
257*4882a593Smuzhiyun MODULE_DESCRIPTION("Zstd Compression Algorithm");
258*4882a593Smuzhiyun MODULE_ALIAS_CRYPTO("zstd");
259