Skip to content

Commit

Permalink
[Kernel] Support data skipping on timestamp/timestampNtz columns (#3481)
Browse files Browse the repository at this point in the history
Adds a `TIMEADD` scalar expression to the data skipping logic. This
addresses issues arising from TIMESTAMP being truncated to millisecond
precision when serialized to JSON. For example, a file containing only
`01:02:03.456789` will be written with `min == max == 01:02:03.456`, so
we must consider it to contain the range from `01:02:03.456 to
01:02:03.457`.

Resolves #2462. 

## How was this patch tested?
Unit tests.
  • Loading branch information
raveeram-db authored Aug 9, 2024
1 parent 78cdeb0 commit 3cebe54
Show file tree
Hide file tree
Showing 12 changed files with 358 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@
* argument. If all arguments are null returns null
* <li>Since version: 3.1.0
* </ul>
* <li>Name: <code>TIMEADD</code>
* <ul>
* <li>Semantic: <code>TIMEADD(colExpr, milliseconds)</code>. Add the specified number of
* milliseconds to the timestamp represented by <i>colExpr</i>. The adjustment does not
* alter the original value but returns a new timestamp increased by the given
* milliseconds. Ex: `TIMEADD(timestampColumn, 1000)` returns a timestamp 1 second
* later.
* <li>Since version: 3.3.0
* </ul>
* </ol>
*
* @since 3.0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import io.delta.kernel.data.FilteredColumnarBatch;
import io.delta.kernel.engine.Engine;
import io.delta.kernel.expressions.*;
import io.delta.kernel.types.*;
import io.delta.kernel.internal.util.Tuple2;
import io.delta.kernel.types.StructField;
import io.delta.kernel.types.StructType;
import java.util.*;

public class DataSkippingUtils {
Expand Down Expand Up @@ -313,7 +315,6 @@ private static DataSkippingPredicate constructComparatorDataSkippingFilters(
case ">=":
return constructBinaryDataSkippingPredicate(
">=", schemaHelper.getMaxColumn(leftCol), rightLit);

default:
throw new IllegalArgumentException(
String.format("Unsupported comparator expression %s", comparator));
Expand All @@ -322,18 +323,15 @@ private static DataSkippingPredicate constructComparatorDataSkippingFilters(

/**
* Constructs a {@link DataSkippingPredicate} for a binary predicate expression with a left
* expression of type {@link Column} and a right expression of type {@link Literal}.
* column, an optional column adjustment expression and a right expression of type {@link
* Literal}.
*/
private static DataSkippingPredicate constructBinaryDataSkippingPredicate(
String exprName, Column col, Literal lit) {
String exprName, Tuple2<Column, Optional<Expression>> colExpr, Literal lit) {
Column column = colExpr._1;
Expression adjColExpr = colExpr._2.isPresent() ? colExpr._2.get() : column;
return new DataSkippingPredicate(
exprName,
Arrays.asList(col, lit),
new HashSet<Column>() {
{
add(col);
}
});
exprName, Arrays.asList(adjColExpr, lit), Collections.singleton(column));
}

private static final Map<String, String> REVERSE_COMPARATORS =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import static io.delta.kernel.internal.util.Preconditions.checkArgument;

import io.delta.kernel.expressions.Column;
import io.delta.kernel.expressions.Expression;
import io.delta.kernel.expressions.Literal;
import io.delta.kernel.expressions.ScalarExpression;
import io.delta.kernel.internal.util.Tuple2;
import io.delta.kernel.types.*;
import java.util.*;
Expand Down Expand Up @@ -112,26 +114,46 @@ public StatsSchemaHelper(StructType dataSchema) {

/**
* Given a logical column in the data schema provided when creating {@code this}, return the
* corresponding MIN column in the statistic schema that stores the MIN values for the provided
* logical column.
* corresponding MIN column and an optional column adjustment expression from the statistic schema
* that stores the MIN values for the provided logical column.
*
* @param column the logical column name.
* @return a tuple of the MIN column and an optional adjustment expression.
*/
public Column getMinColumn(Column column) {
public Tuple2<Column, Optional<Expression>> getMinColumn(Column column) {
checkArgument(
isSkippingEligibleMinMaxColumn(column),
String.format("%s is not a valid min column for data schema %s", column, dataSchema));
return getStatsColumn(column, MIN);
return new Tuple2<>(getStatsColumn(column, MIN), Optional.empty());
}

/**
* Given a logical column in the data schema provided when creating {@code this}, return the
* corresponding MAX column in the statistic schema that stores the MAX values for the provided
* logical column.
* corresponding MAX column and an optional column adjustment expression from the statistic schema
* that stores the MAX values for the provided logical column.
*
* @param column the logical column name.
* @return a tuple of the MAX column and an optional adjustment expression.
*/
public Column getMaxColumn(Column column) {
public Tuple2<Column, Optional<Expression>> getMaxColumn(Column column) {
checkArgument(
isSkippingEligibleMinMaxColumn(column),
String.format("%s is not a valid min column for data schema %s", column, dataSchema));
return getStatsColumn(column, MAX);
DataType dataType = logicalToDataType.get(column);
Column maxColumn = getStatsColumn(column, MAX);

// If this is a column of type Timestamp or TimestampNTZ
// compensate for the truncation from microseconds to milliseconds
// by adding 1 millisecond. For example, a file containing only
// 01:02:03.456789 will be written with min == max == 01:02:03.456, so we must consider it
// to contain the range from 01:02:03.456 to 01:02:03.457.
if (dataType instanceof TimestampType || dataType instanceof TimestampNTZType) {
return new Tuple2<>(
maxColumn,
Optional.of(
new ScalarExpression("TIMEADD", Arrays.asList(maxColumn, Literal.ofLong(1)))));
}
return new Tuple2<>(maxColumn, Optional.empty());
}

/**
Expand All @@ -158,13 +180,7 @@ public Column getNumRecordsColumn() {
*/
public boolean isSkippingEligibleMinMaxColumn(Column column) {
return logicalToDataType.containsKey(column)
&& isSkippingEligibleDataType(logicalToDataType.get(column))
&&
// TODO (delta-io/delta#2462) for now we block using min/max columns of timestamps.
// JSON serialization truncates to milliseconds. To safely use timestamp min/max stats
// we need to add a millisecond to max statistics which requires time addition
// expression
!(logicalToDataType.get(column) instanceof TimestampType);
&& isSkippingEligibleDataType(logicalToDataType.get(column));
}

/**
Expand Down Expand Up @@ -196,6 +212,7 @@ public boolean isSkippingEligibleNullCountColumn(Column column) {
add("double");
add("date");
add("timestamp");
add("timestamp_ntz");
add("string");
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
*/
package io.delta.kernel.test

import java.lang.{Boolean => BooleanJ, Double => DoubleJ, Float => FloatJ}
import io.delta.kernel.data.{ColumnVector, MapValue}
import io.delta.kernel.internal.util.VectorUtils
import io.delta.kernel.types._

import java.lang.{Boolean => BooleanJ, Double => DoubleJ, Float => FloatJ}
import scala.collection.JavaConverters._

trait VectorTestUtils {
Expand All @@ -37,6 +38,21 @@ trait VectorTestUtils {
}
}

protected def timestampVector(values: Seq[Long]): ColumnVector = {
new ColumnVector {
override def getDataType: DataType = TimestampType.TIMESTAMP

override def getSize: Int = values.length

override def close(): Unit = {}

override def isNullAt(rowId: Int): Boolean = values(rowId) == -1

// Values are stored as Longs representing milliseconds since epoch
override def getLong(rowId: Int): Long = values(rowId)
}
}

protected def stringVector(values: Seq[String]): ColumnVector = {
new ColumnVector {
override def getDataType: DataType = StringType.STRING
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,25 @@
import io.delta.kernel.internal.util.Tuple2;
import io.delta.kernel.types.DataType;
import io.delta.kernel.types.StructType;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.time.format.DateTimeFormatterBuilder;
import java.time.temporal.ChronoField;
import java.time.temporal.ChronoUnit;
import java.util.concurrent.TimeUnit;

public class DefaultKernelUtils {
private static final LocalDate EPOCH = LocalDate.ofEpochDay(0);
private static final DateTimeFormatter DEFAULT_JSON_TIMESTAMPNTZ_FORMATTER =
new DateTimeFormatterBuilder()
.appendPattern("yyyy-MM-dd'T'HH:mm:ss")
.optionalStart()
.appendFraction(ChronoField.MICRO_OF_SECOND, 0, 6, true)
.optionalEnd()
.toFormatter();

private DefaultKernelUtils() {}

Expand Down Expand Up @@ -59,6 +73,19 @@ public static long millisToMicros(long millis) {
return Math.multiplyExact(millis, DateTimeConstants.MICROS_PER_MILLIS);
}

/**
* Parses a TimestampNTZ string in UTC format, supporting milliseconds and microseconds, to
* microseconds since the Unix epoch.
*
* @param timestampString the timestamp string to parse.
* @return the number of microseconds since epoch.
*/
public static long parseTimestampNTZ(String timestampString) {
LocalDateTime time = LocalDateTime.parse(timestampString, DEFAULT_JSON_TIMESTAMPNTZ_FORMATTER);
Instant instant = time.toInstant(ZoneOffset.UTC);
return ChronoUnit.MICROS.between(Instant.EPOCH, instant);
}

/**
* Search for the data type of the given column in the schema.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.delta.kernel.data.ColumnVector;
import io.delta.kernel.data.MapValue;
import io.delta.kernel.data.Row;
import io.delta.kernel.defaults.internal.DefaultKernelUtils;
import io.delta.kernel.defaults.internal.data.vector.DefaultGenericVector;
import io.delta.kernel.internal.util.InternalUtils;
import io.delta.kernel.types.*;
Expand Down Expand Up @@ -257,6 +258,11 @@ private static Object decodeElement(JsonNode jsonValue, DataType dataType) {
return ChronoUnit.MICROS.between(Instant.EPOCH, time);
}

if (dataType instanceof TimestampNTZType) {
throwIfTypeMismatch("timestamp_ntz", jsonValue.isTextual(), jsonValue);
return DefaultKernelUtils.parseTimestampNTZ(jsonValue.textValue());
}

if (dataType instanceof StructType) {
throwIfTypeMismatch("object", jsonValue.isObject(), jsonValue);
return new DefaultJsonRow((ObjectNode) jsonValue, (StructType) dataType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,8 @@

import static io.delta.kernel.defaults.internal.DefaultEngineErrors.unsupportedExpressionException;
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.*;
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.booleanWrapperVector;
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.childAt;
import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo;
import static io.delta.kernel.internal.util.ExpressionUtils.getLeft;
import static io.delta.kernel.internal.util.ExpressionUtils.getRight;
import static io.delta.kernel.internal.util.ExpressionUtils.getUnaryChild;
import static io.delta.kernel.internal.util.ExpressionUtils.*;
import static io.delta.kernel.internal.util.Preconditions.checkArgument;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
Expand All @@ -35,6 +31,7 @@
import io.delta.kernel.engine.ExpressionHandler;
import io.delta.kernel.expressions.*;
import io.delta.kernel.types.*;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -260,6 +257,36 @@ ExpressionTransformResult visitCoalesce(ScalarExpression coalesce) {
children.get(0).outputType);
}

@Override
ExpressionTransformResult visitTimeAdd(ScalarExpression timeAdd) {
List<ExpressionTransformResult> children =
timeAdd.getChildren().stream().map(this::visit).collect(Collectors.toList());

if (children.size() != 2) {
throw unsupportedExpressionException(
timeAdd, "TIMEADD requires exactly two arguments: timestamp column and milliseconds");
}

Expression timestampColumn = children.get(0).expression;
Expression durationMilliseconds = children.get(1).expression;
DataType timestampColumnType = children.get(0).outputType;
DataType literalColumnType = children.get(1).outputType;

// Ensure the first child is either a TimestampType or a TimestampNTZType,
// and the second is a LongType.
if (!((timestampColumnType instanceof TimestampType
|| timestampColumnType instanceof TimestampNTZType)
&& (literalColumnType instanceof LongType))) {
throw new IllegalArgumentException(
"TIMEADD requires a timestamp and a Long (milliseconds) to add to it");
}

return new ExpressionTransformResult(
new ScalarExpression("TIMEADD", Arrays.asList(timestampColumn, durationMilliseconds)),
timestampColumnType // Result is also a timestamp
);
}

@Override
ExpressionTransformResult visitLike(final Predicate like) {
List<ExpressionTransformResult> children =
Expand Down Expand Up @@ -534,6 +561,44 @@ ColumnVector visitCoalesce(ScalarExpression coalesce) {
});
}

@Override
ColumnVector visitTimeAdd(ScalarExpression timeAdd) {
ColumnVector timestampColumn = visit(timeAdd.getChildren().get(0));
ColumnVector durationVector = visit(timeAdd.getChildren().get(1));

return new ColumnVector() {
@Override
public DataType getDataType() {
return timestampColumn.getDataType();
}

@Override
public int getSize() {
return timestampColumn.getSize();
}

@Override
public void close() {
timestampColumn.close();
durationVector.close();
}

@Override
public boolean isNullAt(int rowId) {
return timestampColumn.isNullAt(rowId) || durationVector.isNullAt(rowId);
}

@Override
public long getLong(int rowId) {
if (isNullAt(rowId)) {
return 0;
}
long durationMicros = durationVector.getLong(rowId) * 1000L;
return timestampColumn.getLong(rowId) + durationMicros;
}
};
}

@Override
ColumnVector visitLike(final Predicate like) {
List<Expression> children = like.getChildren();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ abstract class ExpressionVisitor<R> {

abstract R visitCoalesce(ScalarExpression ifNull);

abstract R visitTimeAdd(ScalarExpression timeAdd);

abstract R visitLike(Predicate predicate);

final R visit(Expression expression) {
Expand Down Expand Up @@ -106,6 +108,8 @@ private R visitScalarExpression(ScalarExpression expression) {
return visitIsNull(new Predicate(name, children));
case "COALESCE":
return visitCoalesce(expression);
case "TIMEADD":
return visitTimeAdd(expression);
case "LIKE":
return visitLike(new Predicate(name, children));
default:
Expand Down
Loading

0 comments on commit 3cebe54

Please sign in to comment.