1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0-only
2*4882a593Smuzhiyun /*
3*4882a593Smuzhiyun * Cryptographic API.
4*4882a593Smuzhiyun *
5*4882a593Smuzhiyun * Copyright (c) 2017-present, Facebook, Inc.
6*4882a593Smuzhiyun */
7*4882a593Smuzhiyun #include <linux/crypto.h>
8*4882a593Smuzhiyun #include <linux/init.h>
9*4882a593Smuzhiyun #include <linux/interrupt.h>
10*4882a593Smuzhiyun #include <linux/mm.h>
11*4882a593Smuzhiyun #include <linux/module.h>
12*4882a593Smuzhiyun #include <linux/net.h>
13*4882a593Smuzhiyun #include <linux/vmalloc.h>
14*4882a593Smuzhiyun #include <linux/zstd.h>
15*4882a593Smuzhiyun #include <crypto/internal/scompress.h>
16*4882a593Smuzhiyun
17*4882a593Smuzhiyun
18*4882a593Smuzhiyun #define ZSTD_DEF_LEVEL 3
19*4882a593Smuzhiyun
20*4882a593Smuzhiyun struct zstd_ctx {
21*4882a593Smuzhiyun ZSTD_CCtx *cctx;
22*4882a593Smuzhiyun ZSTD_DCtx *dctx;
23*4882a593Smuzhiyun void *cwksp;
24*4882a593Smuzhiyun void *dwksp;
25*4882a593Smuzhiyun };
26*4882a593Smuzhiyun
zstd_params(void)27*4882a593Smuzhiyun static ZSTD_parameters zstd_params(void)
28*4882a593Smuzhiyun {
29*4882a593Smuzhiyun return ZSTD_getParams(ZSTD_DEF_LEVEL, 0, 0);
30*4882a593Smuzhiyun }
31*4882a593Smuzhiyun
zstd_comp_init(struct zstd_ctx * ctx)32*4882a593Smuzhiyun static int zstd_comp_init(struct zstd_ctx *ctx)
33*4882a593Smuzhiyun {
34*4882a593Smuzhiyun int ret = 0;
35*4882a593Smuzhiyun const ZSTD_parameters params = zstd_params();
36*4882a593Smuzhiyun const size_t wksp_size = ZSTD_CCtxWorkspaceBound(params.cParams);
37*4882a593Smuzhiyun
38*4882a593Smuzhiyun ctx->cwksp = vzalloc(wksp_size);
39*4882a593Smuzhiyun if (!ctx->cwksp) {
40*4882a593Smuzhiyun ret = -ENOMEM;
41*4882a593Smuzhiyun goto out;
42*4882a593Smuzhiyun }
43*4882a593Smuzhiyun
44*4882a593Smuzhiyun ctx->cctx = ZSTD_initCCtx(ctx->cwksp, wksp_size);
45*4882a593Smuzhiyun if (!ctx->cctx) {
46*4882a593Smuzhiyun ret = -EINVAL;
47*4882a593Smuzhiyun goto out_free;
48*4882a593Smuzhiyun }
49*4882a593Smuzhiyun out:
50*4882a593Smuzhiyun return ret;
51*4882a593Smuzhiyun out_free:
52*4882a593Smuzhiyun vfree(ctx->cwksp);
53*4882a593Smuzhiyun goto out;
54*4882a593Smuzhiyun }
55*4882a593Smuzhiyun
zstd_decomp_init(struct zstd_ctx * ctx)56*4882a593Smuzhiyun static int zstd_decomp_init(struct zstd_ctx *ctx)
57*4882a593Smuzhiyun {
58*4882a593Smuzhiyun int ret = 0;
59*4882a593Smuzhiyun const size_t wksp_size = ZSTD_DCtxWorkspaceBound();
60*4882a593Smuzhiyun
61*4882a593Smuzhiyun ctx->dwksp = vzalloc(wksp_size);
62*4882a593Smuzhiyun if (!ctx->dwksp) {
63*4882a593Smuzhiyun ret = -ENOMEM;
64*4882a593Smuzhiyun goto out;
65*4882a593Smuzhiyun }
66*4882a593Smuzhiyun
67*4882a593Smuzhiyun ctx->dctx = ZSTD_initDCtx(ctx->dwksp, wksp_size);
68*4882a593Smuzhiyun if (!ctx->dctx) {
69*4882a593Smuzhiyun ret = -EINVAL;
70*4882a593Smuzhiyun goto out_free;
71*4882a593Smuzhiyun }
72*4882a593Smuzhiyun out:
73*4882a593Smuzhiyun return ret;
74*4882a593Smuzhiyun out_free:
75*4882a593Smuzhiyun vfree(ctx->dwksp);
76*4882a593Smuzhiyun goto out;
77*4882a593Smuzhiyun }
78*4882a593Smuzhiyun
zstd_comp_exit(struct zstd_ctx * ctx)79*4882a593Smuzhiyun static void zstd_comp_exit(struct zstd_ctx *ctx)
80*4882a593Smuzhiyun {
81*4882a593Smuzhiyun vfree(ctx->cwksp);
82*4882a593Smuzhiyun ctx->cwksp = NULL;
83*4882a593Smuzhiyun ctx->cctx = NULL;
84*4882a593Smuzhiyun }
85*4882a593Smuzhiyun
zstd_decomp_exit(struct zstd_ctx * ctx)86*4882a593Smuzhiyun static void zstd_decomp_exit(struct zstd_ctx *ctx)
87*4882a593Smuzhiyun {
88*4882a593Smuzhiyun vfree(ctx->dwksp);
89*4882a593Smuzhiyun ctx->dwksp = NULL;
90*4882a593Smuzhiyun ctx->dctx = NULL;
91*4882a593Smuzhiyun }
92*4882a593Smuzhiyun
__zstd_init(void * ctx)93*4882a593Smuzhiyun static int __zstd_init(void *ctx)
94*4882a593Smuzhiyun {
95*4882a593Smuzhiyun int ret;
96*4882a593Smuzhiyun
97*4882a593Smuzhiyun ret = zstd_comp_init(ctx);
98*4882a593Smuzhiyun if (ret)
99*4882a593Smuzhiyun return ret;
100*4882a593Smuzhiyun ret = zstd_decomp_init(ctx);
101*4882a593Smuzhiyun if (ret)
102*4882a593Smuzhiyun zstd_comp_exit(ctx);
103*4882a593Smuzhiyun return ret;
104*4882a593Smuzhiyun }
105*4882a593Smuzhiyun
zstd_alloc_ctx(struct crypto_scomp * tfm)106*4882a593Smuzhiyun static void *zstd_alloc_ctx(struct crypto_scomp *tfm)
107*4882a593Smuzhiyun {
108*4882a593Smuzhiyun int ret;
109*4882a593Smuzhiyun struct zstd_ctx *ctx;
110*4882a593Smuzhiyun
111*4882a593Smuzhiyun ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
112*4882a593Smuzhiyun if (!ctx)
113*4882a593Smuzhiyun return ERR_PTR(-ENOMEM);
114*4882a593Smuzhiyun
115*4882a593Smuzhiyun ret = __zstd_init(ctx);
116*4882a593Smuzhiyun if (ret) {
117*4882a593Smuzhiyun kfree(ctx);
118*4882a593Smuzhiyun return ERR_PTR(ret);
119*4882a593Smuzhiyun }
120*4882a593Smuzhiyun
121*4882a593Smuzhiyun return ctx;
122*4882a593Smuzhiyun }
123*4882a593Smuzhiyun
zstd_init(struct crypto_tfm * tfm)124*4882a593Smuzhiyun static int zstd_init(struct crypto_tfm *tfm)
125*4882a593Smuzhiyun {
126*4882a593Smuzhiyun struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
127*4882a593Smuzhiyun
128*4882a593Smuzhiyun return __zstd_init(ctx);
129*4882a593Smuzhiyun }
130*4882a593Smuzhiyun
__zstd_exit(void * ctx)131*4882a593Smuzhiyun static void __zstd_exit(void *ctx)
132*4882a593Smuzhiyun {
133*4882a593Smuzhiyun zstd_comp_exit(ctx);
134*4882a593Smuzhiyun zstd_decomp_exit(ctx);
135*4882a593Smuzhiyun }
136*4882a593Smuzhiyun
zstd_free_ctx(struct crypto_scomp * tfm,void * ctx)137*4882a593Smuzhiyun static void zstd_free_ctx(struct crypto_scomp *tfm, void *ctx)
138*4882a593Smuzhiyun {
139*4882a593Smuzhiyun __zstd_exit(ctx);
140*4882a593Smuzhiyun kfree_sensitive(ctx);
141*4882a593Smuzhiyun }
142*4882a593Smuzhiyun
zstd_exit(struct crypto_tfm * tfm)143*4882a593Smuzhiyun static void zstd_exit(struct crypto_tfm *tfm)
144*4882a593Smuzhiyun {
145*4882a593Smuzhiyun struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
146*4882a593Smuzhiyun
147*4882a593Smuzhiyun __zstd_exit(ctx);
148*4882a593Smuzhiyun }
149*4882a593Smuzhiyun
__zstd_compress(const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen,void * ctx)150*4882a593Smuzhiyun static int __zstd_compress(const u8 *src, unsigned int slen,
151*4882a593Smuzhiyun u8 *dst, unsigned int *dlen, void *ctx)
152*4882a593Smuzhiyun {
153*4882a593Smuzhiyun size_t out_len;
154*4882a593Smuzhiyun struct zstd_ctx *zctx = ctx;
155*4882a593Smuzhiyun const ZSTD_parameters params = zstd_params();
156*4882a593Smuzhiyun
157*4882a593Smuzhiyun out_len = ZSTD_compressCCtx(zctx->cctx, dst, *dlen, src, slen, params);
158*4882a593Smuzhiyun if (ZSTD_isError(out_len))
159*4882a593Smuzhiyun return -EINVAL;
160*4882a593Smuzhiyun *dlen = out_len;
161*4882a593Smuzhiyun return 0;
162*4882a593Smuzhiyun }
163*4882a593Smuzhiyun
zstd_compress(struct crypto_tfm * tfm,const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen)164*4882a593Smuzhiyun static int zstd_compress(struct crypto_tfm *tfm, const u8 *src,
165*4882a593Smuzhiyun unsigned int slen, u8 *dst, unsigned int *dlen)
166*4882a593Smuzhiyun {
167*4882a593Smuzhiyun struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
168*4882a593Smuzhiyun
169*4882a593Smuzhiyun return __zstd_compress(src, slen, dst, dlen, ctx);
170*4882a593Smuzhiyun }
171*4882a593Smuzhiyun
zstd_scompress(struct crypto_scomp * tfm,const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen,void * ctx)172*4882a593Smuzhiyun static int zstd_scompress(struct crypto_scomp *tfm, const u8 *src,
173*4882a593Smuzhiyun unsigned int slen, u8 *dst, unsigned int *dlen,
174*4882a593Smuzhiyun void *ctx)
175*4882a593Smuzhiyun {
176*4882a593Smuzhiyun return __zstd_compress(src, slen, dst, dlen, ctx);
177*4882a593Smuzhiyun }
178*4882a593Smuzhiyun
__zstd_decompress(const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen,void * ctx)179*4882a593Smuzhiyun static int __zstd_decompress(const u8 *src, unsigned int slen,
180*4882a593Smuzhiyun u8 *dst, unsigned int *dlen, void *ctx)
181*4882a593Smuzhiyun {
182*4882a593Smuzhiyun size_t out_len;
183*4882a593Smuzhiyun struct zstd_ctx *zctx = ctx;
184*4882a593Smuzhiyun
185*4882a593Smuzhiyun out_len = ZSTD_decompressDCtx(zctx->dctx, dst, *dlen, src, slen);
186*4882a593Smuzhiyun if (ZSTD_isError(out_len))
187*4882a593Smuzhiyun return -EINVAL;
188*4882a593Smuzhiyun *dlen = out_len;
189*4882a593Smuzhiyun return 0;
190*4882a593Smuzhiyun }
191*4882a593Smuzhiyun
zstd_decompress(struct crypto_tfm * tfm,const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen)192*4882a593Smuzhiyun static int zstd_decompress(struct crypto_tfm *tfm, const u8 *src,
193*4882a593Smuzhiyun unsigned int slen, u8 *dst, unsigned int *dlen)
194*4882a593Smuzhiyun {
195*4882a593Smuzhiyun struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
196*4882a593Smuzhiyun
197*4882a593Smuzhiyun return __zstd_decompress(src, slen, dst, dlen, ctx);
198*4882a593Smuzhiyun }
199*4882a593Smuzhiyun
zstd_sdecompress(struct crypto_scomp * tfm,const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen,void * ctx)200*4882a593Smuzhiyun static int zstd_sdecompress(struct crypto_scomp *tfm, const u8 *src,
201*4882a593Smuzhiyun unsigned int slen, u8 *dst, unsigned int *dlen,
202*4882a593Smuzhiyun void *ctx)
203*4882a593Smuzhiyun {
204*4882a593Smuzhiyun return __zstd_decompress(src, slen, dst, dlen, ctx);
205*4882a593Smuzhiyun }
206*4882a593Smuzhiyun
207*4882a593Smuzhiyun static struct crypto_alg alg = {
208*4882a593Smuzhiyun .cra_name = "zstd",
209*4882a593Smuzhiyun .cra_driver_name = "zstd-generic",
210*4882a593Smuzhiyun .cra_flags = CRYPTO_ALG_TYPE_COMPRESS,
211*4882a593Smuzhiyun .cra_ctxsize = sizeof(struct zstd_ctx),
212*4882a593Smuzhiyun .cra_module = THIS_MODULE,
213*4882a593Smuzhiyun .cra_init = zstd_init,
214*4882a593Smuzhiyun .cra_exit = zstd_exit,
215*4882a593Smuzhiyun .cra_u = { .compress = {
216*4882a593Smuzhiyun .coa_compress = zstd_compress,
217*4882a593Smuzhiyun .coa_decompress = zstd_decompress } }
218*4882a593Smuzhiyun };
219*4882a593Smuzhiyun
220*4882a593Smuzhiyun static struct scomp_alg scomp = {
221*4882a593Smuzhiyun .alloc_ctx = zstd_alloc_ctx,
222*4882a593Smuzhiyun .free_ctx = zstd_free_ctx,
223*4882a593Smuzhiyun .compress = zstd_scompress,
224*4882a593Smuzhiyun .decompress = zstd_sdecompress,
225*4882a593Smuzhiyun .base = {
226*4882a593Smuzhiyun .cra_name = "zstd",
227*4882a593Smuzhiyun .cra_driver_name = "zstd-scomp",
228*4882a593Smuzhiyun .cra_module = THIS_MODULE,
229*4882a593Smuzhiyun }
230*4882a593Smuzhiyun };
231*4882a593Smuzhiyun
zstd_mod_init(void)232*4882a593Smuzhiyun static int __init zstd_mod_init(void)
233*4882a593Smuzhiyun {
234*4882a593Smuzhiyun int ret;
235*4882a593Smuzhiyun
236*4882a593Smuzhiyun ret = crypto_register_alg(&alg);
237*4882a593Smuzhiyun if (ret)
238*4882a593Smuzhiyun return ret;
239*4882a593Smuzhiyun
240*4882a593Smuzhiyun ret = crypto_register_scomp(&scomp);
241*4882a593Smuzhiyun if (ret)
242*4882a593Smuzhiyun crypto_unregister_alg(&alg);
243*4882a593Smuzhiyun
244*4882a593Smuzhiyun return ret;
245*4882a593Smuzhiyun }
246*4882a593Smuzhiyun
zstd_mod_fini(void)247*4882a593Smuzhiyun static void __exit zstd_mod_fini(void)
248*4882a593Smuzhiyun {
249*4882a593Smuzhiyun crypto_unregister_alg(&alg);
250*4882a593Smuzhiyun crypto_unregister_scomp(&scomp);
251*4882a593Smuzhiyun }
252*4882a593Smuzhiyun
253*4882a593Smuzhiyun subsys_initcall(zstd_mod_init);
254*4882a593Smuzhiyun module_exit(zstd_mod_fini);
255*4882a593Smuzhiyun
256*4882a593Smuzhiyun MODULE_LICENSE("GPL");
257*4882a593Smuzhiyun MODULE_DESCRIPTION("Zstd Compression Algorithm");
258*4882a593Smuzhiyun MODULE_ALIAS_CRYPTO("zstd");
259