diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java index ef3138d677c6..afb0f434aa41 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java @@ -20,6 +20,9 @@ import java.io.IOException; import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.Stream; import org.apache.iceberg.BaseTable; @@ -47,6 +50,7 @@ import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.spark.Spark3Util; import org.apache.iceberg.spark.SparkAggregates; import org.apache.iceberg.spark.SparkFilters; @@ -362,15 +366,54 @@ public void pruneColumns(StructType requestedSchema) { private Schema schemaWithMetadataColumns() { // metadata columns - List fields = + List metadataFields = metaColumns.stream() .distinct() .map(name -> MetadataColumns.metadataColumn(table, name)) .collect(Collectors.toList()); - Schema meta = new Schema(fields); + Schema metadataSchema = calculateMetadataSchema(metadataFields); // schema or rows returned by readers - return TypeUtil.join(schema, meta); + return TypeUtil.join(schema, metadataSchema); + } + + private Schema calculateMetadataSchema(List metaColumnFields) { + Optional partitionField = + metaColumnFields.stream() + .filter(f -> MetadataColumns.PARTITION_COLUMN_ID == f.fieldId()) + .findFirst(); + + // only calculate potential column id collision if partition metadata column was requested + if (!partitionField.isPresent()) { + return new Schema(metaColumnFields); + } + + Set idsToReassign = + TypeUtil.indexById(partitionField.get().type().asStructType()).keySet(); + + // Calculate used ids by union metadata columns with all base table schemas + Set currentlyUsedIds = + metaColumnFields.stream().map(Types.NestedField::fieldId).collect(Collectors.toSet()); + Set allUsedIds = + table.schemas().values().stream() + .map(currSchema -> TypeUtil.indexById(currSchema.asStruct()).keySet()) + .reduce(currentlyUsedIds, Sets::union); + + // Reassign selected ids to deduplicate with used ids. + AtomicInteger nextId = new AtomicInteger(); + return new Schema( + metaColumnFields, + table.schema().identifierFieldIds(), + oldId -> { + if (!idsToReassign.contains(oldId)) { + return oldId; + } + int candidate = nextId.incrementAndGet(); + while (allUsedIds.contains(candidate)) { + candidate = nextId.incrementAndGet(); + } + return candidate; + }); } @Override diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java index e24e74383bc8..127b0eb66fe9 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java @@ -24,6 +24,7 @@ import static org.apache.iceberg.TableProperties.PARQUET_BATCH_SIZE; import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED; +import static org.apache.spark.sql.functions.expr; import static org.apache.spark.sql.functions.lit; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -31,6 +32,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.apache.iceberg.FileFormat; import org.apache.iceberg.HasTableOperations; import org.apache.iceberg.MetadataColumns; @@ -53,6 +55,7 @@ import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.types.StructType; import org.junit.After; import org.junit.Assert; import org.junit.Assume; @@ -169,6 +172,52 @@ public void testSpecAndPartitionMetadataColumns() { sql("SELECT _spec_id, _partition FROM %s ORDER BY _spec_id", TABLE_NAME)); } + @Test + public void testPartitionMetadataColumnWithManyColumns() { + // TODO: support metadata structs in vectorized ORC reads + Assume.assumeFalse(fileFormat == FileFormat.ORC && vectorized); + List fields = + Lists.newArrayList(Types.NestedField.required(0, "id", Types.LongType.get())); + List additionalCols = + IntStream.range(1, 1010) + .mapToObj(i -> Types.NestedField.optional(i, "c" + i, Types.StringType.get())) + .collect(Collectors.toList()); + fields.addAll(additionalCols); + Schema manyColumnsSchema = new Schema(fields); + PartitionSpec spec = PartitionSpec.builderFor(manyColumnsSchema).identity("id").build(); + + TableOperations ops = ((HasTableOperations) table).operations(); + TableMetadata base = ops.current(); + ops.commit( + base, + base.updateSchema(manyColumnsSchema, manyColumnsSchema.highestFieldId()) + .updatePartitionSpec(spec)); + + Dataset df = + spark + .range(2) + .withColumns( + IntStream.range(1, 1010) + .boxed() + .collect(Collectors.toMap(i -> "c" + i, i -> expr("CAST(id as STRING)")))); + StructType sparkSchema = spark.table(TABLE_NAME).schema(); + spark + .createDataFrame(df.rdd(), sparkSchema) + .coalesce(1) + .write() + .format("iceberg") + .mode("append") + .save(TABLE_NAME); + + Assert.assertEquals(2, spark.table(TABLE_NAME).select("*", "_partition").count()); + List expected = + ImmutableList.of(row(row(0L), 0L, "0", "0", "0"), row(row(1L), 1L, "1", "1", "1")); + assertEquals( + "Rows must match", + expected, + sql("SELECT _partition, id, c999, c1000, c1001 FROM %s ORDER BY id", TABLE_NAME)); + } + @Test public void testPositionMetadataColumnWithMultipleRowGroups() throws NoSuchTableException { Assume.assumeTrue(fileFormat == FileFormat.PARQUET); diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java index 6b97e48133fd..9dc214a755d3 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java @@ -20,6 +20,9 @@ import java.io.IOException; import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.Stream; import org.apache.iceberg.BaseTable; @@ -48,6 +51,7 @@ import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.spark.Spark3Util; import org.apache.iceberg.spark.SparkAggregates; import org.apache.iceberg.spark.SparkReadConf; @@ -342,15 +346,54 @@ public void pruneColumns(StructType requestedSchema) { private Schema schemaWithMetadataColumns() { // metadata columns - List fields = + List metadataFields = metaColumns.stream() .distinct() .map(name -> MetadataColumns.metadataColumn(table, name)) .collect(Collectors.toList()); - Schema meta = new Schema(fields); + Schema metadataSchema = calculateMetadataSchema(metadataFields); // schema or rows returned by readers - return TypeUtil.join(schema, meta); + return TypeUtil.join(schema, metadataSchema); + } + + private Schema calculateMetadataSchema(List metaColumnFields) { + Optional partitionField = + metaColumnFields.stream() + .filter(f -> MetadataColumns.PARTITION_COLUMN_ID == f.fieldId()) + .findFirst(); + + // only calculate potential column id collision if partition metadata column was requested + if (!partitionField.isPresent()) { + return new Schema(metaColumnFields); + } + + Set idsToReassign = + TypeUtil.indexById(partitionField.get().type().asStructType()).keySet(); + + // Calculate used ids by union metadata columns with all base table schemas + Set currentlyUsedIds = + metaColumnFields.stream().map(Types.NestedField::fieldId).collect(Collectors.toSet()); + Set allUsedIds = + table.schemas().values().stream() + .map(currSchema -> TypeUtil.indexById(currSchema.asStruct()).keySet()) + .reduce(currentlyUsedIds, Sets::union); + + // Reassign selected ids to deduplicate with used ids. + AtomicInteger nextId = new AtomicInteger(); + return new Schema( + metaColumnFields, + table.schema().identifierFieldIds(), + oldId -> { + if (!idsToReassign.contains(oldId)) { + return oldId; + } + int candidate = nextId.incrementAndGet(); + while (allUsedIds.contains(candidate)) { + candidate = nextId.incrementAndGet(); + } + return candidate; + }); } @Override diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java index 778c46bba6b6..0ba34a638a63 100644 --- a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java @@ -24,6 +24,7 @@ import static org.apache.iceberg.TableProperties.PARQUET_BATCH_SIZE; import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED; +import static org.apache.spark.sql.functions.expr; import static org.apache.spark.sql.functions.lit; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -31,6 +32,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.apache.iceberg.FileFormat; import org.apache.iceberg.HasTableOperations; import org.apache.iceberg.MetadataColumns; @@ -53,6 +55,7 @@ import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.types.StructType; import org.junit.After; import org.junit.Assert; import org.junit.Assume; @@ -169,6 +172,50 @@ public void testSpecAndPartitionMetadataColumns() { sql("SELECT _spec_id, _partition FROM %s ORDER BY _spec_id", TABLE_NAME)); } + @Test + public void testPartitionMetadataColumnWithManyColumns() { + List fields = + Lists.newArrayList(Types.NestedField.required(0, "id", Types.LongType.get())); + List additionalCols = + IntStream.range(1, 1010) + .mapToObj(i -> Types.NestedField.optional(i, "c" + i, Types.StringType.get())) + .collect(Collectors.toList()); + fields.addAll(additionalCols); + Schema manyColumnsSchema = new Schema(fields); + PartitionSpec spec = PartitionSpec.builderFor(manyColumnsSchema).identity("id").build(); + + TableOperations ops = ((HasTableOperations) table).operations(); + TableMetadata base = ops.current(); + ops.commit( + base, + base.updateSchema(manyColumnsSchema, manyColumnsSchema.highestFieldId()) + .updatePartitionSpec(spec)); + + Dataset df = + spark + .range(2) + .withColumns( + IntStream.range(1, 1010) + .boxed() + .collect(Collectors.toMap(i -> "c" + i, i -> expr("CAST(id as STRING)")))); + StructType sparkSchema = spark.table(TABLE_NAME).schema(); + spark + .createDataFrame(df.rdd(), sparkSchema) + .coalesce(1) + .write() + .format("iceberg") + .mode("append") + .save(TABLE_NAME); + + Assert.assertEquals(2, spark.table(TABLE_NAME).select("*", "_partition").count()); + List expected = + ImmutableList.of(row(row(0L), 0L, "0", "0", "0"), row(row(1L), 1L, "1", "1", "1")); + assertEquals( + "Rows must match", + expected, + sql("SELECT _partition, id, c999, c1000, c1001 FROM %s ORDER BY id", TABLE_NAME)); + } + @Test public void testPositionMetadataColumnWithMultipleRowGroups() throws NoSuchTableException { Assume.assumeTrue(fileFormat == FileFormat.PARQUET);