xref: /OK3568_Linux_fs/external/rknpu2/examples/3rdparty/cnpy/cnpy.cpp (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun // Copyright (C) 2011  Carl Rogers
2*4882a593Smuzhiyun // Released under MIT License
3*4882a593Smuzhiyun // license available in LICENSE file, or at http://www.opensource.org/licenses/mit-license.php
4*4882a593Smuzhiyun 
5*4882a593Smuzhiyun #include "cnpy.h"
6*4882a593Smuzhiyun 
7*4882a593Smuzhiyun #include <stdint.h>
8*4882a593Smuzhiyun 
9*4882a593Smuzhiyun #include <algorithm>
10*4882a593Smuzhiyun #include <complex>
11*4882a593Smuzhiyun #include <cstdlib>
12*4882a593Smuzhiyun #include <cstring>
13*4882a593Smuzhiyun #include <iomanip>
14*4882a593Smuzhiyun #include <regex>
15*4882a593Smuzhiyun #include <stdexcept>
16*4882a593Smuzhiyun 
BigEndianTest(int size)17*4882a593Smuzhiyun char cnpy::BigEndianTest(int size)
18*4882a593Smuzhiyun {
19*4882a593Smuzhiyun   if (size == 1)
20*4882a593Smuzhiyun     return '|';
21*4882a593Smuzhiyun   int x = 1;
22*4882a593Smuzhiyun   return (((char*)&x)[0]) ? '<' : '>';
23*4882a593Smuzhiyun }
24*4882a593Smuzhiyun 
map_type(const std::type_info & t)25*4882a593Smuzhiyun char cnpy::map_type(const std::type_info& t)
26*4882a593Smuzhiyun {
27*4882a593Smuzhiyun   if (t == typeid(float))
28*4882a593Smuzhiyun     return 'f';
29*4882a593Smuzhiyun   if (t == typeid(double))
30*4882a593Smuzhiyun     return 'f';
31*4882a593Smuzhiyun   if (t == typeid(long double))
32*4882a593Smuzhiyun     return 'f';
33*4882a593Smuzhiyun 
34*4882a593Smuzhiyun   if (t == typeid(int))
35*4882a593Smuzhiyun     return 'i';
36*4882a593Smuzhiyun   if (t == typeid(char))
37*4882a593Smuzhiyun     return 'i';
38*4882a593Smuzhiyun   if (t == typeid(signed char))
39*4882a593Smuzhiyun     return 'i';
40*4882a593Smuzhiyun   if (t == typeid(short))
41*4882a593Smuzhiyun     return 'i';
42*4882a593Smuzhiyun   if (t == typeid(long))
43*4882a593Smuzhiyun     return 'i';
44*4882a593Smuzhiyun   if (t == typeid(long long))
45*4882a593Smuzhiyun     return 'i';
46*4882a593Smuzhiyun 
47*4882a593Smuzhiyun   if (t == typeid(unsigned char))
48*4882a593Smuzhiyun     return 'u';
49*4882a593Smuzhiyun   if (t == typeid(unsigned short))
50*4882a593Smuzhiyun     return 'u';
51*4882a593Smuzhiyun   if (t == typeid(unsigned long))
52*4882a593Smuzhiyun     return 'u';
53*4882a593Smuzhiyun   if (t == typeid(unsigned long long))
54*4882a593Smuzhiyun     return 'u';
55*4882a593Smuzhiyun   if (t == typeid(unsigned int))
56*4882a593Smuzhiyun     return 'u';
57*4882a593Smuzhiyun 
58*4882a593Smuzhiyun   if (t == typeid(bool))
59*4882a593Smuzhiyun     return 'b';
60*4882a593Smuzhiyun 
61*4882a593Smuzhiyun   if (t == typeid(std::complex<float>))
62*4882a593Smuzhiyun     return 'c';
63*4882a593Smuzhiyun   if (t == typeid(std::complex<double>))
64*4882a593Smuzhiyun     return 'c';
65*4882a593Smuzhiyun   if (t == typeid(std::complex<long double>))
66*4882a593Smuzhiyun     return 'c';
67*4882a593Smuzhiyun 
68*4882a593Smuzhiyun   else
69*4882a593Smuzhiyun     return '?';
70*4882a593Smuzhiyun }
71*4882a593Smuzhiyun 
72*4882a593Smuzhiyun template <>
operator +=(std::vector<char> & lhs,const std::string rhs)73*4882a593Smuzhiyun std::vector<char>& cnpy::operator+=(std::vector<char>& lhs, const std::string rhs)
74*4882a593Smuzhiyun {
75*4882a593Smuzhiyun   lhs.insert(lhs.end(), rhs.begin(), rhs.end());
76*4882a593Smuzhiyun   return lhs;
77*4882a593Smuzhiyun }
78*4882a593Smuzhiyun 
79*4882a593Smuzhiyun template <>
operator +=(std::vector<char> & lhs,const char * rhs)80*4882a593Smuzhiyun std::vector<char>& cnpy::operator+=(std::vector<char>& lhs, const char* rhs)
81*4882a593Smuzhiyun {
82*4882a593Smuzhiyun   // write in little endian
83*4882a593Smuzhiyun   size_t len = strlen(rhs);
84*4882a593Smuzhiyun   lhs.reserve(len);
85*4882a593Smuzhiyun   for (size_t byte = 0; byte < len; byte++) {
86*4882a593Smuzhiyun     lhs.push_back(rhs[byte]);
87*4882a593Smuzhiyun   }
88*4882a593Smuzhiyun   return lhs;
89*4882a593Smuzhiyun }
90*4882a593Smuzhiyun 
parse_npy_header(unsigned char * buffer,size_t & word_size,std::vector<size_t> & shape,bool & fortran_order,std::string & typeName)91*4882a593Smuzhiyun void cnpy::parse_npy_header(unsigned char* buffer, size_t& word_size, std::vector<size_t>& shape, bool& fortran_order,
92*4882a593Smuzhiyun                             std::string& typeName)
93*4882a593Smuzhiyun {
94*4882a593Smuzhiyun   // std::string magic_string(buffer,6);
95*4882a593Smuzhiyun   uint8_t     major_version = *reinterpret_cast<uint8_t*>(buffer + 6);
96*4882a593Smuzhiyun   uint8_t     minor_version = *reinterpret_cast<uint8_t*>(buffer + 7);
97*4882a593Smuzhiyun   uint16_t    header_len    = *reinterpret_cast<uint16_t*>(buffer + 8);
98*4882a593Smuzhiyun   std::string header(reinterpret_cast<char*>(buffer + 9), header_len);
99*4882a593Smuzhiyun 
100*4882a593Smuzhiyun   size_t loc1, loc2;
101*4882a593Smuzhiyun 
102*4882a593Smuzhiyun   // fortran order
103*4882a593Smuzhiyun   loc1          = header.find("fortran_order") + 16;
104*4882a593Smuzhiyun   fortran_order = (header.substr(loc1, 4) == "True" ? true : false);
105*4882a593Smuzhiyun   if (fortran_order)
106*4882a593Smuzhiyun     throw std::runtime_error("npy input file: 'fortran_order' must be false, use: arr2 = np.ascontiguousarray(arr1)");
107*4882a593Smuzhiyun 
108*4882a593Smuzhiyun   // shape
109*4882a593Smuzhiyun   loc1 = header.find("(");
110*4882a593Smuzhiyun   loc2 = header.find(")");
111*4882a593Smuzhiyun 
112*4882a593Smuzhiyun   std::regex  num_regex("[0-9][0-9]*");
113*4882a593Smuzhiyun   std::smatch sm;
114*4882a593Smuzhiyun   shape.clear();
115*4882a593Smuzhiyun 
116*4882a593Smuzhiyun   std::string str_shape = header.substr(loc1 + 1, loc2 - loc1 - 1);
117*4882a593Smuzhiyun   while (std::regex_search(str_shape, sm, num_regex)) {
118*4882a593Smuzhiyun     shape.push_back(std::stoi(sm[0].str()));
119*4882a593Smuzhiyun     str_shape = sm.suffix().str();
120*4882a593Smuzhiyun   }
121*4882a593Smuzhiyun 
122*4882a593Smuzhiyun   // endian, word size, data type
123*4882a593Smuzhiyun   // byte order code | stands for not applicable.
124*4882a593Smuzhiyun   // not sure when this applies except for byte array
125*4882a593Smuzhiyun   loc1              = header.find("descr") + 9;
126*4882a593Smuzhiyun   bool littleEndian = (header[loc1] == '<' || header[loc1] == '|' ? true : false);
127*4882a593Smuzhiyun   assert(littleEndian);
128*4882a593Smuzhiyun 
129*4882a593Smuzhiyun   // char type = header[loc1+1];
130*4882a593Smuzhiyun   // assert(type == map_type(T));
131*4882a593Smuzhiyun 
132*4882a593Smuzhiyun   std::string str_ws = header.substr(loc1 + 2);
133*4882a593Smuzhiyun   loc2               = str_ws.find("'");
134*4882a593Smuzhiyun   word_size          = atoi(str_ws.substr(0, loc2).c_str());
135*4882a593Smuzhiyun   if (header.substr(loc1 + 1, 1) == "i") {
136*4882a593Smuzhiyun     typeName = "int";
137*4882a593Smuzhiyun   } else if (header.substr(loc1 + 1, 1) == "u") {
138*4882a593Smuzhiyun     typeName = "uint";
139*4882a593Smuzhiyun   } else if (header.substr(loc1 + 1, 1) == "f") {
140*4882a593Smuzhiyun     typeName = "float";
141*4882a593Smuzhiyun   }
142*4882a593Smuzhiyun   typeName = typeName + std::to_string(word_size * 8);
143*4882a593Smuzhiyun }
144*4882a593Smuzhiyun 
parse_npy_header(FILE * fp,size_t & word_size,std::vector<size_t> & shape,bool & fortran_order,std::string & typeName)145*4882a593Smuzhiyun void cnpy::parse_npy_header(FILE* fp, size_t& word_size, std::vector<size_t>& shape, bool& fortran_order,
146*4882a593Smuzhiyun                             std::string& typeName)
147*4882a593Smuzhiyun {
148*4882a593Smuzhiyun   char   buffer[256];
149*4882a593Smuzhiyun   size_t res = fread(buffer, sizeof(char), 11, fp);
150*4882a593Smuzhiyun   if (res != 11)
151*4882a593Smuzhiyun     throw std::runtime_error("parse_npy_header: failed fread");
152*4882a593Smuzhiyun   std::string header = fgets(buffer, 256, fp);
153*4882a593Smuzhiyun   assert(header[header.size() - 1] == '\n');
154*4882a593Smuzhiyun 
155*4882a593Smuzhiyun   size_t loc1, loc2;
156*4882a593Smuzhiyun 
157*4882a593Smuzhiyun   // fortran order
158*4882a593Smuzhiyun   loc1 = header.find("fortran_order");
159*4882a593Smuzhiyun   if (loc1 == std::string::npos)
160*4882a593Smuzhiyun     throw std::runtime_error("parse_npy_header: failed to find header keyword: 'fortran_order'");
161*4882a593Smuzhiyun   loc1 += 16;
162*4882a593Smuzhiyun   fortran_order = (header.substr(loc1, 4) == "True" ? true : false);
163*4882a593Smuzhiyun   if (fortran_order)
164*4882a593Smuzhiyun     throw std::runtime_error("npy input file: 'fortran_order' must be false, use: arr2 = np.ascontiguousarray(arr1)");
165*4882a593Smuzhiyun 
166*4882a593Smuzhiyun   // shape
167*4882a593Smuzhiyun   loc1 = header.find("(");
168*4882a593Smuzhiyun   loc2 = header.find(")");
169*4882a593Smuzhiyun   if (loc1 == std::string::npos || loc2 == std::string::npos)
170*4882a593Smuzhiyun     throw std::runtime_error("parse_npy_header: failed to find header keyword: '(' or ')'");
171*4882a593Smuzhiyun 
172*4882a593Smuzhiyun   std::regex  num_regex("[0-9][0-9]*");
173*4882a593Smuzhiyun   std::smatch sm;
174*4882a593Smuzhiyun   shape.clear();
175*4882a593Smuzhiyun 
176*4882a593Smuzhiyun   std::string str_shape = header.substr(loc1 + 1, loc2 - loc1 - 1);
177*4882a593Smuzhiyun   while (std::regex_search(str_shape, sm, num_regex)) {
178*4882a593Smuzhiyun     shape.push_back(std::stoi(sm[0].str()));
179*4882a593Smuzhiyun     str_shape = sm.suffix().str();
180*4882a593Smuzhiyun   }
181*4882a593Smuzhiyun 
182*4882a593Smuzhiyun   // endian, word size, data type
183*4882a593Smuzhiyun   // byte order code | stands for not applicable.
184*4882a593Smuzhiyun   // not sure when this applies except for byte array
185*4882a593Smuzhiyun   loc1 = header.find("descr");
186*4882a593Smuzhiyun   if (loc1 == std::string::npos)
187*4882a593Smuzhiyun     throw std::runtime_error("parse_npy_header: failed to find header keyword: 'descr'");
188*4882a593Smuzhiyun   loc1 += 9;
189*4882a593Smuzhiyun   bool littleEndian = (header[loc1] == '<' || header[loc1] == '|' ? true : false);
190*4882a593Smuzhiyun   assert(littleEndian);
191*4882a593Smuzhiyun 
192*4882a593Smuzhiyun   // char type = header[loc1+1];
193*4882a593Smuzhiyun   // assert(type == map_type(T));
194*4882a593Smuzhiyun 
195*4882a593Smuzhiyun   std::string str_ws = header.substr(loc1 + 2);
196*4882a593Smuzhiyun   loc2               = str_ws.find("'");
197*4882a593Smuzhiyun   word_size          = atoi(str_ws.substr(0, loc2).c_str());
198*4882a593Smuzhiyun   if (header.substr(loc1 + 1, 1) == "i") {
199*4882a593Smuzhiyun     typeName = "int";
200*4882a593Smuzhiyun   } else if (header.substr(loc1 + 1, 1) == "u") {
201*4882a593Smuzhiyun     typeName = "uint";
202*4882a593Smuzhiyun   } else if (header.substr(loc1 + 1, 1) == "f") {
203*4882a593Smuzhiyun     typeName = "float";
204*4882a593Smuzhiyun   }
205*4882a593Smuzhiyun   typeName = typeName + std::to_string(word_size * 8);
206*4882a593Smuzhiyun }
207*4882a593Smuzhiyun 
parse_zip_footer(FILE * fp,uint16_t & nrecs,size_t & global_header_size,size_t & global_header_offset)208*4882a593Smuzhiyun void cnpy::parse_zip_footer(FILE* fp, uint16_t& nrecs, size_t& global_header_size, size_t& global_header_offset)
209*4882a593Smuzhiyun {
210*4882a593Smuzhiyun   std::vector<char> footer(22);
211*4882a593Smuzhiyun   fseek(fp, -22, SEEK_END);
212*4882a593Smuzhiyun   size_t res = fread(&footer[0], sizeof(char), 22, fp);
213*4882a593Smuzhiyun   if (res != 22)
214*4882a593Smuzhiyun     throw std::runtime_error("parse_zip_footer: failed fread");
215*4882a593Smuzhiyun 
216*4882a593Smuzhiyun   uint16_t disk_no, disk_start, nrecs_on_disk, comment_len;
217*4882a593Smuzhiyun   disk_no              = *(uint16_t*)&footer[4];
218*4882a593Smuzhiyun   disk_start           = *(uint16_t*)&footer[6];
219*4882a593Smuzhiyun   nrecs_on_disk        = *(uint16_t*)&footer[8];
220*4882a593Smuzhiyun   nrecs                = *(uint16_t*)&footer[10];
221*4882a593Smuzhiyun   global_header_size   = *(uint32_t*)&footer[12];
222*4882a593Smuzhiyun   global_header_offset = *(uint32_t*)&footer[16];
223*4882a593Smuzhiyun   comment_len          = *(uint16_t*)&footer[20];
224*4882a593Smuzhiyun 
225*4882a593Smuzhiyun   assert(disk_no == 0);
226*4882a593Smuzhiyun   assert(disk_start == 0);
227*4882a593Smuzhiyun   assert(nrecs_on_disk == nrecs);
228*4882a593Smuzhiyun   assert(comment_len == 0);
229*4882a593Smuzhiyun }
230*4882a593Smuzhiyun 
load_the_npy_file(FILE * fp)231*4882a593Smuzhiyun cnpy::NpyArray load_the_npy_file(FILE* fp)
232*4882a593Smuzhiyun {
233*4882a593Smuzhiyun   std::vector<size_t> shape;
234*4882a593Smuzhiyun   size_t              word_size;
235*4882a593Smuzhiyun   std::string         typeName;
236*4882a593Smuzhiyun   bool                fortran_order;
237*4882a593Smuzhiyun   cnpy::parse_npy_header(fp, word_size, shape, fortran_order, typeName);
238*4882a593Smuzhiyun 
239*4882a593Smuzhiyun   cnpy::NpyArray arr(shape, word_size, fortran_order, typeName);
240*4882a593Smuzhiyun   size_t         nread = fread(arr.data<char>(), 1, arr.num_bytes(), fp);
241*4882a593Smuzhiyun   if (nread != arr.num_bytes())
242*4882a593Smuzhiyun     throw std::runtime_error("load_the_npy_file: failed fread");
243*4882a593Smuzhiyun   return arr;
244*4882a593Smuzhiyun }
245*4882a593Smuzhiyun 
load_the_npz_array(FILE * fp,uint32_t compr_bytes,uint32_t uncompr_bytes)246*4882a593Smuzhiyun cnpy::NpyArray load_the_npz_array(FILE* fp, uint32_t compr_bytes, uint32_t uncompr_bytes)
247*4882a593Smuzhiyun {
248*4882a593Smuzhiyun   std::vector<unsigned char> buffer_compr(compr_bytes);
249*4882a593Smuzhiyun   std::vector<unsigned char> buffer_uncompr(uncompr_bytes);
250*4882a593Smuzhiyun   size_t                     nread = fread(&buffer_compr[0], 1, compr_bytes, fp);
251*4882a593Smuzhiyun   if (nread != compr_bytes)
252*4882a593Smuzhiyun     throw std::runtime_error("load_the_npy_file: failed fread");
253*4882a593Smuzhiyun 
254*4882a593Smuzhiyun #if 0
255*4882a593Smuzhiyun   int      err;
256*4882a593Smuzhiyun   z_stream d_stream;
257*4882a593Smuzhiyun 
258*4882a593Smuzhiyun   d_stream.zalloc   = Z_NULL;
259*4882a593Smuzhiyun   d_stream.zfree    = Z_NULL;
260*4882a593Smuzhiyun   d_stream.opaque   = Z_NULL;
261*4882a593Smuzhiyun   d_stream.avail_in = 0;
262*4882a593Smuzhiyun   d_stream.next_in  = Z_NULL;
263*4882a593Smuzhiyun   err               = inflateInit2(&d_stream, -MAX_WBITS);
264*4882a593Smuzhiyun 
265*4882a593Smuzhiyun   d_stream.avail_in  = compr_bytes;
266*4882a593Smuzhiyun   d_stream.next_in   = &buffer_compr[0];
267*4882a593Smuzhiyun   d_stream.avail_out = uncompr_bytes;
268*4882a593Smuzhiyun   d_stream.next_out  = &buffer_uncompr[0];
269*4882a593Smuzhiyun 
270*4882a593Smuzhiyun   err = inflate(&d_stream, Z_FINISH);
271*4882a593Smuzhiyun   err = inflateEnd(&d_stream);
272*4882a593Smuzhiyun #endif
273*4882a593Smuzhiyun 
274*4882a593Smuzhiyun   std::vector<size_t> shape;
275*4882a593Smuzhiyun   size_t              word_size;
276*4882a593Smuzhiyun   bool                fortran_order;
277*4882a593Smuzhiyun   std::string         typeName;
278*4882a593Smuzhiyun   cnpy::parse_npy_header(&buffer_uncompr[0], word_size, shape, fortran_order, typeName);
279*4882a593Smuzhiyun 
280*4882a593Smuzhiyun   cnpy::NpyArray array(shape, word_size, fortran_order, typeName);
281*4882a593Smuzhiyun 
282*4882a593Smuzhiyun   size_t offset = uncompr_bytes - array.num_bytes();
283*4882a593Smuzhiyun   memcpy(array.data<unsigned char>(), &buffer_uncompr[0] + offset, array.num_bytes());
284*4882a593Smuzhiyun 
285*4882a593Smuzhiyun   return array;
286*4882a593Smuzhiyun }
287*4882a593Smuzhiyun 
npz_load(std::string fname)288*4882a593Smuzhiyun cnpy::npz_t cnpy::npz_load(std::string fname)
289*4882a593Smuzhiyun {
290*4882a593Smuzhiyun   FILE* fp = fopen(fname.c_str(), "rb");
291*4882a593Smuzhiyun 
292*4882a593Smuzhiyun   if (!fp) {
293*4882a593Smuzhiyun     throw std::runtime_error("npz_load: Error! Unable to open file " + fname + "!");
294*4882a593Smuzhiyun   }
295*4882a593Smuzhiyun 
296*4882a593Smuzhiyun   cnpy::npz_t arrays;
297*4882a593Smuzhiyun 
298*4882a593Smuzhiyun   while (1) {
299*4882a593Smuzhiyun     std::vector<char> local_header(30);
300*4882a593Smuzhiyun     size_t            headerres = fread(&local_header[0], sizeof(char), 30, fp);
301*4882a593Smuzhiyun     if (headerres != 30)
302*4882a593Smuzhiyun       throw std::runtime_error("npz_load: failed fread");
303*4882a593Smuzhiyun 
304*4882a593Smuzhiyun     // if we've reached the global header, stop reading
305*4882a593Smuzhiyun     if (local_header[2] != 0x03 || local_header[3] != 0x04)
306*4882a593Smuzhiyun       break;
307*4882a593Smuzhiyun 
308*4882a593Smuzhiyun     // read in the variable name
309*4882a593Smuzhiyun     uint16_t    name_len = *(uint16_t*)&local_header[26];
310*4882a593Smuzhiyun     std::string varname(name_len, ' ');
311*4882a593Smuzhiyun     size_t      vname_res = fread(&varname[0], sizeof(char), name_len, fp);
312*4882a593Smuzhiyun     if (vname_res != name_len)
313*4882a593Smuzhiyun       throw std::runtime_error("npz_load: failed fread");
314*4882a593Smuzhiyun 
315*4882a593Smuzhiyun     // erase the lagging .npy
316*4882a593Smuzhiyun     varname.erase(varname.end() - 4, varname.end());
317*4882a593Smuzhiyun 
318*4882a593Smuzhiyun     // read in the extra field
319*4882a593Smuzhiyun     uint16_t extra_field_len = *(uint16_t*)&local_header[28];
320*4882a593Smuzhiyun     if (extra_field_len > 0) {
321*4882a593Smuzhiyun       std::vector<char> buff(extra_field_len);
322*4882a593Smuzhiyun       size_t            efield_res = fread(&buff[0], sizeof(char), extra_field_len, fp);
323*4882a593Smuzhiyun       if (efield_res != extra_field_len)
324*4882a593Smuzhiyun         throw std::runtime_error("npz_load: failed fread");
325*4882a593Smuzhiyun     }
326*4882a593Smuzhiyun 
327*4882a593Smuzhiyun     uint16_t compr_method  = *reinterpret_cast<uint16_t*>(&local_header[0] + 8);
328*4882a593Smuzhiyun     uint32_t compr_bytes   = *reinterpret_cast<uint32_t*>(&local_header[0] + 18);
329*4882a593Smuzhiyun     uint32_t uncompr_bytes = *reinterpret_cast<uint32_t*>(&local_header[0] + 22);
330*4882a593Smuzhiyun 
331*4882a593Smuzhiyun     if (compr_method == 0) {
332*4882a593Smuzhiyun       arrays[varname] = load_the_npy_file(fp);
333*4882a593Smuzhiyun     } else {
334*4882a593Smuzhiyun       arrays[varname] = load_the_npz_array(fp, compr_bytes, uncompr_bytes);
335*4882a593Smuzhiyun     }
336*4882a593Smuzhiyun   }
337*4882a593Smuzhiyun 
338*4882a593Smuzhiyun   fclose(fp);
339*4882a593Smuzhiyun   return arrays;
340*4882a593Smuzhiyun }
341*4882a593Smuzhiyun 
npz_load(std::string fname,std::string varname)342*4882a593Smuzhiyun cnpy::NpyArray cnpy::npz_load(std::string fname, std::string varname)
343*4882a593Smuzhiyun {
344*4882a593Smuzhiyun   FILE* fp = fopen(fname.c_str(), "rb");
345*4882a593Smuzhiyun 
346*4882a593Smuzhiyun   if (!fp)
347*4882a593Smuzhiyun     throw std::runtime_error("npz_load: Unable to open file " + fname);
348*4882a593Smuzhiyun 
349*4882a593Smuzhiyun   while (1) {
350*4882a593Smuzhiyun     std::vector<char> local_header(30);
351*4882a593Smuzhiyun     size_t            header_res = fread(&local_header[0], sizeof(char), 30, fp);
352*4882a593Smuzhiyun     if (header_res != 30)
353*4882a593Smuzhiyun       throw std::runtime_error("npz_load: failed fread");
354*4882a593Smuzhiyun 
355*4882a593Smuzhiyun     // if we've reached the global header, stop reading
356*4882a593Smuzhiyun     if (local_header[2] != 0x03 || local_header[3] != 0x04)
357*4882a593Smuzhiyun       break;
358*4882a593Smuzhiyun 
359*4882a593Smuzhiyun     // read in the variable name
360*4882a593Smuzhiyun     uint16_t    name_len = *(uint16_t*)&local_header[26];
361*4882a593Smuzhiyun     std::string vname(name_len, ' ');
362*4882a593Smuzhiyun     size_t      vname_res = fread(&vname[0], sizeof(char), name_len, fp);
363*4882a593Smuzhiyun     if (vname_res != name_len)
364*4882a593Smuzhiyun       throw std::runtime_error("npz_load: failed fread");
365*4882a593Smuzhiyun     vname.erase(vname.end() - 4, vname.end()); // erase the lagging .npy
366*4882a593Smuzhiyun 
367*4882a593Smuzhiyun     // read in the extra field
368*4882a593Smuzhiyun     uint16_t extra_field_len = *(uint16_t*)&local_header[28];
369*4882a593Smuzhiyun     fseek(fp, extra_field_len, SEEK_CUR); // skip past the extra field
370*4882a593Smuzhiyun 
371*4882a593Smuzhiyun     uint16_t compr_method  = *reinterpret_cast<uint16_t*>(&local_header[0] + 8);
372*4882a593Smuzhiyun     uint32_t compr_bytes   = *reinterpret_cast<uint32_t*>(&local_header[0] + 18);
373*4882a593Smuzhiyun     uint32_t uncompr_bytes = *reinterpret_cast<uint32_t*>(&local_header[0] + 22);
374*4882a593Smuzhiyun 
375*4882a593Smuzhiyun     if (vname == varname) {
376*4882a593Smuzhiyun       NpyArray array = (compr_method == 0) ? load_the_npy_file(fp) : load_the_npz_array(fp, compr_bytes, uncompr_bytes);
377*4882a593Smuzhiyun       fclose(fp);
378*4882a593Smuzhiyun       return array;
379*4882a593Smuzhiyun     } else {
380*4882a593Smuzhiyun       // skip past the data
381*4882a593Smuzhiyun       uint32_t size = *(uint32_t*)&local_header[22];
382*4882a593Smuzhiyun       fseek(fp, size, SEEK_CUR);
383*4882a593Smuzhiyun     }
384*4882a593Smuzhiyun   }
385*4882a593Smuzhiyun 
386*4882a593Smuzhiyun   fclose(fp);
387*4882a593Smuzhiyun 
388*4882a593Smuzhiyun   // if we get here, we haven't found the variable in the file
389*4882a593Smuzhiyun   throw std::runtime_error("npz_load: Variable name " + varname + " not found in " + fname);
390*4882a593Smuzhiyun }
391*4882a593Smuzhiyun 
npy_load(std::string fname)392*4882a593Smuzhiyun cnpy::NpyArray cnpy::npy_load(std::string fname)
393*4882a593Smuzhiyun {
394*4882a593Smuzhiyun   FILE* fp = fopen(fname.c_str(), "rb");
395*4882a593Smuzhiyun 
396*4882a593Smuzhiyun   if (!fp)
397*4882a593Smuzhiyun     throw std::runtime_error("npy_load: Unable to open file " + fname);
398*4882a593Smuzhiyun 
399*4882a593Smuzhiyun   NpyArray arr = load_the_npy_file(fp);
400*4882a593Smuzhiyun 
401*4882a593Smuzhiyun   fclose(fp);
402*4882a593Smuzhiyun   return arr;
403*4882a593Smuzhiyun }
404