Skip to content

Commit

Permalink
feat(pyspark): add partitionBy argument to create_table
Browse files Browse the repository at this point in the history
Adds the partitionBy argument to create_table method in pyspark backend to enable partitioned table creation

fixes #8900
  • Loading branch information
jakepenzak authored and cpcloud committed Feb 19, 2025
1 parent fbe8c8b commit c99cc23
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 8 deletions.
46 changes: 45 additions & 1 deletion ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def create_table(
temp: bool | None = None,
overwrite: bool = False,
format: str = "parquet",
partition_by: str | list[str] | None = None,
) -> ir.Table:
"""Create a new table in Spark.
Expand All @@ -623,6 +624,8 @@ def create_table(
If `True`, overwrite existing data
format
Format of the table on disk
partition_by
Name(s) of partitioning column(s)
Returns
-------
Expand Down Expand Up @@ -651,7 +654,9 @@ def create_table(
with self._active_catalog_database(catalog, db):
self._run_pre_execute_hooks(table)
df = self._session.sql(query)
df.write.saveAsTable(name, format=format, mode=mode)
df.write.saveAsTable(
name, format=format, mode=mode, partitionBy=partition_by
)
elif schema is not None:
schema = ibis.schema(schema)
schema = PySparkSchema.from_ibis(schema)
Expand Down Expand Up @@ -953,6 +958,45 @@ def to_delta(
df = self._session.sql(self.compile(expr, params=params, limit=limit))
df.write.format("delta").save(os.fspath(path), **kwargs)

@util.experimental
def to_parquet(
self,
expr: ir.Table,
/,
path: str | Path,
*,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
**kwargs: Any,
) -> None:
"""Write the results of executing the given expression to a Parquet file.
This method is eager and will execute the associated expression
immediately.
Parameters
----------
expr
The ibis expression to execute and persist to a Parquet file.
path
The data source. A string or Path to the Parquet file.
params
Mapping of scalar parameter expressions to value.
limit
An integer to effect a specific row limit. A value of `None` means
"no limit". The default is in `ibis/config.py`.
**kwargs
Additional keyword arguments passed to
[pyspark.sql.DataFrameWriter](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameWriter.html).
"""
if self.mode == "streaming":
raise NotImplementedError(
"Writing to a Parquet file in streaming mode is not supported."
)
self._run_pre_execute_hooks(expr)
df = self._session.sql(self.compile(expr, params=params, limit=limit))
df.write.format("parquet").save(os.fspath(path), **kwargs)

def to_pyarrow(
self,
expr: ir.Expr,
Expand Down
128 changes: 128 additions & 0 deletions ibis/backends/pyspark/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,131 @@ def test_create_table_no_catalog(con):

assert "t2" not in con.list_tables(database="default")
assert con.current_database != "default"


@pytest.mark.xfail_version(pyspark=["pyspark<3.4"], reason="no catalog support")
def test_create_table_with_partition_and_catalog(con):
# Create a sample table with a partition column
data = {
"epoch": [1712848119, 1712848121, 1712848155, 1712848169],
"category1": ["A", "B", "A", "C"],
"category2": ["G", "J", "G", "H"],
}

t = ibis.memtable(data)

# 1D partition
table_name = "pt1"

con.create_table(
table_name,
database=("spark_catalog", "default"),
obj=t,
overwrite=True,
partition_by="category1",
)
assert table_name in con.list_tables(database="spark_catalog.default")

partitions = (
con.raw_sql(f"SHOW PARTITIONS spark_catalog.default.{table_name}")
.toPandas()
.to_dict()
)
expected_partitions = {
"partition": {0: "category1=A", 1: "category1=B", 2: "category1=C"}
}
assert partitions == expected_partitions

# Cleanup
con.drop_table(table_name, database="spark_catalog.default")
assert table_name not in con.list_tables(database="spark_catalog.default")

# 2D partition
table_name = "pt2"

con.create_table(
table_name,
database=("spark_catalog", "default"),
obj=t,
overwrite=True,
partition_by=["category1", "category2"],
)
assert table_name in con.list_tables(database="spark_catalog.default")

partitions = (
con.raw_sql(f"SHOW PARTITIONS spark_catalog.default.{table_name}")
.toPandas()
.to_dict()
)
expected_partitions = {
"partition": {
0: "category1=A/category2=G",
1: "category1=B/category2=J",
2: "category1=C/category2=H",
}
}
assert partitions == expected_partitions

# Cleanup
con.drop_table(table_name, database="spark_catalog.default")
assert table_name not in con.list_tables(database="spark_catalog.default")


def test_create_table_with_partition_no_catalog(con):
data = {
"epoch": [1712848119, 1712848121, 1712848155, 1712848169],
"category1": ["A", "B", "A", "C"],
"category2": ["G", "J", "G", "H"],
}

t = ibis.memtable(data)

# 1D partition
table_name = "pt1"

con.create_table(
table_name,
obj=t,
overwrite=True,
partition_by="category1",
)
assert table_name in con.list_tables()

partitions = (
con.raw_sql(f"SHOW PARTITIONS ibis_testing.{table_name}").toPandas().to_dict()
)
expected_partitions = {
"partition": {0: "category1=A", 1: "category1=B", 2: "category1=C"}
}
assert partitions == expected_partitions

# Cleanup
con.drop_table(table_name)
assert table_name not in con.list_tables()

# 2D partition
table_name = "pt2"

con.create_table(
table_name,
obj=t,
overwrite=True,
partition_by=["category1", "category2"],
)
assert table_name in con.list_tables()

partitions = (
con.raw_sql(f"SHOW PARTITIONS ibis_testing.{table_name}").toPandas().to_dict()
)
expected_partitions = {
"partition": {
0: "category1=A/category2=G",
1: "category1=B/category2=J",
2: "category1=C/category2=H",
}
}
assert partitions == expected_partitions

# Cleanup
con.drop_table(table_name)
assert table_name not in con.list_tables()
39 changes: 39 additions & 0 deletions ibis/backends/pyspark/tests/test_import_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

import pandas as pd
import pytest
from pandas.testing import assert_frame_equal

from ibis.backends.pyspark.datatypes import PySparkSchema
from ibis.conftest import IS_SPARK_REMOTE


@pytest.mark.parametrize(
Expand Down Expand Up @@ -73,3 +75,40 @@ def test_to_parquet_dir(con_streaming, tmp_path):
sleep(2)
df = pd.concat([pd.read_parquet(f) for f in path.glob("*.parquet")])
assert len(df) == 5


@pytest.mark.skipif(
IS_SPARK_REMOTE, reason="Spark remote does not support assertions about local paths"
)
def test_to_parquet_read_parquet(con, tmp_path):
# No Partitions
t_out = con.table("awards_players")

t_out.to_parquet(tmp_path / "out_np")

t_in = con.read_parquet(tmp_path / "out_np")

cols = list(t_out.columns)
expected = t_out.to_pandas()[cols].sort_values(cols).reset_index(drop=True)
result = t_in.to_pandas()[cols].sort_values(cols).reset_index(drop=True)

assert_frame_equal(expected, result)

# Partitions
t_out = con.table("awards_players")

t_out.to_parquet(tmp_path / "out_p", partitionBy=["playerID"])

# Check partition paths
distinct_playerids = t_out.select("playerID").distinct().to_pandas()

for pid in distinct_playerids["playerID"]:
assert (tmp_path / "out_p" / f"playerID={pid}").exists()

t_in = con.read_parquet(tmp_path / "out_p")

cols = list(t_out.columns)
expected = t_out.to_pandas()[cols].sort_values(cols).reset_index(drop=True)
result = t_in.to_pandas()[cols].sort_values(cols).reset_index(drop=True)

assert_frame_equal(expected, result)
38 changes: 31 additions & 7 deletions ibis/backends/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def test_to_pyarrow_batches_memtable(con):


def test_table_to_parquet(tmp_path, backend, awards_players):
if backend.name() == "pyspark" and IS_SPARK_REMOTE:
pytest.skip("writes to remote output directory")
outparquet = tmp_path / "out.parquet"
awards_players.to_parquet(outparquet)

Expand Down Expand Up @@ -257,15 +259,32 @@ def test_table_to_parquet_writer_kwargs(version, tmp_path, backend, awards_playe
outparquet = tmp_path / "out.parquet"
awards_players.to_parquet(outparquet, version=version)

df = pd.read_parquet(outparquet)
if backend.name() == "pyspark":
if IS_SPARK_REMOTE:
pytest.skip("writes to remote output directory")
# Pyspark will write more than one parquet file under outparquet as directory
parquet_files = sorted(outparquet.glob("*.parquet"))
df = (
pd.concat(map(pd.read_parquet, parquet_files))
.sort_values(list(awards_players.columns))
.reset_index(drop=True)
)
result = (
awards_players.to_pandas()
.sort_values(list(awards_players.columns))
.reset_index(drop=True)
)
backend.assert_frame_equal(result, df)
else:
df = pd.read_parquet(outparquet)

backend.assert_frame_equal(
awards_players.to_pandas().fillna(pd.NA), df.fillna(pd.NA)
)
backend.assert_frame_equal(
awards_players.to_pandas().fillna(pd.NA), df.fillna(pd.NA)
)

md = pa.parquet.read_metadata(outparquet)
md = pa.parquet.read_metadata(outparquet)

assert md.format_version == version
assert md.format_version == version


@pytest.mark.notimpl(
Expand Down Expand Up @@ -333,7 +352,12 @@ def test_memtable_to_file(tmp_path, con, ftype, monkeypatch):

getattr(con, f"to_{ftype}")(memtable, outfile)

assert outfile.is_file()
if con.name == "pyspark" and ftype == "parquet":
if IS_SPARK_REMOTE:
pytest.skip("writes to remote output directory")
assert outfile.is_dir()
else:
assert outfile.is_file()


def test_table_to_csv(tmp_path, backend, awards_players):
Expand Down

0 comments on commit c99cc23

Please sign in to comment.