@register()
def split_groups(
df: AnyDataFrame,
groupers: Annotated[
AllGrouper | UserDefinedGroupers,
Field(description="Index(es) and/or column(s) to group by"),
],
) -> Annotated[
KeyedIterableOfDataFrames,
Field(
description="""\
List of 2-tuples of key:value pairs. Each key:value pair consists of a composite
filter (the key) and the corresponding subset of the input dataframe (the value).
"""
),
]:
if isinstance(groupers, AllGrouper):
# if the special-cased AllGrouper is given, then just passthrough the `df`, wrapped
# in the container expected by downstream tasks. we don't actually have to apply the
# "all" index to the dataframe, because that is not used downstream.
return [(((groupers.index_name, "=", "True"),), df)]
value_groupers = [vg for vg in groupers if isinstance(vg, ValueGrouper)]
for vg in value_groupers:
key = vg.index_name
if key in df.index.names:
idx_level = df.index.names.index(key)
# Coerce any None/NA values into string "None"
if isinstance(df.index, pd.MultiIndex):
index_frame = df.index.to_frame()
index_frame.fillna({key: "None"}, inplace=True)
df.index = pd.MultiIndex.from_frame(index_frame)
else:
index_update = df.index.get_level_values(idx_level).fillna("None")
df.index = index_update
if not pd.api.types.is_string_dtype(df.index.get_level_values(idx_level)):
raise ValueError(
"All indexes used as categorical value groupers must contain "
f"only string data (with no null values); got {df.dtypes}."
)
elif key in df.columns:
# Coerce any None/NA values into string "None"
df.fillna({key: "None"}, inplace=True)
if not pd.api.types.is_string_dtype(df[key]):
raise ValueError(
"All columns used as categorical value groupers must contain "
f"only string data (with no null values); got {df.dtypes}."
)
else:
raise ValueError(f"Value grouper '{key}' is neither a column nor an index in the DataFrame")
# TODO: configurable cardinality constraint with a default?
grouper_index_names = [g.index_name for g in groupers]
grouped = df.groupby(grouper_index_names)
return [
(_groupkey_to_composite_filter(grouper_index_names, index_value), group) # type: ignore[misc,arg-type]
for index_value, group in grouped
]