1 #include <mbgl/style/expression/match.hpp>
2 #include <mbgl/style/expression/check_subtype.hpp>
3 #include <mbgl/style/expression/parsing_context.hpp>
4 #include <mbgl/util/string.hpp>
5 
6 namespace mbgl {
7 namespace style {
8 namespace expression {
9 
10 template <typename T>
eachChild(const std::function<void (const Expression &)> & visit) const11 void Match<T>::eachChild(const std::function<void(const Expression&)>& visit) const {
12     visit(*input);
13     for (const std::pair<T, std::shared_ptr<Expression>>& branch : branches) {
14         visit(*branch.second);
15     }
16     visit(*otherwise);
17 }
18 
19 template <typename T>
operator ==(const Expression & e) const20 bool Match<T>::operator==(const Expression& e) const {
21     if (e.getKind() == Kind::Match) {
22         auto rhs = static_cast<const Match*>(&e);
23         return (*input == *(rhs->input) &&
24                 *otherwise == *(rhs->otherwise) &&
25                 Expression::childrenEqual(branches, rhs->branches));
26     }
27     return false;
28 }
29 
30 template <typename T>
possibleOutputs() const31 std::vector<optional<Value>> Match<T>::possibleOutputs() const {
32     std::vector<optional<Value>> result;
33     for (const auto& branch : branches) {
34         for (auto& output : branch.second->possibleOutputs()) {
35             result.push_back(std::move(output));
36         }
37     }
38     for (auto& output : otherwise->possibleOutputs()) {
39         result.push_back(std::move(output));
40     }
41     return result;
42 }
43 
44 template <typename T>
serialize() const45 mbgl::Value Match<T>::serialize() const {
46     std::vector<mbgl::Value> serialized;
47     serialized.emplace_back(getOperator());
48     serialized.emplace_back(input->serialize());
49 
50     // Sort so serialization has an arbitrary defined order, even though branch order doesn't affect evaluation
51     std::map<T, std::shared_ptr<Expression>> sortedBranches(branches.begin(), branches.end());
52 
53     // Group branches by unique match expression to support condensed serializations
54     // of the form [case1, case2, ...] -> matchExpression
55     std::map<Expression*, size_t> outputLookup;
56     std::vector<std::pair<Expression*, std::vector<mbgl::Value>>> groupedByOutput;
57     for (auto& entry : sortedBranches) {
58         auto outputIndex = outputLookup.find(entry.second.get());
59         if (outputIndex == outputLookup.end()) {
60             // First time seeing this output, add it to the end of the grouped list
61             outputLookup[entry.second.get()] = groupedByOutput.size();
62             groupedByOutput.emplace_back(entry.second.get(), std::vector<mbgl::Value>{{entry.first}});
63         } else {
64             // We've seen this expression before, add the label to that output's group
65             groupedByOutput[outputIndex->second].second.emplace_back(entry.first);
66         }
67     };
68 
69     for (auto& entry : groupedByOutput) {
70         entry.second.size() == 1
71             ? serialized.emplace_back(entry.second[0])       // Only a single label matches this output expression
72             : serialized.emplace_back(entry.second);         // Array of literal labels pointing to this output expression
73         serialized.emplace_back(entry.first->serialize());   // The output expression itself
74     }
75 
76     serialized.emplace_back(otherwise->serialize());
77     return serialized;
78 }
79 
80 
evaluate(const EvaluationContext & params) const81 template<> EvaluationResult Match<std::string>::evaluate(const EvaluationContext& params) const {
82     const EvaluationResult inputValue = input->evaluate(params);
83     if (!inputValue) {
84         return inputValue.error();
85     }
86 
87     if (!inputValue->is<std::string>()) {
88         return otherwise->evaluate(params);
89     }
90 
91     auto it = branches.find(inputValue->get<std::string>());
92     if (it != branches.end()) {
93         return (*it).second->evaluate(params);
94     }
95 
96     return otherwise->evaluate(params);
97 }
98 
evaluate(const EvaluationContext & params) const99 template<> EvaluationResult Match<int64_t>::evaluate(const EvaluationContext& params) const {
100     const EvaluationResult inputValue = input->evaluate(params);
101     if (!inputValue) {
102         return inputValue.error();
103     }
104 
105     if (!inputValue->is<double>()) {
106         return otherwise->evaluate(params);
107     }
108 
109     const auto numeric = inputValue->get<double>();
110     int64_t rounded = std::floor(numeric);
111     if (numeric == rounded) {
112         auto it = branches.find(rounded);
113         if (it != branches.end()) {
114             return (*it).second->evaluate(params);
115         }
116     }
117 
118     return otherwise->evaluate(params);
119 }
120 
121 template class Match<int64_t>;
122 template class Match<std::string>;
123 
124 using InputType = variant<int64_t, std::string>;
125 
126 using namespace mbgl::style::conversion;
parseInputValue(const Convertible & input,ParsingContext & parentContext,std::size_t index,optional<type::Type> & inputType)127 optional<InputType> parseInputValue(const Convertible& input, ParsingContext& parentContext, std::size_t index, optional<type::Type>& inputType) {
128     using namespace mbgl::style::conversion;
129     optional<InputType> result;
130     optional<type::Type> type;
131 
132     auto value = toValue(input);
133 
134     if (value) {
135         value->match(
136             [&] (uint64_t n) {
137                 if (!Value::isSafeInteger(n)) {
138                     parentContext.error("Branch labels must be integers no larger than " + util::toString(Value::maxSafeInteger()) + ".", index);
139                 } else {
140                     type = {type::Number};
141                     result = optional<InputType>{static_cast<int64_t>(n)};
142                 }
143             },
144             [&] (int64_t n) {
145                 if (!Value::isSafeInteger(n)) {
146                     parentContext.error("Branch labels must be integers no larger than " + util::toString(Value::maxSafeInteger()) + ".", index);
147                 } else {
148                     type = {type::Number};
149                     result = optional<InputType>{n};
150                 }
151             },
152             [&] (double n) {
153                 if (!Value::isSafeInteger(n)) {
154                     parentContext.error("Branch labels must be integers no larger than " + util::toString(Value::maxSafeInteger()) + ".", index);
155                 } else if (n != std::floor(n)) {
156                     parentContext.error("Numeric branch labels must be integer values.", index);
157                 } else {
158                     type = {type::Number};
159                     result = optional<InputType>{static_cast<int64_t>(n)};
160                 }
161             },
162             [&] (const std::string& s) {
163                 type = {type::String};
164                 result = {s};
165             },
166             [&] (const auto&) {
167                 parentContext.error("Branch labels must be numbers or strings.", index);
168             }
169         );
170     } else {
171         parentContext.error("Branch labels must be numbers or strings.", index);
172     }
173 
174     if (!type) {
175         return result;
176     }
177 
178     if (!inputType) {
179         inputType = type;
180     } else {
181         optional<std::string> err = type::checkSubtype(*inputType, *type);
182         if (err) {
183             parentContext.error(*err, index);
184             return optional<InputType>();
185         }
186     }
187 
188     return result;
189 }
190 
191 template <typename T>
create(type::Type outputType,std::unique_ptr<Expression> input,std::vector<std::pair<std::vector<InputType>,std::unique_ptr<Expression>>> branches,std::unique_ptr<Expression> otherwise,ParsingContext & ctx)192 static ParseResult create(type::Type outputType,
193                           std::unique_ptr<Expression>input,
194                           std::vector<std::pair<std::vector<InputType>,
195                                                 std::unique_ptr<Expression>>> branches,
196                           std::unique_ptr<Expression> otherwise,
197                           ParsingContext& ctx) {
198     typename Match<T>::Branches typedBranches;
199 
200     std::size_t index = 2;
201 
202     typedBranches.reserve(branches.size());
203     for (std::pair<std::vector<InputType>,
204                    std::unique_ptr<Expression>>& pair : branches) {
205         std::shared_ptr<Expression> result = std::move(pair.second);
206         for (const InputType& label : pair.first) {
207             const auto& typedLabel = label.template get<T>();
208             if (typedBranches.find(typedLabel) != typedBranches.end()) {
209                 ctx.error("Branch labels must be unique.", index);
210                 return ParseResult();
211             }
212             typedBranches.emplace(typedLabel, result);
213         }
214 
215         index += 2;
216     }
217     return ParseResult(std::make_unique<Match<T>>(
218         outputType,
219         std::move(input),
220         std::move(typedBranches),
221         std::move(otherwise)
222     ));
223 }
224 
parseMatch(const Convertible & value,ParsingContext & ctx)225 ParseResult parseMatch(const Convertible& value, ParsingContext& ctx) {
226     assert(isArray(value));
227     auto length = arrayLength(value);
228     if (length < 5) {
229         ctx.error(
230             "Expected at least 4 arguments, but found only " + util::toString(length - 1) + "."
231         );
232         return ParseResult();
233     }
234 
235     // Expect odd-length array: ["match", input, 2 * (n pairs)..., otherwise]
236     if (length % 2 != 1) {
237         ctx.error("Expected an even number of arguments.");
238         return ParseResult();
239     }
240 
241     optional<type::Type> inputType;
242     optional<type::Type> outputType;
243     if (ctx.getExpected() && *ctx.getExpected() != type::Value) {
244         outputType = ctx.getExpected();
245     }
246 
247     std::vector<std::pair<std::vector<InputType>,
248                           std::unique_ptr<Expression>>> branches;
249 
250     branches.reserve((length - 3) / 2);
251     for (size_t i = 2; i + 1 < length; i += 2) {
252         const auto& label = arrayMember(value, i);
253 
254         std::vector<InputType> labels;
255         // Match pair inputs are provided as either a literal value or a
256         // raw JSON array of string / number / boolean values.
257         if (isArray(label)) {
258             auto groupLength = arrayLength(label);
259             if (groupLength == 0) {
260                 ctx.error("Expected at least one branch label.", i);
261                 return ParseResult();
262             }
263 
264             labels.reserve(groupLength);
265             for (size_t j = 0; j < groupLength; j++) {
266                 const optional<InputType> inputValue = parseInputValue(arrayMember(label, j), ctx, i, inputType);
267                 if (!inputValue) {
268                     return ParseResult();
269                 }
270                 labels.push_back(*inputValue);
271             }
272         } else {
273             const optional<InputType> inputValue = parseInputValue(label, ctx, i, inputType);
274             if (!inputValue) {
275                 return ParseResult();
276             }
277             labels.push_back(*inputValue);
278         }
279 
280         ParseResult output = ctx.parse(arrayMember(value, i + 1), i + 1, outputType);
281         if (!output) {
282             return ParseResult();
283         }
284 
285         if (!outputType) {
286             outputType = (*output)->getType();
287         }
288 
289         branches.push_back(std::make_pair(std::move(labels), std::move(*output)));
290     }
291 
292     auto input = ctx.parse(arrayMember(value, 1), 1, {type::Value});
293     if (!input) {
294         return ParseResult();
295     }
296 
297     auto otherwise = ctx.parse(arrayMember(value, length - 1), length - 1, outputType);
298     if (!otherwise) {
299         return ParseResult();
300     }
301 
302     assert(inputType && outputType);
303 
304     optional<std::string> err;
305     if ((*input)->getType() != type::Value && (err = type::checkSubtype(*inputType, (*input)->getType()))) {
306         ctx.error(*err, 1);
307         return ParseResult();
308     }
309 
310     return inputType->match(
311         [&](const type::NumberType&) {
312             return create<int64_t>(*outputType, std::move(*input), std::move(branches), std::move(*otherwise), ctx);
313         },
314         [&](const type::StringType&) {
315             return create<std::string>(*outputType, std::move(*input), std::move(branches), std::move(*otherwise), ctx);
316         },
317         [&](const auto&) {
318             // unreachable: inputType is set by parseInputValue(), which only
319             // accepts string and (integer) numeric values.
320             assert(false);
321             return ParseResult();
322         }
323     );
324 }
325 
326 } // namespace expression
327 } // namespace style
328 } // namespace mbgl
329