You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by ar...@apache.org on 2024/03/08 00:44:12 UTC
(superset) 02/09: Table with Time Comparison:
This is an automated email from the ASF dual-hosted git repository.
arivero pushed a commit to branch table-time-comparison
in repository https://gitbox.apache.org/repos/asf/superset.git
commit c3d04b3fa870cd8338edbcd4888bf491e5fac57d
Author: Antonio Rivero <an...@gmail.com>
AuthorDate: Fri Mar 1 14:41:35 2024 +0100
Table with Time Comparison:
- Using one single query with some new properties added to QueryObject so we generate the comparison data instead of two queries
- Use joins when generating the comparison query
- Add time comparison control to Table chart
- Render Time comparison metrics in Table chart
- Render header with column name on top of each group of 4 metrics columns
- Modify useSticky to consider multiple rows of headers when computing the columns widths
- Add tests for new query building function
---
.../superset-ui-core/src/query/types/Query.ts | 10 +
.../src/query/types/QueryResponse.ts | 1 +
.../plugin-chart-table/src/DataTable/DataTable.tsx | 45 ++-
.../src/DataTable/hooks/useSticky.tsx | 4 +-
.../plugins/plugin-chart-table/src/TableChart.tsx | 42 +++
.../plugins/plugin-chart-table/src/buildQuery.ts | 58 +++-
.../plugin-chart-table/src/controlPanel.tsx | 72 ++++
.../plugin-chart-table/src/transformProps.ts | 165 ++++++++-
.../plugins/plugin-chart-table/src/types.ts | 2 +
.../plugin-chart-table/src/utils/isEqualColumns.ts | 3 +-
.../plugins/plugin-chart-table/test/testData.ts | 1 +
superset/charts/schemas.py | 20 +-
superset/common/query_context_processor.py | 3 +
superset/common/query_object.py | 4 +
superset/connectors/sqla/models.py | 126 ++++++-
superset/constants.py | 1 +
tests/unit_tests/connectors/__init__.py | 16 +
tests/unit_tests/connectors/test_models.py | 383 +++++++++++++++++++++
tests/unit_tests/queries/query_object_test.py | 1 +
19 files changed, 942 insertions(+), 15 deletions(-)
diff --git a/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts b/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts
index 718f10514c..db3a090dd6 100644
--- a/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts
+++ b/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts
@@ -77,6 +77,13 @@ export type ResidualQueryObjectData = {
[key: string]: unknown;
};
+export type QueryObjectInstantTimeComparisonInfo = {
+ /** The range to use as comparison range */
+ range: string;
+ /** The custom filter value to use if range is Custom */
+ filter?: QueryObjectFilterClause;
+};
+
/**
* Query object directly compatible with the new chart data API.
* A stricter version of query form data.
@@ -149,6 +156,9 @@ export interface QueryObject
series_columns?: QueryFormColumn[];
series_limit?: number;
series_limit_metric?: Maybe<QueryFormMetric>;
+
+ /** Instant Time Comparison */
+ instant_time_comparison_info?: QueryObjectInstantTimeComparisonInfo;
}
export interface QueryContext {
diff --git a/superset-frontend/packages/superset-ui-core/src/query/types/QueryResponse.ts b/superset-frontend/packages/superset-ui-core/src/query/types/QueryResponse.ts
index 1705814df1..d910e9a778 100644
--- a/superset-frontend/packages/superset-ui-core/src/query/types/QueryResponse.ts
+++ b/superset-frontend/packages/superset-ui-core/src/query/types/QueryResponse.ts
@@ -78,6 +78,7 @@ export interface ChartDataResponseResult {
| 'timed_out';
from_dttm: number | null;
to_dttm: number | null;
+ instant_time_comparison_range: string | null;
}
export interface TimeseriesChartDataResponseResult
diff --git a/superset-frontend/plugins/plugin-chart-table/src/DataTable/DataTable.tsx b/superset-frontend/plugins/plugin-chart-table/src/DataTable/DataTable.tsx
index 6c5123806f..a0af54eb6a 100644
--- a/superset-frontend/plugins/plugin-chart-table/src/DataTable/DataTable.tsx
+++ b/superset-frontend/plugins/plugin-chart-table/src/DataTable/DataTable.tsx
@@ -67,6 +67,7 @@ export interface DataTableProps<D extends object> extends TableOptions<D> {
rowCount: number;
wrapperRef?: MutableRefObject<HTMLDivElement>;
onColumnOrderChange: () => void;
+ groupHeaderColumns?: Record<string, number[]>;
}
export interface RenderHTMLCellProps extends HTMLProps<HTMLTableCellElement> {
@@ -99,6 +100,7 @@ export default typedMemo(function DataTable<D extends object>({
serverPagination,
wrapperRef: userWrapperRef,
onColumnOrderChange,
+ groupHeaderColumns,
...moreUseTableOptions
}: DataTableProps<D>): JSX.Element {
const tableHooks: PluginHook<D>[] = [
@@ -248,14 +250,55 @@ export default typedMemo(function DataTable<D extends object>({
e.preventDefault();
};
+ const renderDynamicHeaders = () => {
+ // TODO: Make use of ColumnGroup to render the aditional headers
+ const headers: any = [];
+ let currentColumnIndex = 0;
+
+ Object.entries(groupHeaderColumns || {}).forEach(([key, value], index) => {
+ // Calculate the number of placeholder columns needed before the current header
+ const startPosition = value[0];
+ const colSpan = value.length;
+
+ // Add placeholder <th> for columns before this header
+ for (let i = currentColumnIndex; i < startPosition; i += 1) {
+ headers.push(
+ <th
+ key={`placeholder-${i}`}
+ style={{ borderBottom: 0 }}
+ aria-label={`Header-${i}`}
+ />,
+ );
+ }
+
+ // Add the current header <th>
+ headers.push(
+ <th key={`header-${key}`} colSpan={colSpan} style={{ borderBottom: 0 }}>
+ {key}
+ </th>,
+ );
+
+ // Update the current column index
+ currentColumnIndex = startPosition + colSpan;
+ });
+
+ return headers;
+ };
+
const renderTable = () => (
<table {...getTableProps({ className: tableClassName })}>
<thead>
+ {/* Render dynamic headers based on resultMap */}
+ {groupHeaderColumns ? <tr>{renderDynamicHeaders()}</tr> : null}
{headerGroups.map(headerGroup => {
const { key: headerGroupKey, ...headerGroupProps } =
headerGroup.getHeaderGroupProps();
return (
- <tr key={headerGroupKey || headerGroup.id} {...headerGroupProps}>
+ <tr
+ key={headerGroupKey || headerGroup.id}
+ {...headerGroupProps}
+ style={{ borderTop: 0 }}
+ >
{headerGroup.headers.map(column =>
column.render('Header', {
key: column.id,
diff --git a/superset-frontend/plugins/plugin-chart-table/src/DataTable/hooks/useSticky.tsx b/superset-frontend/plugins/plugin-chart-table/src/DataTable/hooks/useSticky.tsx
index ba3466bb40..1e56987486 100644
--- a/superset-frontend/plugins/plugin-chart-table/src/DataTable/hooks/useSticky.tsx
+++ b/superset-frontend/plugins/plugin-chart-table/src/DataTable/hooks/useSticky.tsx
@@ -181,7 +181,9 @@ function StickyWrap({
}
const fullTableHeight = (bodyThead.parentNode as HTMLTableElement)
.clientHeight;
- const ths = bodyThead.childNodes[0]
+ // instead of always using the first tr, we use the last one to support
+ // multi-level headers assuming the last one is the more detailed one
+ const ths = bodyThead.childNodes?.[bodyThead.childNodes?.length - 1 || 0]
.childNodes as NodeListOf<HTMLTableHeaderCellElement>;
const widths = Array.from(ths).map(
th => th.getBoundingClientRect()?.width || th.clientWidth,
diff --git a/superset-frontend/plugins/plugin-chart-table/src/TableChart.tsx b/superset-frontend/plugins/plugin-chart-table/src/TableChart.tsx
index 840020cad8..d4d5de970a 100644
--- a/superset-frontend/plugins/plugin-chart-table/src/TableChart.tsx
+++ b/superset-frontend/plugins/plugin-chart-table/src/TableChart.tsx
@@ -50,6 +50,7 @@ import {
tn,
} from '@superset-ui/core';
+import { isEmpty } from 'lodash';
import { DataColumnMeta, TableChartTransformedProps } from './types';
import DataTable, {
DataTableProps,
@@ -238,6 +239,7 @@ export default function TableChart<D extends DataRecord = DataRecord>(
allowRearrangeColumns = false,
onContextMenu,
emitCrossFilters,
+ enableTimeComparison,
} = props;
const timestampFormatter = useCallback(
value => getTimeFormatterForGranularity(timeGrain)(value),
@@ -413,6 +415,37 @@ export default function TableChart<D extends DataRecord = DataRecord>(
}
: undefined;
+ const comparisonLabels = [t('Main'), '#', '△', '%'];
+
+ const getHeaderColumns = (
+ columnsMeta: DataColumnMeta[],
+ enableTimeComparison?: boolean,
+ ) => {
+ const resultMap: Record<string, number[]> = {};
+
+ if (!enableTimeComparison) {
+ return resultMap;
+ }
+
+ columnsMeta.forEach((element, index) => {
+ // Check if element's label is one of the comparison labels
+ if (comparisonLabels.includes(element.label)) {
+ // Extract the key portion after the space, assuming the format is always "label key"
+ const keyPortion = element.key.split(' ')[1];
+
+ // If the key portion is not in the map, initialize it with the current index
+ if (!resultMap[keyPortion]) {
+ resultMap[keyPortion] = [index];
+ } else {
+ // Add the index to the existing array
+ resultMap[keyPortion].push(index);
+ }
+ }
+ });
+
+ return resultMap;
+ };
+
const getColumnConfigs = useCallback(
(column: DataColumnMeta, i: number): ColumnWithLooseAccessor<D> => {
const {
@@ -596,6 +629,7 @@ export default function TableChart<D extends DataRecord = DataRecord>(
style={{
...sharedStyle,
...style,
+ borderTop: 0,
}}
tabIndex={0}
onKeyDown={(e: React.KeyboardEvent<HTMLElement>) => {
@@ -670,6 +704,11 @@ export default function TableChart<D extends DataRecord = DataRecord>(
[columnsMeta, getColumnConfigs],
);
+ const groupHeaderColumns = useMemo(
+ () => getHeaderColumns(columnsMeta, enableTimeComparison),
+ [columnsMeta, enableTimeComparison],
+ );
+
const handleServerPaginationChange = useCallback(
(pageNumber: number, pageSize: number) => {
updateExternalFormData(setDataMask, pageNumber, pageSize);
@@ -734,6 +773,9 @@ export default function TableChart<D extends DataRecord = DataRecord>(
selectPageSize={pageSize !== null && SelectPageSize}
// not in use in Superset, but needed for unit tests
sticky={sticky}
+ groupHeaderColumns={
+ !isEmpty(groupHeaderColumns) ? groupHeaderColumns : undefined
+ }
/>
</Styles>
);
diff --git a/superset-frontend/plugins/plugin-chart-table/src/buildQuery.ts b/superset-frontend/plugins/plugin-chart-table/src/buildQuery.ts
index 69631a5f35..9e93c268a4 100644
--- a/superset-frontend/plugins/plugin-chart-table/src/buildQuery.ts
+++ b/superset-frontend/plugins/plugin-chart-table/src/buildQuery.ts
@@ -19,11 +19,17 @@
import {
AdhocColumn,
buildQueryContext,
+ buildQueryObject,
+ ComparisonTimeRangeType,
ensureIsArray,
+ FeatureFlag,
+ getComparisonInfo,
getMetricLabel,
+ isFeatureEnabled,
isPhysicalColumn,
QueryMode,
QueryObject,
+ QueryObjectFilterClause,
removeDuplicates,
} from '@superset-ui/core';
import { PostProcessingRule } from '@superset-ui/core/src/query/types/PostProcessing';
@@ -55,7 +61,12 @@ const buildQuery: BuildQuery<TableChartFormData> = (
percent_metrics: percentMetrics,
order_desc: orderDesc = false,
extra_form_data,
+ time_comparison: timeComparison,
+ enable_time_comparison,
} = formData;
+ const canUseTimeComparison =
+ enable_time_comparison &&
+ isFeatureEnabled(FeatureFlag.ChartPluginsExperimental);
const queryMode = getQueryMode(formData);
const sortByMetric = ensureIsArray(formData.timeseries_limit_metric)[0];
const time_grain_sqla =
@@ -69,6 +80,34 @@ const buildQuery: BuildQuery<TableChartFormData> = (
};
}
+ const addComparisonPercentMetrics = (metrics: string[]) =>
+ metrics.reduce((acc, metric) => {
+ const prevMetric = `prev_${metric}`;
+ return acc.concat([metric, prevMetric]);
+ }, [] as string[]);
+
+ const comparisonFormData = getComparisonInfo(
+ formDataCopy,
+ timeComparison,
+ extra_form_data,
+ );
+
+ const getFirstTemporalFilter = (
+ queryObject?: QueryObject,
+ ): QueryObjectFilterClause | undefined => {
+ const { filters = [] } = queryObject || {};
+ const timeFilterIndex: number =
+ filters?.findIndex(
+ filter => 'op' in filter && filter.op === 'TEMPORAL_RANGE',
+ ) ?? -1;
+
+ const timeFilter: QueryObjectFilterClause | undefined =
+ timeFilterIndex !== -1 && filters ? filters[timeFilterIndex] : undefined;
+ return timeFilter;
+ };
+ const comparisonQueryObject = buildQueryObject(comparisonFormData);
+ const firstTemporalFilter = getFirstTemporalFilter(comparisonQueryObject);
+
return buildQueryContext(formDataCopy, baseQueryObject => {
let { metrics, orderby = [], columns = [] } = baseQueryObject;
let postProcessing: PostProcessingRule[] = [];
@@ -85,8 +124,11 @@ const buildQuery: BuildQuery<TableChartFormData> = (
}
// add postprocessing for percent metrics only when in aggregation mode
if (percentMetrics && percentMetrics.length > 0) {
+ const percentMetricsLabelsWithTimeComparison = canUseTimeComparison
+ ? addComparisonPercentMetrics(percentMetrics.map(getMetricLabel))
+ : percentMetrics.map(getMetricLabel);
const percentMetricLabels = removeDuplicates(
- percentMetrics.map(getMetricLabel),
+ percentMetricsLabelsWithTimeComparison,
);
metrics = removeDuplicates(
metrics.concat(percentMetrics),
@@ -139,6 +181,20 @@ const buildQuery: BuildQuery<TableChartFormData> = (
...moreProps,
};
+ // Customize the query for time comparison
+ if (canUseTimeComparison) {
+ queryObject = {
+ ...queryObject,
+ instant_time_comparison_info: {
+ range: timeComparison,
+ filter:
+ timeComparison === ComparisonTimeRangeType.Custom
+ ? firstTemporalFilter
+ : undefined,
+ },
+ };
+ }
+
if (
formData.server_pagination &&
options?.extras?.cachedChanges?.[formData.slice_id] &&
diff --git a/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx b/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx
index ad39b504cb..c7710c81df 100644
--- a/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx
+++ b/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx
@@ -20,14 +20,18 @@
import React from 'react';
import {
ChartDataResponseResult,
+ ComparisonTimeRangeType,
ensureIsArray,
+ FeatureFlag,
GenericDataType,
isAdhocColumn,
+ isFeatureEnabled,
isPhysicalColumn,
QueryFormColumn,
QueryMode,
smartDateFormatter,
t,
+ validateTimeComparisonRangeValues,
} from '@superset-ui/core';
import {
ColumnOption,
@@ -257,6 +261,74 @@ const config: ControlPanelConfig = {
},
],
['adhoc_filters'],
+ [
+ {
+ name: 'enable_time_comparison',
+ config: {
+ type: 'CheckboxControl',
+ label: t('Enable Time Comparison'),
+ description: t('Enable time comparison (experimental feature)'),
+ default: false,
+ visibility: () =>
+ isFeatureEnabled(FeatureFlag.ChartPluginsExperimental),
+ },
+ },
+ ],
+ [
+ {
+ name: 'time_comparison',
+ config: {
+ type: 'SelectControl',
+ label: t('Range for Comparison'),
+ default: 'r',
+ choices: [
+ ['r', 'Inherit range from time filters'],
+ ['y', 'Year'],
+ ['m', 'Month'],
+ ['w', 'Week'],
+ ['c', 'Custom'],
+ ],
+ rerender: ['adhoc_custom'],
+ description: t(
+ 'Set the time range that will be used for the comparison metrics. ' +
+ 'For example, "Year" will compare to the same dates one year earlier. ' +
+ 'Use "Inherit range from time filters" to shift the comparison time range' +
+ 'by the same length as your time range and use "Custom" to set a custom comparison range.',
+ ),
+ visibility: ({ controls }) =>
+ Boolean(controls?.enable_time_comparison?.value) &&
+ isFeatureEnabled(FeatureFlag.ChartPluginsExperimental),
+ },
+ },
+ ],
+ [
+ {
+ name: `adhoc_custom`,
+ config: {
+ ...sharedControls.adhoc_filters,
+ label: t('Filters for Comparison'),
+ description:
+ 'This only applies when selecting the Range for Comparison Type: Custom',
+ visibility: ({ controls }) =>
+ Boolean(controls?.enable_time_comparison?.value) &&
+ controls?.time_comparison?.value ===
+ ComparisonTimeRangeType.Custom,
+ mapStateToProps: (
+ state: ControlPanelState,
+ controlState: ControlState,
+ ) => ({
+ ...(sharedControls.adhoc_filters.mapStateToProps?.(
+ state,
+ controlState,
+ ) || {}),
+ externalValidationErrors: validateTimeComparisonRangeValues(
+ state.controls?.time_comparison?.value,
+ controlState.value,
+ ),
+ }),
+ },
+ },
+ ],
[
{
name: 'timeseries_limit_metric',
diff --git a/superset-frontend/plugins/plugin-chart-table/src/transformProps.ts b/superset-frontend/plugins/plugin-chart-table/src/transformProps.ts
index 0a2a3449c6..e36684baff 100644
--- a/superset-frontend/plugins/plugin-chart-table/src/transformProps.ts
+++ b/superset-frontend/plugins/plugin-chart-table/src/transformProps.ts
@@ -21,14 +21,17 @@ import {
CurrencyFormatter,
DataRecord,
extractTimegrain,
+ FeatureFlag,
GenericDataType,
getMetricLabel,
getNumberFormatter,
getTimeFormatter,
getTimeFormatterForGranularity,
+ isFeatureEnabled,
NumberFormats,
QueryMode,
smartDateFormatter,
+ t,
TimeFormats,
TimeFormatter,
} from '@superset-ui/core';
@@ -48,6 +51,8 @@ import {
const { PERCENT_3_POINT } = NumberFormats;
const { DATABASE_DATETIME } = TimeFormats;
+const COMPARISON_PREFIX = 'prev_';
+
function isNumeric(key: string, data: DataRecord[] = []) {
return data.every(
x => x[key] === null || x[key] === undefined || typeof x[key] === 'number',
@@ -81,6 +86,88 @@ const processDataRecords = memoizeOne(function processDataRecords(
return data;
});
+const calculateDifferences = (
+ originalValue: number,
+ comparisonValue: number,
+) => {
+ const valueDifference = originalValue - comparisonValue;
+ let percentDifferenceNum;
+ if (!originalValue && !comparisonValue) {
+ percentDifferenceNum = 0;
+ } else if (!originalValue || !comparisonValue) {
+ percentDifferenceNum = originalValue ? 1 : -1;
+ } else {
+ percentDifferenceNum =
+ (originalValue - comparisonValue) / Math.abs(comparisonValue);
+ }
+ return { valueDifference, percentDifferenceNum };
+};
+
+const processComparisonTotals = (totals: DataRecord | undefined) => {
+ if (!totals) {
+ return totals;
+ }
+ const transformedTotals: DataRecord = {};
+ Object.keys(totals).forEach(key => {
+ if (totals[key] !== undefined && !key.includes(COMPARISON_PREFIX)) {
+ transformedTotals[`Main ${key}`] = totals[key];
+ transformedTotals[`# ${key}`] = totals[`${COMPARISON_PREFIX}${key}`];
+ const { valueDifference, percentDifferenceNum } = calculateDifferences(
+ totals[key] as number,
+ totals[`${COMPARISON_PREFIX}${key}`] as number,
+ );
+ transformedTotals[`△ ${key}`] = valueDifference;
+ transformedTotals[`% ${key}`] = percentDifferenceNum;
+ }
+ });
+ return transformedTotals;
+};
+
+const processComparisonDataRecords = memoizeOne(
+ function processComparisonDataRecords(
+ originalData: DataRecord[] | undefined,
+ originalColumns: DataColumnMeta[],
+ ) {
+ // Transform data
+ return originalData?.map(originalItem => {
+ const transformedItem: DataRecord = {};
+ originalColumns.forEach(origCol => {
+ if (
+ (origCol.isMetric || origCol.isPercentMetric) &&
+ !origCol.key.includes(COMPARISON_PREFIX) &&
+ origCol.isNumeric
+ ) {
+ const originalValue = originalItem[origCol.key] || 0;
+ const comparisonValue = origCol.isMetric
+ ? originalItem?.[`${COMPARISON_PREFIX}${origCol.key}`] || 0
+ : originalItem[`%${COMPARISON_PREFIX}${origCol.key.slice(1)}`] || 0;
+ const { valueDifference, percentDifferenceNum } =
+ calculateDifferences(
+ originalValue as number,
+ comparisonValue as number,
+ );
+
+ transformedItem[`Main ${origCol.key}`] = originalValue;
+ transformedItem[`# ${origCol.key}`] = comparisonValue;
+ transformedItem[`△ ${origCol.key}`] = valueDifference;
+ transformedItem[`% ${origCol.key}`] = percentDifferenceNum;
+ }
+ });
+
+ Object.keys(originalItem).forEach(key => {
+ const isMetricOrPercentMetric = originalColumns.some(
+ col => col.key === key && (col.isMetric || col.isPercentMetric),
+ );
+ if (!isMetricOrPercentMetric) {
+ transformedItem[key] = originalItem[key];
+ }
+ });
+
+ return transformedItem;
+ });
+ },
+);
+
const processColumns = memoizeOne(function processColumns(
props: TableChartProps,
) {
@@ -186,6 +273,55 @@ const processColumns = memoizeOne(function processColumns(
];
}, isEqualColumns);
+const processComparisonColumns = (
+ columns: DataColumnMeta[],
+ props: TableChartProps,
+) =>
+ columns
+ .map(col => {
+ const {
+ datasource: { columnFormats },
+ rawFormData: { column_config: columnConfig = {} },
+ } = props;
+ const config = columnConfig[col.key] || {};
+ const savedFormat = columnFormats?.[col.key];
+ const numberFormat = config.d3NumberFormat || savedFormat;
+ if (col.isNumeric && !col.key.includes(COMPARISON_PREFIX)) {
+ return [
+ {
+ ...col,
+ label: t('Main'),
+ key: `${t('Main')} ${col.key}`,
+ },
+ {
+ ...col,
+ label: `#`,
+ key: `# ${col.key}`,
+ },
+ {
+ ...col,
+ label: `△`,
+ key: `△ ${col.key}`,
+ },
+ {
+ ...col,
+ formatter: getNumberFormatter(numberFormat || PERCENT_3_POINT),
+ label: `%`,
+ key: `% ${col.key}`,
+ },
+ ];
+ }
+ if (
+ !col.isMetric &&
+ !col.isPercentMetric &&
+ !col.key.includes(COMPARISON_PREFIX)
+ ) {
+ return [col];
+ }
+ return [];
+ })
+ .flat();
+
/**
* Automatically set page size based on number of cells.
*/
@@ -238,23 +374,35 @@ const transformProps = (
show_totals: showTotals,
conditional_formatting: conditionalFormatting,
allow_rearrange_columns: allowRearrangeColumns,
+ enable_time_comparison: enableTimeComparison = false,
} = formData;
+ const canUseTimeComparison =
+ enableTimeComparison &&
+ isFeatureEnabled(FeatureFlag.ChartPluginsExperimental);
const timeGrain = extractTimegrain(formData);
const [metrics, percentMetrics, columns] = processColumns(chartProps);
+ let comparisonColumns: DataColumnMeta[] = [];
+ if (canUseTimeComparison) {
+ comparisonColumns = processComparisonColumns(columns, chartProps);
+ }
let baseQuery;
let countQuery;
let totalQuery;
let rowCount;
+ const queriesDataWithoutComparisonQueries = queriesData.filter(
+ ({ instant_time_comparison_range }) => !instant_time_comparison_range,
+ );
if (serverPagination) {
- [baseQuery, countQuery, totalQuery] = queriesData;
+ [baseQuery, countQuery, totalQuery] = queriesDataWithoutComparisonQueries;
rowCount = (countQuery?.data?.[0]?.rowcount as number) ?? 0;
} else {
- [baseQuery, totalQuery] = queriesData;
+ [baseQuery, totalQuery] = queriesDataWithoutComparisonQueries;
rowCount = baseQuery?.rowcount ?? 0;
}
const data = processDataRecords(baseQuery?.data, columns);
+ const comparisonData = processComparisonDataRecords(baseQuery?.data, columns);
const totals =
showTotals && queryMode === QueryMode.Aggregate
? totalQuery?.data[0]
@@ -262,13 +410,19 @@ const transformProps = (
const columnColorFormatters =
getColorFormatters(conditionalFormatting, data) ?? defaultColorFormatters;
+ const comparisonTotals = processComparisonTotals(totals);
+
+ const passedData = canUseTimeComparison ? comparisonData || [] : data;
+ const passedTotals = canUseTimeComparison ? comparisonTotals : totals;
+ const passedColumns = canUseTimeComparison ? comparisonColumns : columns;
+
return {
height,
width,
isRawRecords: queryMode === QueryMode.Raw,
- data,
- totals,
- columns,
+ data: passedData,
+ totals: passedTotals,
+ columns: passedColumns,
serverPagination,
metrics,
percentMetrics,
@@ -292,6 +446,7 @@ const transformProps = (
timeGrain,
allowRearrangeColumns,
onContextMenu,
+ enableTimeComparison: canUseTimeComparison,
};
};
diff --git a/superset-frontend/plugins/plugin-chart-table/src/types.ts b/superset-frontend/plugins/plugin-chart-table/src/types.ts
index 02bae809fe..1806eddb1a 100644
--- a/superset-frontend/plugins/plugin-chart-table/src/types.ts
+++ b/superset-frontend/plugins/plugin-chart-table/src/types.ts
@@ -91,6 +91,7 @@ export type TableChartFormData = QueryFormData & {
time_grain_sqla?: TimeGranularity;
column_config?: Record<string, TableColumnConfig>;
allow_rearrange_columns?: boolean;
+ enable_time_comparison?: boolean;
};
export interface TableChartProps extends ChartProps {
@@ -135,6 +136,7 @@ export interface TableChartTransformedProps<D extends DataRecord = DataRecord> {
clientY: number,
filters?: ContextMenuFilters,
) => void;
+ enableTimeComparison?: boolean;
}
export default {};
diff --git a/superset-frontend/plugins/plugin-chart-table/src/utils/isEqualColumns.ts b/superset-frontend/plugins/plugin-chart-table/src/utils/isEqualColumns.ts
index 28731c73c2..8153ea856a 100644
--- a/superset-frontend/plugins/plugin-chart-table/src/utils/isEqualColumns.ts
+++ b/superset-frontend/plugins/plugin-chart-table/src/utils/isEqualColumns.ts
@@ -41,6 +41,7 @@ export default function isEqualColumns(
JSON.stringify(a.formData.extraFormData || null) ===
JSON.stringify(b.formData.extraFormData || null) &&
JSON.stringify(a.rawFormData.column_config || null) ===
- JSON.stringify(b.rawFormData.column_config || null)
+ JSON.stringify(b.rawFormData.column_config || null) &&
+ a.formData.enableTimeComparison === b.formData.enableTimeComparison
);
}
diff --git a/superset-frontend/plugins/plugin-chart-table/test/testData.ts b/superset-frontend/plugins/plugin-chart-table/test/testData.ts
index 24abc3381e..af2fbe5a65 100644
--- a/superset-frontend/plugins/plugin-chart-table/test/testData.ts
+++ b/superset-frontend/plugins/plugin-chart-table/test/testData.ts
@@ -84,6 +84,7 @@ const basicQueryResult: ChartDataResponseResult = {
status: 'success',
from_dttm: null,
to_dttm: null,
+ instant_time_comparison_range: null,
};
/**
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 611f7af597..34731af571 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -26,6 +26,7 @@ from marshmallow.validate import Length, Range
from superset import app
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
+from superset.constants import InstantTimeComparison
from superset.db_engine_specs.base import builtin_time_grains
from superset.tags.models import TagType
from superset.utils import pandas_postprocessing, schema as utils
@@ -948,6 +949,14 @@ class ChartDataFilterSchema(Schema):
)
+class InstantTimeComparisonInfoSchema(Schema):
+ range = fields.String(
+ metadata={"description": "Type of time comparison to be used"},
+ validate=validate.OneOf(choices=[ran.value for ran in InstantTimeComparison]),
+ )
+ filter = fields.Nested(ChartDataFilterSchema, allow_none=True)
+
+
class ChartDataExtrasSchema(Schema):
relative_start = fields.String(
metadata={
@@ -994,7 +1003,8 @@ class ChartDataExtrasSchema(Schema):
metadata={
"description": "This is only set using the new time comparison controls "
"that is made available in some plugins behind the experimental "
- "feature flag."
+ "feature flag. If passed as extra, the time range will be changed inside this"
+ " query object."
},
allow_none=True,
)
@@ -1350,6 +1360,14 @@ class ChartDataQueryObjectSchema(Schema):
fields.String(),
allow_none=True,
)
+ instant_time_comparison_info = fields.Nested(
+ InstantTimeComparisonInfoSchema,
+ metadata={
+ "description": "Extra parameters to use instant time comparison"
+ " with JOINs using a single query"
+ },
+ allow_none=True,
+ )
class ChartDataQueryContextSchema(Schema):
diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py
index d8b5bea4bb..77f84989b1 100644
--- a/superset/common/query_context_processor.py
+++ b/superset/common/query_context_processor.py
@@ -197,6 +197,9 @@ class QueryContextProcessor:
"from_dttm": query_obj.from_dttm,
"to_dttm": query_obj.to_dttm,
"label_map": label_map,
+ "instant_time_comparison_range": query_obj.extras.get(
+ "instant_time_comparison_range"
+ ),
}
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None:
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index 5109c465e0..77f3a08ce8 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -107,6 +107,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
time_shift: str | None
time_range: str | None
to_dttm: datetime | None
+ instant_time_comparison_info: dict[str, Any] | None
def __init__( # pylint: disable=too-many-locals
self,
@@ -132,6 +133,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
series_limit_metric: Metric | None = None,
time_range: str | None = None,
time_shift: str | None = None,
+ instant_time_comparison_info: dict[str, Any] | None = None,
**kwargs: Any,
):
self._set_annotation_layers(annotation_layers)
@@ -161,6 +163,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
self.time_offsets = kwargs.get("time_offsets", [])
self.inner_from_dttm = kwargs.get("inner_from_dttm")
self.inner_to_dttm = kwargs.get("inner_to_dttm")
+ self.instant_time_comparison_info = instant_time_comparison_info
self._rename_deprecated_fields(kwargs)
self._move_deprecated_extra_fields(kwargs)
@@ -335,6 +338,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
"series_limit_metric": self.series_limit_metric,
"to_dttm": self.to_dttm,
"time_shift": self.time_shift,
+ "instant_time_comparison_info": self.instant_time_comparison_info,
}
return query_object_dict
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 089b9c2f28..b6e14ce62e 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import builtins
+import copy
import dataclasses
import json
import logging
@@ -81,7 +82,7 @@ from superset.connectors.sqla.utils import (
get_physical_table_metadata,
get_virtual_table_metadata,
)
-from superset.constants import EMPTY_STRING, NULL_STRING
+from superset.constants import EMPTY_STRING, InstantTimeComparison, NULL_STRING
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
from superset.exceptions import (
ColumnNotFoundException,
@@ -105,6 +106,7 @@ from superset.models.helpers import (
ImportExportMixin,
QueryResult,
QueryStringExtended,
+ SqlaQuery,
validate_adhoc_subquery,
)
from superset.models.slice import Slice
@@ -120,7 +122,7 @@ from superset.superset_typing import (
)
from superset.utils import core as utils
from superset.utils.backports import StrEnum
-from superset.utils.core import GenericDataType, MediumText
+from superset.utils.core import FilterOperator, GenericDataType, MediumText
config = app.config
metadata = Model.metadata # pylint: disable=no-member
@@ -1413,24 +1415,138 @@ class SqlaTable(
def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
return get_template_processor(table=self, database=self.database, **kwargs)
+ def extract_column_names(self, final_selected_columns: Any) -> list[str]:
+ column_names = []
+ for selected_col in final_selected_columns:
+ # The key attribute usually holds the name or alias of the column
+ column_name = selected_col.key if hasattr(selected_col, "key") else None
+ # If the column has a name attribute, use it as a fallback
+ if not column_name and hasattr(selected_col, "name"):
+ column_name = selected_col.name
+ # For labeled elements, the name is stored in the 'name' attribute
+ if hasattr(selected_col, "name"):
+ column_name = selected_col.name
+ # Append the extracted name to the list
+ if column_name:
+ column_names.append(column_name)
+ return column_names
+
+ def process_time_compare_join( # pylint: disable=too-many-locals
+ self,
+ query_obj: QueryObjectDict,
+ sqlaq: SqlaQuery,
+ mutate: bool,
+ instant_time_comparison_info: dict[str, Any],
+ ) -> tuple[str, list[str]]:
+ query_obj_clone = copy.copy(query_obj)
+ final_query_sql = ""
+ query_obj_clone["row_limit"] = None
+ query_obj_clone["row_offset"] = None
+ instant_time_comparison_range = instant_time_comparison_info.get("range")
+ if instant_time_comparison_range == InstantTimeComparison.CUSTOM:
+ custom_filter = instant_time_comparison_info.get("filter", {})
+ temporal_filters = [
+ filter["col"]
+ for filter in query_obj_clone.get("filter", {})
+ if filter.get("op", None) == FilterOperator.TEMPORAL_RANGE
+ ]
+ non_temporal_filters = [
+ filter["col"]
+ for filter in query_obj_clone.get("filter", {})
+ if filter.get("op", None) != FilterOperator.TEMPORAL_RANGE
+ ]
+ if len(temporal_filters) > 0:
+ # Edit the firt temporal filter to include the custom filter
+ temporal_filters[0] = custom_filter
+
+ new_filters = temporal_filters + non_temporal_filters
+ query_obj_clone["filter"] = new_filters
+ if instant_time_comparison_range != InstantTimeComparison.CUSTOM:
+ query_obj_clone["extras"] = {
+ **query_obj_clone.get("extras", {}),
+ "instant_time_comparison_range": instant_time_comparison_range,
+ }
+ sqlaq_2 = self.get_sqla_query(**query_obj_clone)
+ join_columns = query_obj_clone.get("columns") or []
+ sqla_query_a = sqlaq.sqla_query
+ sqla_query_b = sqlaq_2.sqla_query
+ sqla_query_b_subquery = sqla_query_b.subquery()
+ query_a_cte = sqla_query_a.cte("query_a_results")
+ column_names_a = [column.key for column in sqla_query_a.c]
+ exclude_columns_b = set(query_obj_clone.get("columns") or [])
+ selected_columns_a = [query_a_cte.c[col].label(col) for col in column_names_a]
+ # Renamed columns from Query B (with "prev_" prefix)
+ selected_columns_b = [
+ sqla_query_b_subquery.c[col].label(f"prev_{col}")
+ for col in sqla_query_b_subquery.c.keys()
+ if col not in exclude_columns_b
+ ]
+ # Combine selected columns from both queries
+ final_selected_columns = selected_columns_a + selected_columns_b
+ if join_columns and not query_obj_clone.get("is_rowcount"):
+ # Proceed with JOIN operation as before since join_columns is not empty
+ join_conditions = [
+ sqla_query_b_subquery.c[col] == query_a_cte.c[col]
+ for col in join_columns
+ if col in sqla_query_b_subquery.c and col in query_a_cte.c
+ ]
+ final_query = sa.select(*final_selected_columns).select_from(
+ sqla_query_b_subquery.join(query_a_cte, sa.and_(*join_conditions))
+ )
+ else:
+ final_query = sa.select(*final_selected_columns).select_from(
+ sqla_query_b_subquery.join(
+ query_a_cte, sa.literal(True) == sa.literal(True)
+ )
+ )
+ final_query_sql = self.database.compile_sqla_query(final_query)
+ final_query_sql = self._apply_cte(final_query_sql, sqlaq.cte)
+ final_query_sql = sqlparse.format(final_query_sql, reindent=True)
+ if mutate:
+ final_query_sql = self.mutate_query_from_config(final_query_sql)
+
+ labels_expected = self.extract_column_names(final_selected_columns)
+
+ return final_query_sql, labels_expected
+
def get_query_str_extended(
self,
query_obj: QueryObjectDict,
mutate: bool = True,
) -> QueryStringExtended:
- sqlaq = self.get_sqla_query(**query_obj)
+ # So we don't mutate the original query_obj
+ query_obj_clone = copy.copy(query_obj)
+ instant_time_comparison_info = query_obj.get("instant_time_comparison_info")
+ query_obj_clone.pop("instant_time_comparison_info", None)
+ sqlaq = self.get_sqla_query(**query_obj_clone)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
sql = self._apply_cte(sql, sqlaq.cte)
sql = sqlparse.format(sql, reindent=True)
+
if mutate:
sql = self.mutate_query_from_config(sql)
+
+ if (
+ is_feature_enabled("CHART_PLUGINS_EXPERIMENTAL")
+ and instant_time_comparison_info
+ ):
+ (
+ final_query_sql,
+ labels_expected,
+ ) = self.process_time_compare_join(
+ query_obj_clone, sqlaq, mutate, instant_time_comparison_info
+ )
+ else:
+ final_query_sql = sql
+ labels_expected = sqlaq.labels_expected
+
return QueryStringExtended(
applied_template_filters=sqlaq.applied_template_filters,
applied_filter_columns=sqlaq.applied_filter_columns,
rejected_filter_columns=sqlaq.rejected_filter_columns,
- labels_expected=sqlaq.labels_expected,
+ labels_expected=labels_expected,
prequeries=sqlaq.prequeries,
- sql=sql,
+ sql=final_query_sql if final_query_sql else sql,
)
def get_query_str(self, query_obj: QueryObjectDict) -> str:
diff --git a/superset/constants.py b/superset/constants.py
index bf4e7717d5..9af8870e2d 100644
--- a/superset/constants.py
+++ b/superset/constants.py
@@ -48,6 +48,7 @@ class InstantTimeComparison(StrEnum):
YEAR = "y"
MONTH = "m"
WEEK = "w"
+ CUSTOM = "c"
class RouteMethod: # pylint: disable=too-few-public-methods
diff --git a/tests/unit_tests/connectors/__init__.py b/tests/unit_tests/connectors/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/unit_tests/connectors/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit_tests/connectors/test_models.py b/tests/unit_tests/connectors/test_models.py
new file mode 100644
index 0000000000..cf179c9dfa
--- /dev/null
+++ b/tests/unit_tests/connectors/test_models.py
@@ -0,0 +1,383 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import datetime
+
+from sqlalchemy.orm.session import Session
+
+from superset import db
+from tests.unit_tests.conftest import with_feature_flags
+
+
+class TestInstantTimeComparisonQueryGeneration:
+ @staticmethod
+ def base_setup(session: Session):
+ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
+ from superset.models.core import Database
+
+ engine = db.session.get_bind()
+ SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
+
+ table = SqlaTable(
+ table_name="my_table",
+ schema="my_schema",
+ database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
+ )
+
+ # Common columns
+ columns = [
+ {"column_name": "ds", "type": "DATETIME"},
+ {"column_name": "gender", "type": "VARCHAR(255)"},
+ {"column_name": "name", "type": "VARCHAR(255)"},
+ {"column_name": "state", "type": "VARCHAR(255)"},
+ ]
+
+ # Add columns to the table
+ for col in columns:
+ TableColumn(column_name=col["column_name"], type=col["type"], table=table)
+
+ # Common metrics
+ metrics = [
+ {"metric_name": "count", "expression": "count(*)"},
+ {"metric_name": "sum_sum", "expression": "SUM"},
+ ]
+
+ # Add metrics to the table
+ for metric in metrics:
+ SqlMetric(
+ metric_name=metric["metric_name"],
+ expression=metric["expression"],
+ table=table,
+ )
+
+ db.session.add(table)
+ db.session.flush()
+
+ return table
+
+ @staticmethod
+ def generate_base_query_obj():
+ return {
+ "apply_fetch_values_predicate": False,
+ "columns": ["name"],
+ "extras": {"having": "", "where": ""},
+ "filter": [
+ {"op": "TEMPORAL_RANGE", "val": "1984-01-01 : 2024-02-14", "col": "ds"}
+ ],
+ "from_dttm": datetime.datetime(1984, 1, 1, 0, 0),
+ "granularity": None,
+ "inner_from_dttm": None,
+ "inner_to_dttm": None,
+ "is_rowcount": False,
+ "is_timeseries": False,
+ "order_desc": True,
+ "orderby": [("SUM(num_boys)", False)],
+ "row_limit": 10,
+ "row_offset": 0,
+ "series_columns": [],
+ "series_limit": 0,
+ "series_limit_metric": None,
+ "to_dttm": datetime.datetime(2024, 2, 14, 0, 0),
+ "time_shift": None,
+ "metrics": [
+ {
+ "aggregate": "SUM",
+ "column": {
+ "column_name": "num_boys",
+ "type": "BIGINT",
+ "filterable": True,
+ "groupby": True,
+ "id": 334,
+ "is_certified": False,
+ "is_dttm": False,
+ "type_generic": 0,
+ },
+ "datasourceWarning": False,
+ "expressionType": "SIMPLE",
+ "hasCustomLabel": False,
+ "label": "SUM(num_boys)",
+ "optionName": "metric_gzp6eq9g1lc_d8o0mj0mhq4",
+ "sqlExpression": None,
+ },
+ {
+ "aggregate": "SUM",
+ "column": {
+ "column_name": "num_girls",
+ "type": "BIGINT",
+ "filterable": True,
+ "groupby": True, # Note: This will need adjustment in some cases
+ "id": 335,
+ "is_certified": False,
+ "is_dttm": False,
+ "type_generic": 0,
+ },
+ "datasourceWarning": False,
+ "expressionType": "SIMPLE",
+ "hasCustomLabel": False,
+ "label": "SUM(num_girls)",
+ "optionName": "metric_5gyhtmyfw1t_d42py86jpco",
+ "sqlExpression": None,
+ },
+ ],
+ "instant_time_comparison_info": {
+ "range": "y",
+ },
+ }
+
+ @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True)
+ def test_creates_time_comparison_query(session: Session):
+ table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+ query_obj = TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+ str = table.get_query_str_extended(query_obj)
+ expected_str = """
+ WITH query_a_results AS
+ (SELECT name AS name,
+ sum(num_boys) AS "SUM(num_boys)",
+ sum(num_girls) AS "SUM(num_girls)"
+ FROM my_schema.my_table
+ WHERE ds >= '1984-01-01 00:00:00'
+ AND ds < '2024-02-14 00:00:00'
+ GROUP BY name
+ ORDER BY "SUM(num_boys)" DESC
+ LIMIT 10
+ OFFSET 0)
+ SELECT query_a_results.name AS name,
+ query_a_results."SUM(num_boys)" AS "SUM(num_boys)",
+ query_a_results."SUM(num_girls)" AS "SUM(num_girls)",
+ anon_1."SUM(num_boys)" AS "prev_SUM(num_boys)",
+ anon_1."SUM(num_girls)" AS "prev_SUM(num_girls)"
+ FROM
+ (SELECT name AS name,
+ sum(num_boys) AS "SUM(num_boys)",
+ sum(num_girls) AS "SUM(num_girls)"
+ FROM my_schema.my_table
+ WHERE ds >= '1983-01-01 00:00:00'
+ AND ds < '2023-02-14 00:00:00'
+ GROUP BY name
+ ORDER BY "SUM(num_boys)" DESC) AS anon_1
+ JOIN query_a_results ON anon_1.name = query_a_results.name
+ """
+ simplified_query1 = " ".join(str.sql.split()).lower()
+ simplified_query2 = " ".join(expected_str.split()).lower()
+ assert table.id == 1
+ assert simplified_query1 == simplified_query2
+
+ @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True)
+ def test_creates_time_comparison_query_no_columns(session: Session):
+ table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+ query_obj = TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+ query_obj["columns"] = []
+ query_obj["metrics"][0]["column"]["groupby"] = False
+ query_obj["metrics"][1]["column"]["groupby"] = False
+
+ str = table.get_query_str_extended(query_obj)
+ expected_str = """
+ WITH query_a_results AS
+ (SELECT sum(num_boys) AS "SUM(num_boys)",
+ sum(num_girls) AS "SUM(num_girls)"
+ FROM my_schema.my_table
+ WHERE ds >= '1984-01-01 00:00:00'
+ AND ds < '2024-02-14 00:00:00'
+ ORDER BY "SUM(num_boys)" DESC
+ LIMIT 10
+ OFFSET 0)
+ SELECT query_a_results."SUM(num_boys)" AS "SUM(num_boys)",
+ query_a_results."SUM(num_girls)" AS "SUM(num_girls)",
+ anon_1."SUM(num_boys)" AS "prev_SUM(num_boys)",
+ anon_1."SUM(num_girls)" AS "prev_SUM(num_girls)"
+ FROM
+ (SELECT sum(num_boys) AS "SUM(num_boys)",
+ sum(num_girls) AS "SUM(num_girls)"
+ FROM my_schema.my_table
+ WHERE ds >= '1983-01-01 00:00:00'
+ AND ds < '2023-02-14 00:00:00'
+ ORDER BY "SUM(num_boys)" DESC) AS anon_1
+ JOIN query_a_results ON 1 = 1
+ """
+ simplified_query1 = " ".join(str.sql.split()).lower()
+ simplified_query2 = " ".join(expected_str.split()).lower()
+ assert table.id == 1
+ assert simplified_query1 == simplified_query2
+
+ @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True)
+ def test_creates_time_comparison_rowcount_query(session: Session):
+ table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+ query_obj = TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+ query_obj["is_rowcount"] = True
+ str = table.get_query_str_extended(query_obj)
+ expected_str = """
+ WITH query_a_results AS
+ (SELECT COUNT(*) AS rowcount
+ FROM
+ (SELECT name AS name,
+ sum(num_boys) AS "SUM(num_boys)",
+ sum(num_girls) AS "SUM(num_girls)"
+ FROM my_schema.my_table
+ WHERE ds >= '1984-01-01 00:00:00'
+ AND ds < '2024-02-14 00:00:00'
+ GROUP BY name
+ ORDER BY "SUM(num_boys)" DESC
+ LIMIT 10
+ OFFSET 0) AS rowcount_qry)
+ SELECT query_a_results.rowcount AS rowcount,
+ anon_1.rowcount AS prev_rowcount
+ FROM
+ (SELECT COUNT(*) AS rowcount
+ FROM
+ (SELECT name AS name,
+ sum(num_boys) AS "SUM(num_boys)",
+ sum(num_girls) AS "SUM(num_girls)"
+ FROM my_schema.my_table
+ WHERE ds >= '1983-01-01 00:00:00'
+ AND ds < '2023-02-14 00:00:00'
+ GROUP BY name
+ ORDER BY "SUM(num_boys)" DESC) AS rowcount_qry) AS anon_1
+ JOIN query_a_results ON 1 = 1
+ """
+ simplified_query1 = " ".join(str.sql.split()).lower()
+ simplified_query2 = " ".join(expected_str.split()).lower()
+ assert table.id == 1
+ assert simplified_query1 == simplified_query2
+
+ @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True)
+ def test_creates_query_without_time_comparison(session: Session):
+ table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+ query_obj = TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+ query_obj["instant_time_comparison_info"] = None
+ str = table.get_query_str_extended(query_obj)
+ expected_str = """
+ SELECT name AS name,
+ sum(num_boys) AS "SUM(num_boys)",
+ sum(num_girls) AS "SUM(num_girls)"
+ FROM my_schema.my_table
+ WHERE ds >= '1984-01-01 00:00:00'
+ AND ds < '2024-02-14 00:00:00'
+ GROUP BY name
+ ORDER BY "SUM(num_boys)" DESC
+ LIMIT 10
+ OFFSET 0
+ """
+ simplified_query1 = " ".join(str.sql.split()).lower()
+ simplified_query2 = " ".join(expected_str.split()).lower()
+ assert table.id == 1
+ assert simplified_query1 == simplified_query2
+
+ @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True)
+ def test_creates_time_comparison_query_custom_filters(session: Session):
+ table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+ query_obj = TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+ query_obj["instant_time_comparison_info"] = {
+ "range": "c",
+ "filter": {
+ "op": "TEMPORAL_RANGE",
+ "val": "1900-01-01 : 1950-02-14",
+ "col": "ds",
+ },
+ }
+ str = table.get_query_str_extended(query_obj)
+ expected_str = """
+ WITH query_a_results AS
+ (SELECT name AS name,
+ sum(num_boys) AS "SUM(num_boys)",
+ sum(num_girls) AS "SUM(num_girls)"
+ FROM my_schema.my_table
+ WHERE ds >= '1984-01-01 00:00:00'
+ AND ds < '2024-02-14 00:00:00'
+ GROUP BY name
+ ORDER BY "SUM(num_boys)" DESC
+ LIMIT 10
+ OFFSET 0)
+ SELECT query_a_results.name AS name,
+ query_a_results."SUM(num_boys)" AS "SUM(num_boys)",
+ query_a_results."SUM(num_girls)" AS "SUM(num_girls)",
+ anon_1."SUM(num_boys)" AS "prev_SUM(num_boys)",
+ anon_1."SUM(num_girls)" AS "prev_SUM(num_girls)"
+ FROM
+ (SELECT name AS name,
+ sum(num_boys) AS "SUM(num_boys)",
+ sum(num_girls) AS "SUM(num_girls)"
+ FROM my_schema.my_table
+ WHERE ds >= '1900-01-01 00:00:00'
+ AND ds < '1950-02-14 00:00:00'
+ GROUP BY name
+ ORDER BY "SUM(num_boys)" DESC) AS anon_1
+ JOIN query_a_results ON anon_1.name = query_a_results.name
+ """
+ simplified_query1 = " ".join(str.sql.split()).lower()
+ simplified_query2 = " ".join(expected_str.split()).lower()
+ assert table.id == 1
+ assert simplified_query1 == simplified_query2
+
+ @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True)
+ def test_creates_time_comparison_query_paginated(session: Session):
+ table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+ query_obj = TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+ query_obj["row_offset"] = 20
+ str = table.get_query_str_extended(query_obj)
+ expected_str = """
+ WITH query_a_results AS
+ (SELECT name AS name,
+ sum(num_boys) AS "SUM(num_boys)",
+ sum(num_girls) AS "SUM(num_girls)"
+ FROM my_schema.my_table
+ WHERE ds >= '1984-01-01 00:00:00'
+ AND ds < '2024-02-14 00:00:00'
+ GROUP BY name
+ ORDER BY "SUM(num_boys)" DESC
+ LIMIT 10
+ OFFSET 20)
+ SELECT query_a_results.name AS name,
+ query_a_results."SUM(num_boys)" AS "SUM(num_boys)",
+ query_a_results."SUM(num_girls)" AS "SUM(num_girls)",
+ anon_1."SUM(num_boys)" AS "prev_SUM(num_boys)",
+ anon_1."SUM(num_girls)" AS "prev_SUM(num_girls)"
+ FROM
+ (SELECT name AS name,
+ sum(num_boys) AS "SUM(num_boys)",
+ sum(num_girls) AS "SUM(num_girls)"
+ FROM my_schema.my_table
+ WHERE ds >= '1983-01-01 00:00:00'
+ AND ds < '2023-02-14 00:00:00'
+ GROUP BY name
+ ORDER BY "SUM(num_boys)" DESC) AS anon_1
+ JOIN query_a_results ON anon_1.name = query_a_results.name
+ """
+ simplified_query1 = " ".join(str.sql.split()).lower()
+ simplified_query2 = " ".join(expected_str.split()).lower()
+ assert table.id == 1
+ assert simplified_query1 == simplified_query2
+
+ @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=False)
+ def test_ignore_if_ff_off(session: Session):
+ table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+ query_obj = TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+ str = table.get_query_str_extended(query_obj)
+ expected_str = """
+ SELECT name AS name,
+ sum(num_boys) AS "SUM(num_boys)",
+ sum(num_girls) AS "SUM(num_girls)"
+ FROM my_schema.my_table
+ WHERE ds >= '1984-01-01 00:00:00'
+ AND ds < '2024-02-14 00:00:00'
+ GROUP BY name
+ ORDER BY "SUM(num_boys)" DESC
+ LIMIT 10
+ OFFSET 0
+ """
+ simplified_query1 = " ".join(str.sql.split()).lower()
+ simplified_query2 = " ".join(expected_str.split()).lower()
+ assert table.id == 1
+ assert simplified_query1 == simplified_query2
diff --git a/tests/unit_tests/queries/query_object_test.py b/tests/unit_tests/queries/query_object_test.py
index 81a654653f..f90ab8255d 100644
--- a/tests/unit_tests/queries/query_object_test.py
+++ b/tests/unit_tests/queries/query_object_test.py
@@ -47,6 +47,7 @@ def test_default_query_object_to_dict():
"granularity": None,
"inner_from_dttm": None,
"inner_to_dttm": None,
+ "instant_time_comparison_info": None,
"is_rowcount": False,
"is_timeseries": False,
"metrics": None,