1 #include <mbgl/style/expression/match.hpp>
2 #include <mbgl/style/expression/check_subtype.hpp>
3 #include <mbgl/style/expression/parsing_context.hpp>
4 #include <mbgl/util/string.hpp>
5
6 namespace mbgl {
7 namespace style {
8 namespace expression {
9
10 template <typename T>
eachChild(const std::function<void (const Expression &)> & visit) const11 void Match<T>::eachChild(const std::function<void(const Expression&)>& visit) const {
12 visit(*input);
13 for (const std::pair<T, std::shared_ptr<Expression>>& branch : branches) {
14 visit(*branch.second);
15 }
16 visit(*otherwise);
17 }
18
19 template <typename T>
operator ==(const Expression & e) const20 bool Match<T>::operator==(const Expression& e) const {
21 if (e.getKind() == Kind::Match) {
22 auto rhs = static_cast<const Match*>(&e);
23 return (*input == *(rhs->input) &&
24 *otherwise == *(rhs->otherwise) &&
25 Expression::childrenEqual(branches, rhs->branches));
26 }
27 return false;
28 }
29
30 template <typename T>
possibleOutputs() const31 std::vector<optional<Value>> Match<T>::possibleOutputs() const {
32 std::vector<optional<Value>> result;
33 for (const auto& branch : branches) {
34 for (auto& output : branch.second->possibleOutputs()) {
35 result.push_back(std::move(output));
36 }
37 }
38 for (auto& output : otherwise->possibleOutputs()) {
39 result.push_back(std::move(output));
40 }
41 return result;
42 }
43
44 template <typename T>
serialize() const45 mbgl::Value Match<T>::serialize() const {
46 std::vector<mbgl::Value> serialized;
47 serialized.emplace_back(getOperator());
48 serialized.emplace_back(input->serialize());
49
50 // Sort so serialization has an arbitrary defined order, even though branch order doesn't affect evaluation
51 std::map<T, std::shared_ptr<Expression>> sortedBranches(branches.begin(), branches.end());
52
53 // Group branches by unique match expression to support condensed serializations
54 // of the form [case1, case2, ...] -> matchExpression
55 std::map<Expression*, size_t> outputLookup;
56 std::vector<std::pair<Expression*, std::vector<mbgl::Value>>> groupedByOutput;
57 for (auto& entry : sortedBranches) {
58 auto outputIndex = outputLookup.find(entry.second.get());
59 if (outputIndex == outputLookup.end()) {
60 // First time seeing this output, add it to the end of the grouped list
61 outputLookup[entry.second.get()] = groupedByOutput.size();
62 groupedByOutput.emplace_back(entry.second.get(), std::vector<mbgl::Value>{{entry.first}});
63 } else {
64 // We've seen this expression before, add the label to that output's group
65 groupedByOutput[outputIndex->second].second.emplace_back(entry.first);
66 }
67 };
68
69 for (auto& entry : groupedByOutput) {
70 entry.second.size() == 1
71 ? serialized.emplace_back(entry.second[0]) // Only a single label matches this output expression
72 : serialized.emplace_back(entry.second); // Array of literal labels pointing to this output expression
73 serialized.emplace_back(entry.first->serialize()); // The output expression itself
74 }
75
76 serialized.emplace_back(otherwise->serialize());
77 return serialized;
78 }
79
80
evaluate(const EvaluationContext & params) const81 template<> EvaluationResult Match<std::string>::evaluate(const EvaluationContext& params) const {
82 const EvaluationResult inputValue = input->evaluate(params);
83 if (!inputValue) {
84 return inputValue.error();
85 }
86
87 if (!inputValue->is<std::string>()) {
88 return otherwise->evaluate(params);
89 }
90
91 auto it = branches.find(inputValue->get<std::string>());
92 if (it != branches.end()) {
93 return (*it).second->evaluate(params);
94 }
95
96 return otherwise->evaluate(params);
97 }
98
evaluate(const EvaluationContext & params) const99 template<> EvaluationResult Match<int64_t>::evaluate(const EvaluationContext& params) const {
100 const EvaluationResult inputValue = input->evaluate(params);
101 if (!inputValue) {
102 return inputValue.error();
103 }
104
105 if (!inputValue->is<double>()) {
106 return otherwise->evaluate(params);
107 }
108
109 const auto numeric = inputValue->get<double>();
110 int64_t rounded = std::floor(numeric);
111 if (numeric == rounded) {
112 auto it = branches.find(rounded);
113 if (it != branches.end()) {
114 return (*it).second->evaluate(params);
115 }
116 }
117
118 return otherwise->evaluate(params);
119 }
120
121 template class Match<int64_t>;
122 template class Match<std::string>;
123
124 using InputType = variant<int64_t, std::string>;
125
126 using namespace mbgl::style::conversion;
parseInputValue(const Convertible & input,ParsingContext & parentContext,std::size_t index,optional<type::Type> & inputType)127 optional<InputType> parseInputValue(const Convertible& input, ParsingContext& parentContext, std::size_t index, optional<type::Type>& inputType) {
128 using namespace mbgl::style::conversion;
129 optional<InputType> result;
130 optional<type::Type> type;
131
132 auto value = toValue(input);
133
134 if (value) {
135 value->match(
136 [&] (uint64_t n) {
137 if (!Value::isSafeInteger(n)) {
138 parentContext.error("Branch labels must be integers no larger than " + util::toString(Value::maxSafeInteger()) + ".", index);
139 } else {
140 type = {type::Number};
141 result = optional<InputType>{static_cast<int64_t>(n)};
142 }
143 },
144 [&] (int64_t n) {
145 if (!Value::isSafeInteger(n)) {
146 parentContext.error("Branch labels must be integers no larger than " + util::toString(Value::maxSafeInteger()) + ".", index);
147 } else {
148 type = {type::Number};
149 result = optional<InputType>{n};
150 }
151 },
152 [&] (double n) {
153 if (!Value::isSafeInteger(n)) {
154 parentContext.error("Branch labels must be integers no larger than " + util::toString(Value::maxSafeInteger()) + ".", index);
155 } else if (n != std::floor(n)) {
156 parentContext.error("Numeric branch labels must be integer values.", index);
157 } else {
158 type = {type::Number};
159 result = optional<InputType>{static_cast<int64_t>(n)};
160 }
161 },
162 [&] (const std::string& s) {
163 type = {type::String};
164 result = {s};
165 },
166 [&] (const auto&) {
167 parentContext.error("Branch labels must be numbers or strings.", index);
168 }
169 );
170 } else {
171 parentContext.error("Branch labels must be numbers or strings.", index);
172 }
173
174 if (!type) {
175 return result;
176 }
177
178 if (!inputType) {
179 inputType = type;
180 } else {
181 optional<std::string> err = type::checkSubtype(*inputType, *type);
182 if (err) {
183 parentContext.error(*err, index);
184 return optional<InputType>();
185 }
186 }
187
188 return result;
189 }
190
191 template <typename T>
create(type::Type outputType,std::unique_ptr<Expression> input,std::vector<std::pair<std::vector<InputType>,std::unique_ptr<Expression>>> branches,std::unique_ptr<Expression> otherwise,ParsingContext & ctx)192 static ParseResult create(type::Type outputType,
193 std::unique_ptr<Expression>input,
194 std::vector<std::pair<std::vector<InputType>,
195 std::unique_ptr<Expression>>> branches,
196 std::unique_ptr<Expression> otherwise,
197 ParsingContext& ctx) {
198 typename Match<T>::Branches typedBranches;
199
200 std::size_t index = 2;
201
202 typedBranches.reserve(branches.size());
203 for (std::pair<std::vector<InputType>,
204 std::unique_ptr<Expression>>& pair : branches) {
205 std::shared_ptr<Expression> result = std::move(pair.second);
206 for (const InputType& label : pair.first) {
207 const auto& typedLabel = label.template get<T>();
208 if (typedBranches.find(typedLabel) != typedBranches.end()) {
209 ctx.error("Branch labels must be unique.", index);
210 return ParseResult();
211 }
212 typedBranches.emplace(typedLabel, result);
213 }
214
215 index += 2;
216 }
217 return ParseResult(std::make_unique<Match<T>>(
218 outputType,
219 std::move(input),
220 std::move(typedBranches),
221 std::move(otherwise)
222 ));
223 }
224
parseMatch(const Convertible & value,ParsingContext & ctx)225 ParseResult parseMatch(const Convertible& value, ParsingContext& ctx) {
226 assert(isArray(value));
227 auto length = arrayLength(value);
228 if (length < 5) {
229 ctx.error(
230 "Expected at least 4 arguments, but found only " + util::toString(length - 1) + "."
231 );
232 return ParseResult();
233 }
234
235 // Expect odd-length array: ["match", input, 2 * (n pairs)..., otherwise]
236 if (length % 2 != 1) {
237 ctx.error("Expected an even number of arguments.");
238 return ParseResult();
239 }
240
241 optional<type::Type> inputType;
242 optional<type::Type> outputType;
243 if (ctx.getExpected() && *ctx.getExpected() != type::Value) {
244 outputType = ctx.getExpected();
245 }
246
247 std::vector<std::pair<std::vector<InputType>,
248 std::unique_ptr<Expression>>> branches;
249
250 branches.reserve((length - 3) / 2);
251 for (size_t i = 2; i + 1 < length; i += 2) {
252 const auto& label = arrayMember(value, i);
253
254 std::vector<InputType> labels;
255 // Match pair inputs are provided as either a literal value or a
256 // raw JSON array of string / number / boolean values.
257 if (isArray(label)) {
258 auto groupLength = arrayLength(label);
259 if (groupLength == 0) {
260 ctx.error("Expected at least one branch label.", i);
261 return ParseResult();
262 }
263
264 labels.reserve(groupLength);
265 for (size_t j = 0; j < groupLength; j++) {
266 const optional<InputType> inputValue = parseInputValue(arrayMember(label, j), ctx, i, inputType);
267 if (!inputValue) {
268 return ParseResult();
269 }
270 labels.push_back(*inputValue);
271 }
272 } else {
273 const optional<InputType> inputValue = parseInputValue(label, ctx, i, inputType);
274 if (!inputValue) {
275 return ParseResult();
276 }
277 labels.push_back(*inputValue);
278 }
279
280 ParseResult output = ctx.parse(arrayMember(value, i + 1), i + 1, outputType);
281 if (!output) {
282 return ParseResult();
283 }
284
285 if (!outputType) {
286 outputType = (*output)->getType();
287 }
288
289 branches.push_back(std::make_pair(std::move(labels), std::move(*output)));
290 }
291
292 auto input = ctx.parse(arrayMember(value, 1), 1, {type::Value});
293 if (!input) {
294 return ParseResult();
295 }
296
297 auto otherwise = ctx.parse(arrayMember(value, length - 1), length - 1, outputType);
298 if (!otherwise) {
299 return ParseResult();
300 }
301
302 assert(inputType && outputType);
303
304 optional<std::string> err;
305 if ((*input)->getType() != type::Value && (err = type::checkSubtype(*inputType, (*input)->getType()))) {
306 ctx.error(*err, 1);
307 return ParseResult();
308 }
309
310 return inputType->match(
311 [&](const type::NumberType&) {
312 return create<int64_t>(*outputType, std::move(*input), std::move(branches), std::move(*otherwise), ctx);
313 },
314 [&](const type::StringType&) {
315 return create<std::string>(*outputType, std::move(*input), std::move(branches), std::move(*otherwise), ctx);
316 },
317 [&](const auto&) {
318 // unreachable: inputType is set by parseInputValue(), which only
319 // accepts string and (integer) numeric values.
320 assert(false);
321 return ParseResult();
322 }
323 );
324 }
325
326 } // namespace expression
327 } // namespace style
328 } // namespace mbgl
329