forked from ClickHouse/ClickHouse
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patharrayDotProduct.cpp
More file actions
422 lines (348 loc) · 16.2 KB
/
arrayDotProduct.cpp
File metadata and controls
422 lines (348 loc) · 16.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
#include <Columns/ColumnArray.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionBinaryArithmetic.h>
#include <Functions/FunctionFactory.h>
#include <Functions/IFunction.h>
#include <Functions/castTypeToEither.h>
#include <Interpreters/Context_fwd.h>
#if USE_MULTITARGET_CODE
#include <immintrin.h>
#endif
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int LOGICAL_ERROR;
extern const int SIZES_OF_ARRAYS_DONT_MATCH;
}
struct DotProduct
{
static constexpr auto name = "arrayDotProduct";
static DataTypePtr getReturnType(const DataTypePtr & left, const DataTypePtr & right)
{
using Types = TypeList<DataTypeFloat32, DataTypeFloat64,
DataTypeUInt8, DataTypeUInt16, DataTypeUInt32, DataTypeUInt64,
DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64>;
Types types;
DataTypePtr result_type;
bool valid = castTypeToEither(types, left.get(), [&](const auto & left_)
{
return castTypeToEither(types, right.get(), [&](const auto & right_)
{
using LeftType = typename std::decay_t<decltype(left_)>::FieldType;
using RightType = typename std::decay_t<decltype(right_)>::FieldType;
using ResultType = typename NumberTraits::ResultOfAdditionMultiplication<LeftType, RightType>::Type;
if constexpr (std::is_same_v<LeftType, Float32> && std::is_same_v<RightType, Float32>)
result_type = std::make_shared<DataTypeFloat32>();
else
result_type = std::make_shared<DataTypeFromFieldType<ResultType>>();
return true;
});
});
if (!valid)
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Arguments of function {} only support: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", name);
return result_type;
}
template <typename Type>
struct State
{
Type sum = 0;
};
template <typename Type>
static NO_SANITIZE_UNDEFINED void accumulate(State<Type> & state, Type x, Type y)
{
state.sum += x * y;
}
template <typename Type>
static NO_SANITIZE_UNDEFINED void combine(State<Type> & state, const State<Type> & other_state)
{
state.sum += other_state.sum;
}
#if USE_MULTITARGET_CODE
template <typename Type>
AVX512_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombine(
const Type * __restrict data_x,
const Type * __restrict data_y,
size_t i_max,
size_t & i,
State<Type> & state)
{
static constexpr bool is_float32 = std::is_same_v<Type, Float32>;
__m512 sums;
if constexpr (is_float32)
sums = _mm512_setzero_ps();
else
sums = _mm512_setzero_pd();
constexpr size_t n = is_float32 ? 16 : 8;
for (; i + n < i_max; i += n)
{
if constexpr (is_float32)
{
__m512 x = _mm512_loadu_ps(data_x + i);
__m512 y = _mm512_loadu_ps(data_y + i);
sums = _mm512_fmadd_ps(x, y, sums);
}
else
{
__m512 x = _mm512_loadu_pd(data_x + i);
__m512 y = _mm512_loadu_pd(data_y + i);
sums = _mm512_fmadd_pd(x, y, sums);
}
}
if constexpr (is_float32)
state.sum = _mm512_reduce_add_ps(sums);
else
state.sum = _mm512_reduce_add_pd(sums);
}
#endif
template <typename Type>
static Type finalize(const State<Type> & state)
{
return state.sum;
}
};
/// The implementation is modeled after the implementation of distance functions arrayL1Distance, arrayL2Distance, etc.
/// The main difference is that arrayDotProduct() interferes the result type differently.
template <typename Kernel>
class FunctionArrayScalarProduct : public IFunction
{
public:
static constexpr auto name = Kernel::name;
String getName() const override { return name; }
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayScalarProduct>(); }
size_t getNumberOfArguments() const override { return 2; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
std::array<DataTypePtr, 2> nested_types;
for (size_t i = 0; i < 2; ++i)
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[i].get());
if (!array_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Arguments for function {} must be of type Array", getName());
const auto & nested_type = array_type->getNestedType();
if (!isNativeNumber(nested_type))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Function {} cannot process values of type {}", getName(), nested_type->getName());
nested_types[i] = nested_type;
}
return Kernel::getReturnType(nested_types[0], nested_types[1]);
}
#define SUPPORTED_TYPES(ACTION) \
ACTION(UInt8) \
ACTION(UInt16) \
ACTION(UInt32) \
ACTION(UInt64) \
ACTION(Int8) \
ACTION(Int16) \
ACTION(Int32) \
ACTION(Int64) \
ACTION(Float32) \
ACTION(Float64)
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override
{
switch (result_type->getTypeId())
{
#define ON_TYPE(type) \
case TypeIndex::type: \
return executeWithResultType<type>(arguments, input_rows_count); \
break;
SUPPORTED_TYPES(ON_TYPE)
#undef ON_TYPE
default:
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected result type {}", result_type->getName());
}
}
private:
template <typename ResultType>
ColumnPtr executeWithResultType(const ColumnsWithTypeAndName & arguments, size_t input_rows_count) const
{
DataTypePtr type_x = typeid_cast<const DataTypeArray *>(arguments[0].type.get())->getNestedType();
switch (type_x->getTypeId())
{
#define ON_TYPE(type) \
case TypeIndex::type: \
return executeWithResultTypeAndLeftType<ResultType, type>(arguments, input_rows_count); \
break;
SUPPORTED_TYPES(ON_TYPE)
#undef ON_TYPE
default:
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Arguments of function {} has nested type {}. "
"Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.",
getName(),
type_x->getName());
}
}
template <typename ResultType, typename LeftType>
ColumnPtr executeWithResultTypeAndLeftType(const ColumnsWithTypeAndName & arguments, size_t input_rows_count) const
{
DataTypePtr type_y = typeid_cast<const DataTypeArray *>(arguments[1].type.get())->getNestedType();
switch (type_y->getTypeId())
{
#define ON_TYPE(type) \
case TypeIndex::type: \
return executeWithResultTypeAndLeftTypeAndRightType<ResultType, LeftType, type>(arguments[0].column, arguments[1].column, input_rows_count); \
break;
SUPPORTED_TYPES(ON_TYPE)
#undef ON_TYPE
default:
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Arguments of function {} has nested type {}. "
"Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.",
getName(),
type_y->getName());
}
}
template <typename ResultType, typename LeftType, typename RightType>
ColumnPtr executeWithResultTypeAndLeftTypeAndRightType(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count) const
{
if (typeid_cast<const ColumnConst *>(col_x.get()))
{
return executeWithLeftArgConst<ResultType, LeftType, RightType>(col_x, col_y, input_rows_count);
}
if (typeid_cast<const ColumnConst *>(col_y.get()))
{
return executeWithLeftArgConst<ResultType, RightType, LeftType>(col_y, col_x, input_rows_count);
}
const auto & array_x = *assert_cast<const ColumnArray *>(col_x.get());
const auto & array_y = *assert_cast<const ColumnArray *>(col_y.get());
const auto & data_x = typeid_cast<const ColumnVector<LeftType> &>(array_x.getData()).getData();
const auto & data_y = typeid_cast<const ColumnVector<RightType> &>(array_y.getData()).getData();
const auto & offsets_x = array_x.getOffsets();
if (!array_x.hasEqualOffsets(array_y))
throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Array arguments for function {} must have equal sizes", getName());
auto col_res = ColumnVector<ResultType>::create(input_rows_count);
auto & result_data = col_res->getData();
ColumnArray::Offset current_offset = 0;
for (size_t row = 0; row < input_rows_count; ++row)
{
const size_t array_size = offsets_x[row] - current_offset;
size_t i = 0;
/// Process chunks in vectorized manner
static constexpr size_t VEC_SIZE = 4;
typename Kernel::template State<ResultType> states[VEC_SIZE];
for (; i + VEC_SIZE < array_size; i += VEC_SIZE)
{
for (size_t j = 0; j < VEC_SIZE; ++j)
Kernel::template accumulate<ResultType>(states[j], static_cast<ResultType>(data_x[current_offset + i + j]), static_cast<ResultType>(data_y[current_offset + i + j]));
}
typename Kernel::template State<ResultType> state;
for (const auto & other_state : states)
Kernel::template combine<ResultType>(state, other_state);
/// Process the tail
for (; i < array_size; ++i)
Kernel::template accumulate<ResultType>(state, static_cast<ResultType>(data_x[current_offset + i]), static_cast<ResultType>(data_y[current_offset + i]));
result_data[row] = Kernel::template finalize<ResultType>(state);
current_offset = offsets_x[row];
}
return col_res;
}
template <typename ResultType, typename LeftType, typename RightType>
ColumnPtr executeWithLeftArgConst(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count) const
{
col_x = assert_cast<const ColumnConst *>(col_x.get())->getDataColumnPtr();
col_y = col_y->convertToFullColumnIfConst();
const auto & array_x = *assert_cast<const ColumnArray *>(col_x.get());
const auto & array_y = *assert_cast<const ColumnArray *>(col_y.get());
const auto & data_x = typeid_cast<const ColumnVector<LeftType> &>(array_x.getData()).getData();
const auto & data_y = typeid_cast<const ColumnVector<RightType> &>(array_y.getData()).getData();
const auto & offsets_x = array_x.getOffsets();
const auto & offsets_y = array_y.getOffsets();
ColumnArray::Offset prev_offset = 0;
for (auto offset_y : offsets_y)
{
if (offsets_x[0] != offset_y - prev_offset) [[unlikely]]
{
throw Exception(
ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH,
"Arguments of function {} have different array sizes: {} and {}",
getName(),
offsets_x[0],
offset_y - prev_offset);
}
prev_offset = offset_y;
}
auto col_res = ColumnVector<ResultType>::create(input_rows_count);
auto & result = col_res->getData();
ColumnArray::Offset current_offset = 0;
for (size_t row = 0; row < input_rows_count; ++row)
{
const size_t array_size = offsets_x[0];
typename Kernel::template State<ResultType> state;
size_t i = 0;
/// SIMD optimization: process multiple elements in both input arrays at once.
/// To avoid combinatorial explosion of SIMD kernels, focus on
/// - the two most common input/output types (Float32 x Float32) --> Float32 and (Float64 x Float64) --> Float64 instead of 10 x
/// 10 input types x 8 output types,
/// - const/non-const inputs instead of non-const/non-const inputs
/// - the most powerful SIMD instruction set (AVX-512F).
#if USE_MULTITARGET_CODE
if constexpr ((std::is_same_v<ResultType, Float32> || std::is_same_v<ResultType, Float64>)
&& std::is_same_v<ResultType, LeftType> && std::is_same_v<LeftType, RightType>)
{
if (isArchSupported(TargetArch::AVX512F))
Kernel::template accumulateCombine<ResultType>(&data_x[0], &data_y[current_offset], array_size, i, state);
}
#else
/// Process chunks in vectorized manner
static constexpr size_t VEC_SIZE = 4;
typename Kernel::template State<ResultType> states[VEC_SIZE];
for (; i + VEC_SIZE < array_size; i += VEC_SIZE)
{
for (size_t j = 0; j < VEC_SIZE; ++j)
Kernel::template accumulate<ResultType>(states[j], static_cast<ResultType>(data_x[i + j]), static_cast<ResultType>(data_y[current_offset + i + j]));
}
for (const auto & other_state : states)
Kernel::template combine<ResultType>(state, other_state);
#endif
/// Process the tail
for (; i < array_size; ++i)
Kernel::template accumulate<ResultType>(state, static_cast<ResultType>(data_x[i]), static_cast<ResultType>(data_y[current_offset + i]));
result[row] = Kernel::template finalize<ResultType>(state);
current_offset = offsets_y[row];
}
return col_res;
}
};
using FunctionArrayDotProduct = FunctionArrayScalarProduct<DotProduct>;
REGISTER_FUNCTION(ArrayDotProduct)
{
FunctionDocumentation::Description description = R"(
Returns the dot product of two arrays.
:::note
The sizes of the two vectors must be equal. Arrays and Tuples may also contain mixed element types.
:::
)";
FunctionDocumentation::Syntax syntax = "arrayDotProduct(v1, v2)";
FunctionDocumentation::Arguments arguments = {
{"v1", "First vector.", {"Array((U)Int* | Float* | Decimal)", "Tuple((U)Int* | Float* | Decimal)"}},
{"v2", "Second vector.", {"Array((U)Int* | Float* | Decimal)", "Tuple((U)Int* | Float* | Decimal)"}},
};
FunctionDocumentation::ReturnedValue returned_value = {R"(
The dot product of the two vectors.
:::note
The return type is determined by the type of the arguments. If Arrays or Tuples contain mixed element types then the result type is the supertype.
:::
)", {"(U)Int*", "Float*", "Decimal"}};
FunctionDocumentation::Examples examples = {
{"Array example", "SELECT arrayDotProduct([1, 2, 3], [4, 5, 6]) AS res, toTypeName(res);", "32 UInt16"},
{"Tuple example", "SELECT dotProduct((1::UInt16, 2::UInt8, 3::Float32),(4::Int16, 5::Float32, 6::UInt8)) AS res, toTypeName(res);", "32 Float64"}
};
FunctionDocumentation::IntroducedIn introduced_in = {23, 5};
FunctionDocumentation::Category category = FunctionDocumentation::Category::Array;
FunctionDocumentation documentation = {description, syntax, arguments, returned_value, examples, introduced_in, category};
factory.registerFunction<FunctionArrayDotProduct>(documentation);
}
// These functions are used by TupleOrArrayFunction in Function/vectorFunctions.cpp
FunctionPtr createFunctionArrayDotProduct(ContextPtr context_) { return FunctionArrayDotProduct::create(context_); }
}