Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions cpp/src/gandiva/annotator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,25 @@ void Annotator::PrepareBuffersForField(const FieldDescriptor& desc,
}
}

const Status Annotator::CheckEvalBatchFieldType(
const arrow::RecordBatch& record_batch) const {
for (int i = 0; i < record_batch.num_columns(); ++i) {
const std::string& name = record_batch.column_name(i);
auto found = in_name_to_desc_.find(name);
if (found == in_name_to_desc_.end()) {
// skip columns not involved in the expression.
continue;
}
if (record_batch.column(i)->type_id() != found->second->Type()->id()) {
return Status::ExecutionError("Expect field ", name, " type is ",
found->second->Type()->ToString(), ", input field ",
name, " type is ",
record_batch.column(i)->type()->ToString());
}
}
return Status::OK();
}

EvalBatchPtr Annotator::PrepareEvalBatch(const arrow::RecordBatch& record_batch,
const ArrayDataVector& out_vector) const {
EvalBatchPtr eval_batch = std::make_shared<EvalBatch>(
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/gandiva/annotator.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class GANDIVA_EXPORT Annotator {
EvalBatchPtr PrepareEvalBatch(const arrow::RecordBatch& record_batch,
const ArrayDataVector& out_vector) const;

const Status CheckEvalBatchFieldType(const arrow::RecordBatch& record_batch) const;

int buffer_count() const { return buffer_count_; }

private:
Expand Down
14 changes: 0 additions & 14 deletions cpp/src/gandiva/expr_validator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,6 @@ Status ExprValidator::Visit(const FieldNode& node) {
Status::ExpressionValidationError("Field ", node.field()->name(),
" has unsupported data type ",
node.return_type()->name()));

// Ensure that field is found in schema
auto field_in_schema_entry = field_map_.find(node.field()->name());
ARROW_RETURN_IF(field_in_schema_entry == field_map_.end(),
Status::ExpressionValidationError("Field ", node.field()->name(),
" not in schema."));

// Ensure that the found field matches.
FieldPtr field_in_schema = field_in_schema_entry->second;
ARROW_RETURN_IF(!field_in_schema->Equals(node.field()),
Status::ExpressionValidationError(
"Field definition in schema ", field_in_schema->ToString(),
" different from field in expression ", node.field()->ToString()));

return Status::OK();
}

Expand Down
2 changes: 2 additions & 0 deletions cpp/src/gandiva/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ namespace gandiva {

std::string Expression::ToString() { return root()->ToString(); }

std::string Expression::ToCacheKeyString() { return root()->ToCacheKeyString(); }

} // namespace gandiva
2 changes: 2 additions & 0 deletions cpp/src/gandiva/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class GANDIVA_EXPORT Expression {

std::string ToString();

std::string ToCacheKeyString();

private:
const NodePtr root_;
const FieldPtr result_;
Expand Down
46 changes: 18 additions & 28 deletions cpp/src/gandiva/expression_cache_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,45 +34,40 @@ class ExpressionCacheKey {
public:
ExpressionCacheKey(SchemaPtr schema, std::shared_ptr<Configuration> configuration,
ExpressionVector expression_vector, SelectionVector::Mode mode)
: schema_(schema), mode_(mode), uniquifier_(0), configuration_(configuration) {
: mode_(mode), uniqifier_(0), configuration_(configuration) {
static const int kSeedValue = 4;
size_t result = kSeedValue;
for (auto& expr : expression_vector) {
std::string expr_as_string = expr->ToString();
expressions_as_strings_.push_back(expr_as_string);
arrow::internal::hash_combine(result, expr_as_string);
UpdateUniquifier(expr_as_string);
std::string expr_cache_key_string = expr->ToCacheKeyString();
expressions_as_cache_key_strings_.push_back(expr_cache_key_string);
arrow::internal::hash_combine(result, expr_cache_key_string);
UpdateUniqifier(expr_cache_key_string);
}
arrow::internal::hash_combine(result, static_cast<size_t>(mode));
arrow::internal::hash_combine(result, configuration->Hash());
arrow::internal::hash_combine(result, schema_->ToString());
arrow::internal::hash_combine(result, uniquifier_);
arrow::internal::hash_combine(result, uniqifier_);
hash_code_ = result;
}

ExpressionCacheKey(SchemaPtr schema, std::shared_ptr<Configuration> configuration,
Expression& expression)
: schema_(schema),
mode_(SelectionVector::MODE_NONE),
uniquifier_(0),
configuration_(configuration) {
: mode_(SelectionVector::MODE_NONE), uniqifier_(0), configuration_(configuration) {
static const int kSeedValue = 4;
size_t result = kSeedValue;
expressions_as_strings_.push_back(expression.ToString());
UpdateUniquifier(expression.ToString());

expressions_as_cache_key_strings_.push_back(expression.ToCacheKeyString());
UpdateUniqifier(expression.ToCacheKeyString());
arrow::internal::hash_combine(result, expression.ToCacheKeyString());
arrow::internal::hash_combine(result, configuration->Hash());
arrow::internal::hash_combine(result, schema_->ToString());
arrow::internal::hash_combine(result, uniquifier_);
arrow::internal::hash_combine(result, uniqifier_);
hash_code_ = result;
}

void UpdateUniquifier(const std::string& expr) {
if (uniquifier_ == 0) {
void UpdateUniqifier(const std::string& expr) {
if (uniqifier_ == 0) {
// caching of expressions with re2 patterns causes lock contention. So, use
// multiple instances to reduce contention.
if (expr.find(" like(") != std::string::npos) {
uniquifier_ = std::hash<std::thread::id>()(std::this_thread::get_id()) % 16;
uniqifier_ = std::hash<std::thread::id>()(std::this_thread::get_id()) % 16;
}
}
}
Expand All @@ -84,10 +79,6 @@ class ExpressionCacheKey {
return false;
}

if (!(schema_->Equals(*other.schema_, true))) {
return false;
}

if (configuration_ != other.configuration_) {
return false;
}
Expand All @@ -96,11 +87,11 @@ class ExpressionCacheKey {
return false;
}

if (expressions_as_strings_ != other.expressions_as_strings_) {
if (expressions_as_cache_key_strings_ != other.expressions_as_cache_key_strings_) {
return false;
}

if (uniquifier_ != other.uniquifier_) {
if (uniqifier_ != other.uniqifier_) {
return false;
}

Expand All @@ -111,10 +102,9 @@ class ExpressionCacheKey {

private:
size_t hash_code_;
SchemaPtr schema_;
std::vector<std::string> expressions_as_strings_;
std::vector<std::string> expressions_as_cache_key_strings_;
SelectionVector::Mode mode_;
uint32_t uniquifier_;
uint32_t uniqifier_;
std::shared_ptr<Configuration> configuration_;
};

Expand Down
2 changes: 0 additions & 2 deletions cpp/src/gandiva/filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ Status Filter::Make(SchemaPtr schema, ConditionPtr condition,
Status Filter::Evaluate(const arrow::RecordBatch& batch,
std::shared_ptr<SelectionVector> out_selection) {
const auto num_rows = batch.num_rows();
ARROW_RETURN_IF(!batch.schema()->Equals(*schema_),
Status::Invalid("RecordBatch schema must expected filter schema"));
ARROW_RETURN_IF(num_rows == 0, Status::Invalid("RecordBatch must be non-empty."));
ARROW_RETURN_IF(out_selection == nullptr,
Status::Invalid("out_selection must be non-null."));
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/gandiva/llvm_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch,
const ArrayDataVector& output_vector) const {
DCHECK_GT(record_batch.num_rows(), 0);

auto status = annotator_.CheckEvalBatchFieldType(record_batch);

ARROW_RETURN_IF(!status.ok(), status);

auto eval_batch = annotator_.PrepareEvalBatch(record_batch, output_vector);
DCHECK_GT(eval_batch->GetNumBuffers(), 0);

Expand Down
83 changes: 83 additions & 0 deletions cpp/src/gandiva/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class GANDIVA_EXPORT Node {

virtual std::string ToString() const = 0;

virtual std::string ToCacheKeyString() const = 0;

protected:
DataTypePtr return_type_;
};
Expand Down Expand Up @@ -99,6 +101,8 @@ class GANDIVA_EXPORT LiteralNode : public Node {
return ss.str();
}

std::string ToCacheKeyString() const override { return ToString(); }

private:
LiteralHolder holder_;
bool is_null_;
Expand All @@ -117,6 +121,10 @@ class GANDIVA_EXPORT FieldNode : public Node {
return "(" + field()->type()->ToString() + ") " + field()->name();
}

std::string ToCacheKeyString() const override {
return "(" + field()->type()->ToString() + ") ";
}

private:
FieldPtr field_;
};
Expand Down Expand Up @@ -149,6 +157,24 @@ class GANDIVA_EXPORT FunctionNode : public Node {
return ss.str();
}

std::string ToCacheKeyString() const override {
std::stringstream ss;
ss << ((return_type() == NULLPTR) ? "untyped"
: descriptor()->return_type()->ToString())
<< " " << descriptor()->name() << "(";
bool skip_comma = true;
for (auto& child : children()) {
if (skip_comma) {
ss << child->ToCacheKeyString();
skip_comma = false;
} else {
ss << ", " << child->ToCacheKeyString();
}
}
ss << ")";
return ss.str();
}

private:
FuncDescriptorPtr descriptor_;
NodeVector children_;
Expand Down Expand Up @@ -188,6 +214,14 @@ class GANDIVA_EXPORT IfNode : public Node {
return ss.str();
}

std::string ToCacheKeyString() const override {
std::stringstream ss;
ss << "if (" << condition()->ToCacheKeyString() << ") { ";
ss << then_node()->ToCacheKeyString() << " } else { ";
ss << else_node()->ToCacheKeyString() << " }";
return ss.str();
}

private:
NodePtr condition_;
NodePtr then_node_;
Expand Down Expand Up @@ -225,6 +259,23 @@ class GANDIVA_EXPORT BooleanNode : public Node {
return ss.str();
}

std::string ToCacheKeyString() const override {
std::stringstream ss;
bool first = true;
for (auto& child : children_) {
if (!first) {
if (expr_type() == BooleanNode::AND) {
ss << " && ";
} else {
ss << " || ";
}
}
ss << child->ToCacheKeyString();
first = false;
}
return ss.str();
}

private:
ExprType expr_type_;
NodeVector children_;
Expand Down Expand Up @@ -265,6 +316,22 @@ class InExpressionNode : public Node {
return ss.str();
}

std::string ToCacheKeyString() const override {
std::stringstream ss;
ss << eval_expr_->ToCacheKeyString() << " IN (";
bool add_comma = false;
for (auto& value : values_) {
if (add_comma) {
ss << ", ";
}
// add type in the front to differentiate
ss << value;
add_comma = true;
}
ss << ")";
return ss.str();
}

private:
NodePtr eval_expr_;
std::unordered_set<Type> values_;
Expand Down Expand Up @@ -309,6 +376,22 @@ class InExpressionNode<gandiva::DecimalScalar128> : public Node {
return ss.str();
}

std::string ToCacheKeyString() const override {
std::stringstream ss;
ss << eval_expr_->ToCacheKeyString() << " IN (";
bool add_comma = false;
for (auto& value : values_) {
if (add_comma) {
ss << ", ";
}
// add type in the front to differentiate
ss << value;
add_comma = true;
}
ss << ")";
return ss.str();
}

private:
NodePtr eval_expr_;
std::unordered_set<gandiva::DecimalScalar128> values_;
Expand Down
2 changes: 0 additions & 2 deletions cpp/src/gandiva/projector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,6 @@ Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records,
}

Status Projector::ValidateEvaluateArgsCommon(const arrow::RecordBatch& batch) const {
ARROW_RETURN_IF(!batch.schema()->Equals(*schema_),
Status::Invalid("Schema in RecordBatch must match schema in Make()"));
ARROW_RETURN_IF(batch.num_rows() == 0,
Status::Invalid("RecordBatch must be non-empty."));

Expand Down
14 changes: 7 additions & 7 deletions cpp/src/gandiva/tests/filter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ class TestFilter : public ::testing::Test {

TEST_F(TestFilter, TestFilterCache) {
// schema for input fields
auto field0 = field("f0_filter_cache", int32());
auto field1 = field("f1_filter_cache", int32());
auto field0 = field("f0_filter_cache", int64());
auto field1 = field("f1_filter_cache", int64());
auto schema = arrow::schema({field0, field1});

// Build condition f0 + f1 < 10
auto node_f0 = TreeExprBuilder::MakeField(field0);
auto node_f1 = TreeExprBuilder::MakeField(field1);
auto sum_func =
TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32());
auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10);
TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int64());
auto literal_10 = TreeExprBuilder::MakeLiteral((int64_t)10);
auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10},
arrow::boolean());
auto condition = TreeExprBuilder::MakeCondition(less_than_10);
Expand All @@ -69,13 +69,13 @@ TEST_F(TestFilter, TestFilterCache) {
EXPECT_TRUE(cached_filter->GetBuiltFromCache());

// schema is different should return a new filter.
auto field2 = field("f2_filter_cache", int32());
auto field2 = field("f2_filter_cache", int64());
auto different_schema = arrow::schema({field0, field1, field2});
std::shared_ptr<Filter> should_be_new_filter;
status =
Filter::Make(different_schema, condition, configuration, &should_be_new_filter);
EXPECT_TRUE(status.ok());
EXPECT_FALSE(should_be_new_filter->GetBuiltFromCache());
EXPECT_TRUE(should_be_new_filter->GetBuiltFromCache());

// condition is different, should return a new filter.
auto greater_than_10 = TreeExprBuilder::MakeFunction(
Expand All @@ -84,7 +84,7 @@ TEST_F(TestFilter, TestFilterCache) {
std::shared_ptr<Filter> should_be_new_filter1;
status = Filter::Make(schema, new_condition, configuration, &should_be_new_filter1);
EXPECT_TRUE(status.ok());
EXPECT_FALSE(should_be_new_filter->GetBuiltFromCache());
EXPECT_FALSE(should_be_new_filter1->GetBuiltFromCache());
}

TEST_F(TestFilter, TestFilterCacheNullTreatment) {
Expand Down
Loading
Loading