Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pyarrow support in find_column_type for pandas dataframes #313

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 36 additions & 17 deletions locopy/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import pandas as pd
import polars as pl
import pyarrow as pa
import yaml

from locopy.errors import (
Expand Down Expand Up @@ -317,6 +318,20 @@ def validate_float_object(column):
except (ValueError, TypeError):
return None

def check_column_type_pyarrow(pa_dtype):
if pa.types.is_temporal(pa_dtype):
return "timestamp"
elif pa.types.is_boolean(pa_dtype):
return "boolean"
elif pa.types.is_integer(pa_dtype):
return "int"
elif pa.types.is_floating(pa_dtype):
return "float"
elif pa.types.is_string(pa_dtype):
return "varchar"
else:
return "varchar"

if warehouse_type.lower() not in ["snowflake", "redshift"]:
raise ValueError(
'warehouse_type argument must be either "snowflake" or "redshift"'
Expand All @@ -328,24 +343,28 @@ def validate_float_object(column):
data = dataframe[column].dropna().reset_index(drop=True)
if data.size == 0:
column_type.append("varchar")
elif (data.dtype in ["datetime64[ns]", "M8[ns]"]) or (
re.match(r"(datetime64\[ns\,\W)([a-zA-Z]+)(\])", str(data.dtype))
):
column_type.append("timestamp")
elif str(data.dtype).lower().startswith("bool"):
column_type.append("boolean")
elif str(data.dtype).startswith("object"):
data_type = validate_float_object(data) or validate_date_object(data)
if not data_type:
column_type.append("varchar")
else:
column_type.append(data_type)
elif str(data.dtype).lower().startswith("int"):
column_type.append("int")
elif str(data.dtype).lower().startswith("float"):
column_type.append("float")
elif isinstance(data.dtype, pd.ArrowDtype):
datatype = check_column_type_pyarrow(data.dtype.pyarrow_dtype)
column_type.append(datatype)
else:
column_type.append("varchar")
if (data.dtype in ["datetime64[ns]", "M8[ns]"]) or (
re.match(r"(datetime64\[ns\,\W)([a-zA-Z]+)(\])", str(data.dtype))
):
column_type.append("timestamp")
elif str(data.dtype).lower().startswith("bool"):
column_type.append("boolean")
elif str(data.dtype).startswith("object"):
data_type = validate_float_object(data) or validate_date_object(data)
if not data_type:
column_type.append("varchar")
else:
column_type.append(data_type)
elif str(data.dtype).lower().startswith("int"):
column_type.append("int")
elif str(data.dtype).lower().startswith("float"):
column_type.append("float")
else:
column_type.append("varchar")
logger.info("Parsing column %s to %s", column, column_type[-1])
return OrderedDict(zip(list(dataframe.columns), column_type))

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ authors = [
{ name="Faisal Dosani", email="faisal.dosani@capitalone.com" },
]
license = {text = "Apache Software License"}
dependencies = ["boto3<=1.35.53,>=1.9.92", "PyYAML<=6.0.1,>=5.1", "pandas<=2.2.3,>=0.25.2", "numpy<=2.0.2,>=1.22.0", "polars>=0.20.0"]
dependencies = ["boto3<=1.35.53,>=1.9.92", "PyYAML<=6.0.1,>=5.1", "pandas<=2.2.3,>=1.5.0", "numpy<=2.0.2,>=1.22.0", "polars>=0.20.0", "pyarrow>=10.0.1"]

requires-python = ">=3.9.0"
classifiers = [
Expand Down
42 changes: 42 additions & 0 deletions tests/test_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from unittest import mock

import locopy.utility as util
import pyarrow as pa
import pytest
from locopy.errors import (
CompressionError,
Expand Down Expand Up @@ -388,7 +389,48 @@ def test_find_column_type_new():
"d": "varchar",
"e": "boolean",
}
assert find_column_type(input_text, "snowflake") == output_text_snowflake
assert find_column_type(input_text, "redshift") == output_text_redshift


def test_find_column_type_pyarrow():
import pandas as pd

input_text = pd.DataFrame.from_dict(
{
"a": [1],
"b": [pd.Timestamp("2017-01-01T12+0")],
"c": [1.2],
"d": ["a"],
"e": [True],
}
)

input_text = input_text.astype(
dtype={
"a": "int64[pyarrow]",
"b": "timestamp[ns, tz=UTC][pyarrow]",
"c": "float64[pyarrow]",
"d": pd.ArrowDtype(pa.string()),
"e": "bool[pyarrow]",
}
)

output_text_snowflake = {
"a": "int",
"b": "timestamp",
"c": "float",
"d": "varchar",
"e": "boolean",
}

output_text_redshift = {
"a": "int",
"b": "timestamp",
"c": "float",
"d": "varchar",
"e": "boolean",
}
assert find_column_type(input_text, "snowflake") == output_text_snowflake
assert find_column_type(input_text, "redshift") == output_text_redshift

Expand Down