Skip to content

Commit

Permalink
Refactor: clean up pivot lineage (#4534)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Dec 18, 2024
1 parent 84ec478 commit cd6e00f
Showing 1 changed file with 57 additions and 64 deletions.
121 changes: 57 additions & 64 deletions sqlglot/lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,23 +254,39 @@ def to_node(
if dt.comments and dt.comments[0].startswith("source: ")
}

pivots = scope.pivots
pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None
if pivot:
# For each aggregation function, the pivot creates a new column for each field in category
# combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a,
# b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum'
# belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs
# to the column indices 1, 3. Here, only the columns used in the aggregations are of interest
# in the lineage, so lookup the pivot column name by index and map that with the columns used
# in the aggregation.
#
# Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b')
pivot_columns = pivot.args["columns"]
pivot_aggs_count = len(pivot.expressions)

pivot_column_mapping = {}
for i, agg in enumerate(pivot.expressions):
agg_cols = list(agg.find_all(exp.Column))
for col_index in range(i, len(pivot_columns), pivot_aggs_count):
pivot_column_mapping[pivot_columns[col_index].name] = agg_cols

for c in source_columns:
table = c.table
source = scope.sources.get(table)

# check for a possible pivot here to calculate ones for the 2nd if case
# and to fall back to the last else case if no pivot is found
pivot = next(
(p for p in scope.pivots if p.alias_or_name == c.table and not p.unpivot), None
)

if isinstance(source, Scope):
reference_node_name = None
if source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names:
reference_node_name = table
elif source.scope_type == ScopeType.CTE:
selected_node, _ = scope.selected_sources.get(table, (None, None))
reference_node_name = selected_node.name if selected_node else None

# The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
to_node(
c.name,
Expand All @@ -282,68 +298,45 @@ def to_node(
reference_node_name=reference_node_name,
trim_selects=trim_selects,
)
elif pivot and pivot.alias_or_name == c.table:
downstream_columns = []

elif pivot:
if isinstance(pivot.args.get("columns"), list):
# The source is a pivot operation, so we need to trace back to the aggregated column
pivot_column_mapping = {}
pivot_aggs_count = len(pivot.expressions)
columns = pivot.args["columns"]
columns_count = len(columns)
column_name = c.name
downstream_columns = []

if any(column_name == pivoted_column.name for pivoted_column in columns):
# The column is in the pivot, so we need to trace back to the aggregation that created it.
for i, agg in enumerate(pivot.expressions):
agg_cols = list(agg.find_all(exp.Column))
for col_index in range(i, columns_count, pivot_aggs_count):
# e.g. "pivot (sum(value) as value_sum, max(price)) for category in ('a' as cat_a, 'b')"
# For each aggregation function, the pivot creates a new column for each field in category combined with the aggfunc.
# So the columns parsed have this order: cat_a_value_sum, cat_a, b_value_sum, b.
# Because of this step wise manner the aggfunc 'sum(value) as value_sum' belongs to the column indices 0, 2,
# and the aggfunc 'max(price)' without an alias belongs to the column indices 1, 3.
# Here only the columns used in the aggregations are of interest in the lineage, so lookup the pivot column name
# by index and map that with the columns used in the aggregation.
pivot_column_mapping[columns[col_index].name] = agg_cols

downstream_columns.extend(pivot_column_mapping[column_name])
column_name = c.name
if any(column_name == pivot_column.name for pivot_column in pivot_columns):
downstream_columns.extend(pivot_column_mapping[column_name])
else:
# The column is not in the pivot, so it must be an implicit column of the
# pivoted source -- adapt column to be from the implicit pivoted source.
downstream_columns.append(exp.column(c.this, table=pivot.parent.this))

for downstream_column in downstream_columns:
table = downstream_column.table
source = scope.sources.get(table)
if isinstance(source, Scope):
to_node(
downstream_column.name,
scope=source,
scope_name=table,
dialect=dialect,
upstream=node,
source_name=source_names.get(table) or source_name,
reference_node_name=reference_node_name,
trim_selects=trim_selects,
)
else:
# The column is not in the pivot, so it must be an implicit column of the pivoted source.
# adapt column to be from the implicit pivoted source
parent = t.cast(exp.Table, pivot.parent.copy())
parent.args["pivots"] = []
adapted_column = t.cast(exp.Column, c.copy())
adapted_column.args["table"] = parent.this
downstream_columns.append(adapted_column)

for downstream_column in downstream_columns:
table = downstream_column.table
source = scope.sources.get(table)
if isinstance(source, Scope):
to_node(
downstream_column.name,
scope=source,
scope_name=table,
dialect=dialect,
upstream=node,
source_name=source_names.get(table) or source_name,
reference_node_name=reference_node_name,
trim_selects=trim_selects,
)
else:
source = source or exp.Placeholder()
node.downstream.append(
Node(
name=downstream_column.sql(comments=False),
source=source,
expression=source,
)
source = source or exp.Placeholder()
node.downstream.append(
Node(
name=downstream_column.sql(comments=False),
source=source,
expression=source,
)
)
else:
# The source is not a scope and the column is not in any pivot - we've reached the end of the line. At this point, if a source is not found
# it means this column's lineage is unknown. This can happen if the definition of a source used in a query
# is not passed into the `sources` map.
# The source is not a scope and the column is not in any pivot - we've reached the end
# of the line. At this point, if a source is not found it means this column's lineage
# is unknown. This can happen if the definition of a source used in a query is not
# passed into the `sources` map.
source = source or exp.Placeholder()
node.downstream.append(
Node(name=c.sql(comments=False), source=source, expression=source)
Expand Down

0 comments on commit cd6e00f

Please sign in to comment.