Index: data/histograms/histogram_tpcds.parquet ================================================================== --- data/histograms/histogram_tpcds.parquet +++ data/histograms/histogram_tpcds.parquet cannot compute difference between binary files ADDED params_config/search_params/tpcds.toml Index: params_config/search_params/tpcds.toml ================================================================== --- /dev/null +++ params_config/search_params/tpcds.toml @@ -0,0 +1,16 @@ +dataset = "TPCDS" +dev = true +max_hops = [1,2,4] +extra_predicates = [1,3,5] +row_retention_probability = [0.2, 0.3, 0.4, 0.6, 0.8, 0.85, 0.9, 1.0] +unique_joins = true +max_queries_per_fact_table = 10 +max_queries_per_signature = 2 +keep_edge_probability = 0.2 +equality_lower_bound_probability = [0,0.1] +extra_values_for_in = 3 + +[operator_weights] +operator_in = 1 +operator_range = 3 +operator_equal = 3 ADDED params_config/search_params/tpcds_dev.toml Index: params_config/search_params/tpcds_dev.toml ================================================================== --- /dev/null +++ params_config/search_params/tpcds_dev.toml @@ -0,0 +1,16 @@ +dataset = "TPCDS" +dev = true +max_hops = [1] +extra_predicates = [5] +row_retention_probability = [0.2, 0.9] +unique_joins = true +max_queries_per_fact_table = 1 +max_queries_per_signature = 2 +keep_edge_probability = 0.2 +equality_lower_bound_probability = [0,0.1] +extra_values_for_in = 3 + +[operator_weights] +operator_in = 1 +operator_range = 3 +operator_equal = 3 ADDED params_config/snowflake/tpcds.toml Index: params_config/snowflake/tpcds.toml ================================================================== --- /dev/null +++ params_config/snowflake/tpcds.toml @@ -0,0 +1,18 @@ +dataset = "TPCDS" +max_hops = 3 +max_queries_per_fact_table = 100 +max_queries_per_signature = 1 +keep_edge_probability = 0.2 + + + +[predicate_parameters] +row_retention_probability = 0.2 +extra_predicates = 3 +equality_lower_bound_probability = 0.00 +extra_values_for_in = 3 + +[predicate_parameters.operator_weights] +operator_in = 1 +operator_range = 3 +operator_equal = 3 Index: pyproject.toml ================================================================== --- pyproject.toml +++ pyproject.toml @@ -28,10 +28,11 @@ pypika = ">=0.48.9,<0.49" numpy = ">=2.2.5,<3" duckdb = ">=1.2.2,<2" polars = ">=1.27.1,<2" tqdm = "*" +cattrs = ">=24.1.2,<25" [tool.pixi.feature.test.dependencies] pytest = ">=8.3.5,<9" Index: src/query_generator/database_schemas/schemas.py ================================================================== --- src/query_generator/database_schemas/schemas.py +++ src/query_generator/database_schemas/schemas.py @@ -3,11 +3,11 @@ from query_generator.database_schemas.tpcds import get_tpcds_table_info from query_generator.database_schemas.tpch import get_tpch_table_info from query_generator.utils.definitions import Dataset from query_generator.utils.exceptions import ( PartiallySupportedDatasetError, - UnkwonDatasetError, + UnkownDatasetError, ) def get_schema(dataset: Dataset) -> tuple[dict[str, dict[str, Any]], list[str]]: """Get the schema of the database based on the dataset. @@ -24,6 +24,6 @@ return get_tpcds_table_info() if dataset == Dataset.TPCH: return get_tpch_table_info() if dataset == Dataset.JOB: raise PartiallySupportedDatasetError(dataset.value) - raise UnkwonDatasetError(dataset) + raise UnkownDatasetError(dataset) Index: src/query_generator/database_schemas/tpcds.py ================================================================== --- src/query_generator/database_schemas/tpcds.py +++ src/query_generator/database_schemas/tpcds.py @@ -438,11 +438,10 @@ "s_market_id": {"max": 10, "min": 1}, "s_number_employees": {"max": 300, "min": 200}, "s_rec_end_date": {"max": "2001-03-12", "min": "1999-03-13"}, "s_rec_start_date": {"max": "2001-03-13", "min": "1997-03-13"}, "s_store_sk": {"max": 402, "min": 1}, - "s_tax_precentage": {"max": 0.11, "min": 0.0}, }, "foreign_keys": [], }, "store_returns": { "alias": "sr", Index: src/query_generator/duckdb_connection/binning.py ================================================================== --- src/query_generator/duckdb_connection/binning.py +++ src/query_generator/duckdb_connection/binning.py @@ -11,25 +11,22 @@ from query_generator.join_based_query_generator.utils.query_writer import ( Writer, ) from query_generator.utils.definitions import ( BatchGeneratedQueryFeatures, - Dataset, Extension, + PredicateParameters, QueryGenerationParameters, ) +from query_generator.utils.params import SearchParametersEndpoint @dataclass class SearchParameters: - dataset: Dataset + user_input: SearchParametersEndpoint scale_factor: int | float con: duckdb.DuckDBPyConnection - max_hops: list[int] - extra_predicates: list[int] - row_retention_probability: list[float] - unique_joins: bool def get_result_from_duckdb(query: str, con: duckdb.DuckDBPyConnection) -> int: try: result = int(con.sql(query).fetchall()[0][0]) @@ -37,11 +34,11 @@ print(f"Invalid query, exception: {e},\n{query}") return -1 return result -def get_total_iterations(search_params: SearchParameters) -> int: +def get_total_iterations(search_params: SearchParametersEndpoint) -> int: """Get the total number of iterations for the Snowflake binning process. Args: search_params (SearchParameters): The parameters for the Snowflake binning process. @@ -52,10 +49,11 @@ """ return ( len(search_params.max_hops) * len(search_params.extra_predicates) * len(search_params.row_retention_probability) + * len(search_params.equality_lower_bound_probability) ) def run_snowflake_param_seach( search_params: SearchParameters, @@ -66,37 +64,48 @@ parameters (BinningSnowflakeParameters): The parameters for the Snowflake binning process. """ query_writer = Writer( - search_params.dataset, + search_params.user_input.dataset, Extension.SNOWFLAKE_SEARCH_PARAMS, ) rows: list[dict[str, str | int | float]] = [] - total_iterations = get_total_iterations(search_params) + total_iterations = get_total_iterations(search_params.user_input) batch_number = 0 seen_subgraphs: dict[int, bool] = {} - for max_hops, extra_predicates, row_retention_probability in tqdm( + for ( + max_hops, + extra_predicates, + row_retention_probability, + equality_lower_bound_probability, + ) in tqdm( product( - search_params.max_hops, - search_params.extra_predicates, - search_params.row_retention_probability, + search_params.user_input.max_hops, + search_params.user_input.extra_predicates, + search_params.user_input.row_retention_probability, + search_params.user_input.equality_lower_bound_probability, ), total=total_iterations, desc="Progress", ): batch_number += 1 query_generator = QueryGenerator( QueryGenerationParameters( - dataset=search_params.dataset, + dataset=search_params.user_input.dataset, max_hops=max_hops, - max_queries_per_fact_table=10, - max_queries_per_signature=2, - keep_edge_prob=0.2, - extra_predicates=extra_predicates, - row_retention_probability=float(row_retention_probability), + max_queries_per_fact_table=search_params.user_input.max_queries_per_fact_table, + max_queries_per_signature=search_params.user_input.max_queries_per_signature, + keep_edge_probability=search_params.user_input.keep_edge_probability, seen_subgraphs=seen_subgraphs, + predicate_parameters=PredicateParameters( + extra_predicates=extra_predicates, + row_retention_probability=row_retention_probability, + operator_weights=search_params.user_input.operator_weights, + equality_lower_bound_probability=equality_lower_bound_probability, + extra_values_for_in=search_params.user_input.extra_values_for_in, + ), ) ) for query in query_generator.generate_queries(): selected_rows = get_result_from_duckdb(query.query, search_params.con) if selected_rows == -1: @@ -126,9 +135,9 @@ "max_hops": max_hops, "row_retention_probability": row_retention_probability, }, ) # Update the seen subgraphs with the new ones - if search_params.unique_joins: + if search_params.user_input.unique_joins: seen_subgraphs = query_generator.subgraph_generator.seen_subgraphs df_queries = pl.DataFrame(rows) query_writer.write_dataframe(df_queries) Index: src/query_generator/duckdb_connection/setup.py ================================================================== --- src/query_generator/duckdb_connection/setup.py +++ src/query_generator/duckdb_connection/setup.py @@ -4,11 +4,11 @@ from query_generator.utils.definitions import Dataset from query_generator.utils.exceptions import ( MissingScaleFactorError, PartiallySupportedDatasetError, - UnkwonDatasetError, + UnkownDatasetError, ) def load_and_install_libraries() -> None: duckdb.install_extension("TPCDS") @@ -27,11 +27,11 @@ elif dataset == Dataset.TPCH: con.execute(f"CALL dbgen(sf = {scale_factor})") elif dataset == Dataset.JOB: raise PartiallySupportedDatasetError(dataset.value) else: - raise UnkwonDatasetError(dataset) + raise UnkownDatasetError(dataset) def get_path( dataset: Dataset, scale_factor: float | int | None, @@ -38,11 +38,11 @@ ) -> str: if dataset in [Dataset.TPCDS, Dataset.TPCH]: return f"data/duckdb/{dataset.value}/{scale_factor}.db" if dataset == Dataset.JOB: return f"data/duckdb/{dataset.value}/job.db" - raise UnkwonDatasetError(dataset.value) + raise UnkownDatasetError(dataset.value) def setup_duckdb( dataset: Dataset, scale_factor: int | float | None = None, Index: src/query_generator/join_based_query_generator/snowflake.py ================================================================== --- src/query_generator/join_based_query_generator/snowflake.py +++ src/query_generator/join_based_query_generator/snowflake.py @@ -18,16 +18,21 @@ from query_generator.join_based_query_generator.utils.query_writer import ( Writer, ) from query_generator.predicate_generator.predicate_generator import ( HistogramDataType, + PredicateEquality, PredicateGenerator, + PredicateIn, + PredicateRange, + SupportedHistogramType, ) from query_generator.utils.definitions import ( Dataset, Extension, GeneratedQueryFeatures, + PredicateParameters, QueryGenerationParameters, ) from query_generator.utils.exceptions import InvalidHistogramTypeError from query_generator.utils.utils import set_seed @@ -37,16 +42,17 @@ self, subgraph_generator: SubGraphGenerator, # TODO(Gabriel): http://localhost:8080/tktview/b9400c203a38f3aef46ec250d98563638ba7988b tables_schema: Any, dataset: Dataset, + predicate_params: PredicateParameters, ) -> None: self.sub_graph_gen = subgraph_generator self.table_to_pypika_table = { i: Table(i, alias=tables_schema[i]["alias"]) for i in tables_schema } - self.predicate_gen = PredicateGenerator(dataset) + self.predicate_gen = PredicateGenerator(dataset, predicate_params) self.tables_schema = tables_schema def get_subgraph_tables( self, subgraph: list[ForeignKeyGraph.Edge], @@ -84,64 +90,56 @@ def add_predicates( self, subgraph: list[ForeignKeyGraph.Edge], query: OracleQuery, - extra_predicates: int, - row_retention_probability: float, ) -> OracleQuery: subgraph_tables = self.get_subgraph_tables(subgraph) for predicate in self.predicate_gen.get_random_predicates( subgraph_tables, - extra_predicates, - row_retention_probability, - ): - query = self._add_range(query, predicate) - return query - - def _add_range( - self, query: OracleQuery, predicate: PredicateGenerator.Predicate - ) -> OracleQuery: - if predicate.dtype in [HistogramDataType.INT, HistogramDataType.FLOAT]: - return self._add_range_number(query, predicate) - if predicate.dtype in [HistogramDataType.DATE]: - return self._add_range_date(query, predicate) - if predicate.dtype in [HistogramDataType.STRING]: - return self._add_range_string(query, predicate) - raise InvalidHistogramTypeError(str(predicate.dtype)) - - def _add_range_number( - self, query: OracleQuery, predicate: PredicateGenerator.Predicate - ) -> OracleQuery: - return query.where( - self.table_to_pypika_table[predicate.table][predicate.column] - >= predicate.min_value, - ).where( - self.table_to_pypika_table[predicate.table][predicate.column] - <= predicate.max_value, - ) - - def _add_range_date( - self, query: OracleQuery, predicate: PredicateGenerator.Predicate - ) -> OracleQuery: - return query.where( - self.table_to_pypika_table[predicate.table][predicate.column] - >= fn.Cast(predicate.min_value, "date"), - ).where( - self.table_to_pypika_table[predicate.table][predicate.column] - <= fn.Cast(predicate.max_value, "date"), - ) - - def _add_range_string( - self, query: OracleQuery, predicate: PredicateGenerator.Predicate - ) -> OracleQuery: - return query.where( - self.table_to_pypika_table[predicate.table][predicate.column] - >= predicate.min_value, - ).where( - self.table_to_pypika_table[predicate.table][predicate.column] - <= predicate.max_value + ): + if isinstance(predicate, PredicateRange): + return self._add_range(query, predicate) + if isinstance(predicate, PredicateEquality): + return self._add_equality(query, predicate) + if isinstance(predicate, PredicateIn): + return self._add_in(query, predicate) + raise InvalidHistogramTypeError(str(predicate.dtype)) + return query + + def _cast_if_needed( + self, value: SupportedHistogramType, dtype: HistogramDataType + ) -> Any: + """Cast the value to the appropriate type if needed.""" + if dtype == HistogramDataType.DATE: + return fn.Cast(value, "date") + return value + + def _add_range( + self, query: OracleQuery, predicate: PredicateRange + ) -> OracleQuery: + return query.where( + self.table_to_pypika_table[predicate.table][predicate.column] + >= self._cast_if_needed(predicate.min_value, predicate.dtype), + ).where( + self.table_to_pypika_table[predicate.table][predicate.column] + <= self._cast_if_needed(predicate.max_value, predicate.dtype) + ) + + def _add_equality( + self, query: OracleQuery, predicate: PredicateEquality + ) -> OracleQuery: + return query.where( + self.table_to_pypika_table[predicate.table][predicate.column] + == predicate.equality_value + ) + + def _add_in(self, query: OracleQuery, predicate: PredicateIn) -> OracleQuery: + return query.where( + self.table_to_pypika_table[predicate.table][predicate.column].isin( + [self._cast_if_needed(i, predicate.dtype) for i in predicate.in_values] + ) ) class QueryGenerator: def __init__(self, params: QueryGenerationParameters) -> None: @@ -149,18 +147,19 @@ self.params = params self.tables_schema, self.fact_tables = get_schema(params.dataset) self.foreign_key_graph = ForeignKeyGraph(self.tables_schema) self.subgraph_generator = SubGraphGenerator( self.foreign_key_graph, - params.keep_edge_prob, + params.keep_edge_probability, params.max_hops, params.seen_subgraphs, ) self.query_builder = QueryBuilder( self.subgraph_generator, self.tables_schema, params.dataset, + params.predicate_parameters, ) def generate_queries(self) -> Iterator[GeneratedQueryFeatures]: for fact_table in self.fact_tables: for cnt, subgraph in enumerate( @@ -172,12 +171,10 @@ query = self.query_builder.generate_query_from_subgraph(subgraph) for idx in range(1, self.params.max_queries_per_signature + 1): query = self.query_builder.add_predicates( subgraph, query, - self.params.extra_predicates, - self.params.row_retention_probability, ) yield GeneratedQueryFeatures( query=query.get_sql(), template_number=cnt, Index: src/query_generator/join_based_query_generator/utils/subgraph_generator.py ================================================================== --- src/query_generator/join_based_query_generator/utils/subgraph_generator.py +++ src/query_generator/join_based_query_generator/utils/subgraph_generator.py @@ -11,22 +11,23 @@ class SubGraphGenerator: def __init__( self, graph: ForeignKeyGraph, - keep_edge_prob: float, + keep_edge_probability: float, max_hops: int, seen_subgraphs: dict[int, bool], ) -> None: self.hops = max_hops - self.keep_edge_prob = keep_edge_prob + self.keep_edge_probability = keep_edge_probability self.graph = graph self.seen_subgraphs: dict[int, bool] = seen_subgraphs.copy() def get_random_subgraph(self, fact_table: str) -> list[ForeignKeyGraph.Edge]: """Starting from the fact table, for each edge of the current table we - decide based on the keep_edge_probability whether to keep the edge or not. + decide based on the keep_edge_probabilityability whether to keep the + edge or not. We repeat this process up until the maximum number of hops. """ @dataclass @@ -43,11 +44,11 @@ if current_node.depth >= self.hops: continue current_edges = self.graph.get_edges(current_node.table) for current_edge in current_edges: - if random.random() < self.keep_edge_prob: + if random.random() < self.keep_edge_probability: edges_subgraph.append(current_edge) queue.append( JoinDepthNode( current_edge.reference_table.name, current_node.depth + 1, Index: src/query_generator/main.py ================================================================== --- src/query_generator/main.py +++ src/query_generator/main.py @@ -28,170 +28,77 @@ ) from query_generator.utils.definitions import ( Dataset, Extension, QueryGenerationParameters, +) +from query_generator.utils.params import ( + SearchParametersEndpoint, + SnowflakeEndpoint, + read_and_parse_toml, ) from query_generator.utils.show_messages import show_dev_warning from query_generator.utils.utils import validate_file_path app = typer.Typer(name="Query Generation") @app.command() def snowflake( - dataset: Annotated[ - Dataset, - typer.Option("--dataset", "-d", help="The dataset used"), - ], - max_hops: Annotated[ - int, - typer.Option( - "--max-hops", - "-h", - help="The maximum number of hops", - min=1, - max=5, - ), - ] = 3, - max_queries_per_fact_table: Annotated[ - int, - typer.Option( - "--fact", - "-f", - help="The maximum number of queries per fact table", - min=1, - ), - ] = 100, - max_queries_per_signature: Annotated[ - int, - typer.Option( - "--signature", - "-s", - help="The maximum number of queries per signature/template", - min=1, - ), - ] = 1, - keep_edge_prob: Annotated[ - float, - typer.Option( - "--edge-prob", - "-p", - help="The probability of keeping an edge in the subgraph", - min=0.0, - max=1.0, - ), - ] = 0.2, - row_retention_probability: Annotated[ - float, - typer.Option( - "--row-retention", - "-r", - help="The probability of keeping a row in each predicate", - min=0.0, - max=1.0, - ), - ] = 0.2, - extra_predicates: Annotated[ - int, - typer.Option( - "--extra-predicates", - "-e", - help="The number of extra predicates to add to the query", - min=0, - ), - ] = 3, + config_path: Annotated[ + str, + typer.Option( + "-c", + "--config", + help="The path to the configuration file" + "They can be found in the params_config/query_generation/ folder", + ), + ], ) -> None: """Generate queries using a random subgraph.""" + params_endpoint = read_and_parse_toml(Path(config_path), SnowflakeEndpoint) params = QueryGenerationParameters( - dataset=dataset, - max_hops=max_hops, - max_queries_per_fact_table=max_queries_per_fact_table, - max_queries_per_signature=max_queries_per_signature, - keep_edge_prob=keep_edge_prob, - extra_predicates=extra_predicates, - row_retention_probability=row_retention_probability, + dataset=params_endpoint.dataset, + max_hops=params_endpoint.max_hops, + max_queries_per_fact_table=params_endpoint.max_queries_per_fact_table, + max_queries_per_signature=params_endpoint.max_queries_per_signature, + keep_edge_probability=params_endpoint.keep_edge_probability, seen_subgraphs={}, + predicate_parameters=params_endpoint.predicate_parameters, ) generate_and_write_queries(params) @app.command() def param_search( - dataset: Annotated[ - Dataset, - typer.Option("--dataset", "-d", help="The dataset used"), - ], - *, - dev: Annotated[ - bool, - typer.Option( - "--dev", - help="Development testing. If true then uses scale factor 0.1 to check.", - ), - ] = False, - unique_joins: Annotated[ - bool, - typer.Option( - "--unique-joins", - "-u", - help="If true all queries will have a unique join structure " - "(not recommended for TPC-H)", - ), - ] = False, - max_hops_range: Annotated[ - list[int] | None, - typer.Option( - "--max-hops-range", - "-h", - help="The range of hops to use for the query generation", - show_default="1, 2, 4", - ), - ] = None, - extra_predicates_range: Annotated[ - list[int] | None, - typer.Option( - "--extra-predicates-range", - "-e", - help="The range of extra predicates to use for the query generation", - show_default="1, 2, 3, 5", - ), - ] = None, - row_retention_probability_range: Annotated[ - list[float] | None, - typer.Option( - "--row-retention-probability-range", - "-r", - help="The range of row retention probabilities to use " - "for the query generation", - show_default="0.2, 0.3, 0.4, 0.6, 0.8, 0.85, 0.9, 1.0", - ), - ] = None, + config_path: Annotated[ + str, + typer.Option( + "-c", + "--config", + help="The path to the configuration file" + "They can be found in the params_config/search_params/ folder", + ), + ], ) -> None: """This is an extension of the Snowflake algorithm. It runs multiple batches with different configurations of the algorithm. This allows us to get multiple results. """ - if max_hops_range is None: - max_hops_range = [1, 2, 4] - if extra_predicates_range is None: - extra_predicates_range = [1, 2, 3, 5] - if row_retention_probability_range is None: - row_retention_probability_range = [0.2, 0.3, 0.4, 0.6, 0.8, 0.85, 0.9, 1.0] - show_dev_warning(dev=dev) - scale_factor = 0.1 if dev else 100 - con = setup_duckdb(dataset, scale_factor) + params = read_and_parse_toml( + Path(config_path), + SearchParametersEndpoint, + ) + show_dev_warning(dev=params.dev) + scale_factor = 0.1 if params.dev else 100 + con = setup_duckdb(params.dataset, scale_factor) run_snowflake_param_seach( SearchParameters( scale_factor=scale_factor, con=con, - dataset=dataset, - max_hops=max_hops_range, - extra_predicates=extra_predicates_range, - row_retention_probability=row_retention_probability_range, - unique_joins=unique_joins, + user_input=params, ), ) @app.command() Index: src/query_generator/predicate_generator/predicate_generator.py ================================================================== --- src/query_generator/predicate_generator/predicate_generator.py +++ src/query_generator/predicate_generator/predicate_generator.py @@ -1,46 +1,82 @@ import math import random +from abc import ABC from collections.abc import Iterator from dataclasses import dataclass from enum import Enum +import numpy as np import polars as pl -from query_generator.tools.histograms import HistogramColumns -from query_generator.utils.definitions import Dataset +from query_generator.tools.histograms import ( + HistogramColumns, + MostCommonValuesColumns, +) +from query_generator.utils.definitions import ( + Dataset, + PredicateOperatorProbability, + PredicateParameters, +) from query_generator.utils.exceptions import ( InvalidHistogramTypeError, - UnkwonDatasetError, + UnkownDatasetError, ) SupportedHistogramType = float | int | str SuportedHistogramArrayType = list[float] | list[int] | list[str] + +MAX_DISTINCT_COUNT_FOR_RANGE = 500 +PROBABILITY_TO_CHOOSE_EQUALITY = 0.8 +PREDICATE_IN_SIZE = 5 + + +class PredicateTypes(Enum): + IN = "in" + RANGE = "range" + EQUALITY = "equality" + class HistogramDataType(Enum): INT = "int" FLOAT = "float" DATE = "date" STRING = "string" + +@dataclass +class Predicate(ABC): + table: str + column: str + dtype: HistogramDataType + + +@dataclass +class PredicateRange(Predicate): + min_value: SupportedHistogramType + max_value: SupportedHistogramType + + +@dataclass +class PredicateEquality(Predicate): + equality_value: SupportedHistogramType + + +@dataclass +class PredicateIn(Predicate): + in_values: SuportedHistogramArrayType + class PredicateGenerator: - @dataclass - class Predicate: - table: str - column: str - min_value: SupportedHistogramType - max_value: SupportedHistogramType - dtype: HistogramDataType - - def __init__(self, dataset: Dataset): + def __init__(self, dataset: Dataset, predicate_params: PredicateParameters): self.dataset = dataset self.histogram: pl.DataFrame = self.read_histogram() + self.predicate_params = predicate_params - def _parse_bin( - self, hist_array: list[str], dtype: HistogramDataType + def _cast_array( + self, str_array: list[str], dtype: HistogramDataType ) -> SuportedHistogramArrayType: """Parse the bin string representation to a list of values. Args: bin_str (str): String representation of bins. @@ -49,17 +85,30 @@ Returns: list: List of parsed values. """ if dtype == HistogramDataType.INT: - return [int(float(x)) for x in hist_array] + return [int(float(x)) for x in str_array] + if dtype == HistogramDataType.FLOAT: + return [float(x) for x in str_array] + if dtype == HistogramDataType.DATE: + return str_array + if dtype == HistogramDataType.STRING: + return str_array + raise InvalidHistogramTypeError(dtype) + + def _cast_element( + self, value: str, dtype: HistogramDataType + ) -> SupportedHistogramType: + if dtype == HistogramDataType.INT: + return int(float(value)) if dtype == HistogramDataType.FLOAT: - return [float(x) for x in hist_array] + return float(value) if dtype == HistogramDataType.DATE: - return hist_array + return value if dtype == HistogramDataType.STRING: - return hist_array + return value raise InvalidHistogramTypeError(dtype) def read_histogram(self) -> pl.DataFrame: """Read the histogram data for the specified dataset. @@ -75,11 +124,11 @@ elif self.dataset == Dataset.TPCDS: path = "data/histograms/histogram_tpcds.parquet" elif self.dataset == Dataset.JOB: path = "data/histograms/histogram_job.parquet" else: - raise UnkwonDatasetError(self.dataset.value) + raise UnkownDatasetError(self.dataset.value) return pl.read_parquet(path).filter(pl.col("histogram") != []) def _get_histogram_type(self, dtype: str) -> HistogramDataType: if dtype in ["INTEGER", "BIGINT"]: return HistogramDataType.INT @@ -88,55 +137,153 @@ if dtype == "DATE": return HistogramDataType.DATE if dtype == "VARCHAR": return HistogramDataType.STRING raise InvalidHistogramTypeError(dtype) + + def _choose_predicate_type( + self, operator_weights: PredicateOperatorProbability + ) -> PredicateTypes: + weights = [ + operator_weights.operator_equal, + operator_weights.operator_in, + operator_weights.operator_range, + ] + return random.choices( + [ + PredicateTypes.EQUALITY, + PredicateTypes.IN, + PredicateTypes.RANGE, + ], + weights=weights, + )[0] def get_random_predicates( self, tables: list[str], - num_predicates: int, - row_retention_probability: float, - ) -> Iterator["PredicateGenerator.Predicate"]: + ) -> Iterator[Predicate]: """Generate random predicates based on the histogram data. Args: tables (str): List of tables to select predicates from. num_predicates (int): Number of predicates to generate. row_retention_probability (float): Probability of retaining rows. Returns: - List[PredicateGenerator.Predicate]: List of generated predicates. + List[Predicate]: List of generated predicates. """ selected_tables_histogram = self.histogram.filter( pl.col(HistogramColumns.TABLE.value).is_in(tables) ) - for row in selected_tables_histogram.sample(n=num_predicates).iter_rows( - named=True - ): + for row in selected_tables_histogram.sample( + n=self.predicate_params.extra_predicates + ).iter_rows(named=True): table = row[HistogramColumns.TABLE.value] column = row[HistogramColumns.COLUMN.value] - bins = row[HistogramColumns.HISTOGRAM.value] - dtype = self._get_histogram_type(row[HistogramColumns.DTYPE.value]) - min_value, max_value = self._get_min_max_from_bins( - bins, row_retention_probability, dtype - ) - predicate = PredicateGenerator.Predicate( - table=table, - column=column, - min_value=min_value, - max_value=max_value, - dtype=dtype, - ) - yield predicate + dtype = self._get_histogram_type(row[HistogramColumns.DTYPE.value]) + predicate_type = self._choose_predicate_type( + self.predicate_params.operator_weights + ) + + if predicate_type == PredicateTypes.RANGE: + yield self._get_range_predicate( + table, column, row[HistogramColumns.HISTOGRAM.value], dtype + ) + elif predicate_type == PredicateTypes.IN: + array = self._get_in_array( + row[HistogramColumns.MOST_COMMON_VALUES.value], + row[HistogramColumns.TABLE_SIZE.value], + row[HistogramColumns.HISTOGRAM_MCV.value], + ) + if array is not None: + yield self._get_in_predicate(array, table, column, dtype) + else: + continue + elif predicate_type == PredicateTypes.EQUALITY: + value = self._get_equality_value( + row[HistogramColumns.MOST_COMMON_VALUES.value], + row[HistogramColumns.TABLE_SIZE.value], + ) + if value is not None: + yield self._get_equality_predicate(value, table, column, dtype) + else: + continue + + def _get_in_predicate( + self, array: list[str], table: str, column: str, dtype: HistogramDataType + ) -> PredicateIn: + cast_array = self._cast_array(array, dtype) + return PredicateIn(table, column, dtype, cast_array) + + def _get_in_array( + self, + most_common_values: list[dict[str, int | str]], + table_size: int, + histogram: list[str], + ) -> list[str] | None: + """ + Gets the array for the IN operator + """ + value = self._get_equality_value(most_common_values, table_size) + if value is None: + return None + noise_values = random.sample( + histogram, + k=min(self.predicate_params.extra_values_for_in, len(histogram)), + ) + return [value] + noise_values + + def _get_equality_predicate( + self, value: str, table: str, column: str, dtype: HistogramDataType + ) -> PredicateEquality: + cast_value = self._cast_element(value, dtype) + return PredicateEquality( + table=table, column=column, dtype=dtype, equality_value=cast_value + ) + + def _get_equality_value( + self, + most_common_values: list[dict[str, int | str]], + table_size: int, + ) -> str | None: + mcv_probabilities: list[float] = [ + float(table_size) / float(v[MostCommonValuesColumns.COUNT.value]) + for v in most_common_values + ] + mcv_probabilities_np = np.array(mcv_probabilities) + filtered_indices = np.where( + mcv_probabilities_np + > self.predicate_params.equality_lower_bound_probability + )[0] + if len(filtered_indices) == 0: + return None + idx = random.choice(filtered_indices) + value = most_common_values[idx][MostCommonValuesColumns.VALUE.value] + assert isinstance(value, str) + return value + + def _get_range_predicate( + self, + table: str, + column: str, + bins: list[str], + dtype: HistogramDataType, + ) -> PredicateRange: + min_value, max_value = self._get_min_max_from_bins(bins, dtype) + return PredicateRange( + table=table, + column=column, + min_value=min_value, + max_value=max_value, + dtype=dtype, + ) def _get_min_max_from_bins( self, bins: list[str], - row_retention_probability: float, dtype: HistogramDataType, ) -> tuple[SupportedHistogramType, SupportedHistogramType]: """Convert the bins string representation to a tuple of min and max values. Args: @@ -145,16 +292,16 @@ Returns: tuple: Tuple containing min and max values. """ - histogram_array: SuportedHistogramArrayType = self._parse_bin(bins, dtype) + histogram_array: SuportedHistogramArrayType = self._cast_array(bins, dtype) subrange_length = math.ceil( - row_retention_probability * len(histogram_array) + self.predicate_params.row_retention_probability * len(histogram_array) ) start_index = random.randint(0, len(histogram_array) - subrange_length) min_value = histogram_array[start_index] max_value = histogram_array[ min(start_index + subrange_length, len(histogram_array) - 1) ] return min_value, max_value Index: src/query_generator/tools/histograms.py ================================================================== --- src/query_generator/tools/histograms.py +++ src/query_generator/tools/histograms.py @@ -18,11 +18,14 @@ get_histogram_excluding_common_values, get_tables, ) from query_generator.utils.exceptions import InvalidHistogramTypeError -LIMIT_FOR_DISTINCT_VALUES = 1000 + +class MostCommonValuesColumns(Enum): + VALUE = "value" + COUNT = "count" class RedundantHistogramsDataType(Enum): """ This class was made for compatibility with old code that @@ -89,16 +92,12 @@ def get_most_common_values( con: duckdb.DuckDBPyConnection, table: str, column: str, common_value_size: int, - distinct_count: int, ) -> list[RawDuckDBMostCommonValues]: - result: list[RawDuckDBMostCommonValues] = [] - if distinct_count < LIMIT_FOR_DISTINCT_VALUES: - result = get_frequent_non_null_values(con, table, column, common_value_size) - return result + return get_frequent_non_null_values(con, table, column, common_value_size) def get_histogram_array(histogram_params: HistogramParams) -> list[str]: histogram_raw = get_equi_height_histogram( histogram_params.con, @@ -116,14 +115,11 @@ histogram_params: HistogramParams, common_values_size: int, distinct_count: int, ) -> list[str]: histogram_array: list[RawDuckDBHistograms] = [] - if ( - distinct_count < LIMIT_FOR_DISTINCT_VALUES - and distinct_count > common_values_size - ): + if distinct_count > common_values_size: histogram_array = get_histogram_excluding_common_values( histogram_params.con, histogram_params.table, histogram_params.column.column_name, histogram_params.histogram_size, @@ -182,11 +178,10 @@ most_common_values = get_most_common_values( con, table, column.column_name, common_values_size, - distinct_count, ) # Get histogram array excluding common values histogram_array_excluding_mcv = ( get_histogram_array_excluding_common_values( @@ -196,11 +191,14 @@ ) ) row_dict |= { HistogramColumns.MOST_COMMON_VALUES.value: [ - {"value": value.value, "count": value.count} + { + MostCommonValuesColumns.VALUE.value: value.value, + MostCommonValuesColumns.COUNT.value: value.count, + } for value in most_common_values ], HistogramColumns.HISTOGRAM_MCV.value: histogram_array_excluding_mcv, } Index: src/query_generator/utils/definitions.py ================================================================== --- src/query_generator/utils/definitions.py +++ src/query_generator/utils/definitions.py @@ -17,19 +17,40 @@ TPCH = "TPCH" JOB = "JOB" @dataclass +class PredicateOperatorProbability: + """Probability of using a specific predicate operator. + + They are based on choice with weights for each operator. + """ + + operator_in: float + operator_equal: float + operator_range: float + + +@dataclass +class PredicateParameters: + extra_predicates: int + row_retention_probability: float + operator_weights: PredicateOperatorProbability + equality_lower_bound_probability: float + extra_values_for_in: int + + +# TODO(Gabriel): http://localhost:8080/tktview/205e90a1fa +@dataclass class QueryGenerationParameters: + dataset: Dataset max_hops: int max_queries_per_signature: int max_queries_per_fact_table: int - keep_edge_prob: float - dataset: Dataset - extra_predicates: int - row_retention_probability: float + keep_edge_probability: float seen_subgraphs: dict[int, bool] + predicate_parameters: PredicateParameters @dataclass class GeneratedQueryFeatures: query: str Index: src/query_generator/utils/exceptions.py ================================================================== --- src/query_generator/utils/exceptions.py +++ src/query_generator/utils/exceptions.py @@ -25,11 +25,11 @@ class DuplicateEdgesError(Exception): def __init__(self, table: str) -> None: super().__init__(f"Duplicate edges found for table {table}.") -class UnkwonDatasetError(Exception): +class UnkownDatasetError(Exception): def __init__(self, dataset: str) -> None: super().__init__(f"Unknown dataset: {dataset}") class MissingScaleFactorError(Exception): ADDED src/query_generator/utils/params.py Index: src/query_generator/utils/params.py ================================================================== --- /dev/null +++ src/query_generator/utils/params.py @@ -0,0 +1,102 @@ +import tomllib +from dataclasses import dataclass +from pathlib import Path +from typing import TypeVar + +from cattrs import structure + +from query_generator.utils.definitions import ( + Dataset, + PredicateOperatorProbability, + PredicateParameters, +) + + +@dataclass +class SearchParametersEndpoint: + """ + Represents the parameters used for configuring search queries, including + query builder, subgraph, and predicate options. + + This class is designed to support both the `IN` and `=` statements in + query generation. + + Attributes: + dataset (Dataset): The dataset to be queried. + dev (bool): Flag indicating whether to use development settings. + max_queries_per_fact_table (int): Maximum number of queries per fact + table. + max_queries_per_signature (int): Maximum number of queries per + signature. + unique_joins (bool): Whether to enforce unique joins in the subgraph. + max_hops (list[int]): Maximum number of hops allowed in the subgraph. + keep_edge_probability (float): Probability of retaining an edge in the + subgraph. + extra_predicates (list[int]): Number of additional predicates to include + in the query. + row_retention_probability (list[float]): Probability of retaining a row + for range predicates + operator_weights (PredicateOperatorProbability): Probability + distribution for predicate operators. + equality_lower_bound_probability (float): Lower bound probability when + using the `=` and the `IN` operators + """ + + # Query Builder + dataset: Dataset + dev: bool + max_queries_per_fact_table: int + max_queries_per_signature: int + # Subgraph + unique_joins: bool + max_hops: list[int] + keep_edge_probability: float + # Predicates + extra_predicates: list[int] + row_retention_probability: list[float] + operator_weights: PredicateOperatorProbability + equality_lower_bound_probability: list[float] + extra_values_for_in: int + + +@dataclass +class SnowflakeEndpoint: + """ + Represents the parameters used for configuring query generation, + including query builder, subgraph, and predicate options. + + Attributes: + dataset (Dataset): The dataset to be used for query generation. + max_queries_per_signature (int): Maximum number of queries to generate + per signature. + max_queries_per_fact_table (int): Maximum number of queries to generate + per fact table. + max_hops (int): Maximum number of hops allowed in the subgraph. + keep_edge_probability (float): Probability of retaining an edge in the + subgraph. + extra_predicates (int): Number of extra predicates to add to the query. + row_retention_probability (float): Probability of retaining a row after + applying predicates. + operator_weights (PredicateOperatorProbability): Probability + distribution for predicate operators. + equality_lower_bound_probability (float): Probability of using a lower + bound for equality predicates. + """ + + # Query builder + dataset: Dataset + max_queries_per_signature: int + max_queries_per_fact_table: int + # Subgraph + max_hops: int + keep_edge_probability: float + # Predicates + predicate_parameters: PredicateParameters + + +T = TypeVar("T") + + +def read_and_parse_toml(path: Path, cls: type[T]) -> T: + toml_dict = tomllib.loads(path.read_text()) + return structure(toml_dict, cls) Index: tests/duckdb/test_binning.py ================================================================== --- tests/duckdb/test_binning.py +++ tests/duckdb/test_binning.py @@ -1,16 +1,18 @@ +import tomllib from unittest import mock +from cattrs import structure import polars as pl import pytest from query_generator.duckdb_connection.binning import ( SearchParameters, run_snowflake_param_seach, ) from query_generator.tools.cherry_pick_binning import make_bins_in_csv -from query_generator.utils.definitions import Dataset +from query_generator.utils.params import SearchParametersEndpoint @pytest.mark.parametrize( "count_star, upper_bound, total_bins, expected_bin", [ @@ -42,15 +44,15 @@ @pytest.mark.parametrize( "extra_predicates, expected_call_count, unique_joins", [ - ([1], 120 * 1 + 14, False), - ([1], 120 * 1 + 14, True), + ("[1]", 120 * 1 + 14, "false"), + ("[1]", 120 * 1 + 14, "true"), # Inventory is small and prooduces 14 queries total - ([1, 2], 120 * 2 + 14, True), - ([1, 2], 120 * 2 + 14 * 2, False), + ("[1, 2]", 120 * 2 + 14, "true"), + ("[1, 2]", 120 * 2 + 14 * 2, "false"), ], ) def test_binning_calls(extra_predicates, expected_call_count, unique_joins): with mock.patch( "query_generator.duckdb_connection.binning.Writer.write_query_to_batch", @@ -57,20 +59,35 @@ ) as mock_writer: with mock.patch( "query_generator.duckdb_connection.binning.get_result_from_duckdb", ) as mock_connect: mock_connect.return_value = 0 + data_toml = f""" + dataset = "TPCDS" + dev = true + max_hops = [1] + extra_predicates = {extra_predicates} + row_retention_probability = [0.2] + unique_joins = {unique_joins} + max_queries_per_fact_table = 10 + max_queries_per_signature = 2 + keep_edge_probability = 0.2 + equality_lower_bound_probability = [0] + extra_values_for_in = 3 + + [operator_weights] + operator_in = 1 + operator_range = 3 + operator_equal = 3 + """ + user_input = structure(tomllib.loads(data_toml), SearchParametersEndpoint) run_snowflake_param_seach( search_params=SearchParameters( scale_factor=0, - dataset=Dataset.TPCDS, - max_hops=[1], - extra_predicates=extra_predicates, - row_retention_probability=[0.2], con=None, - unique_joins=unique_joins, + user_input=user_input, ), ) assert mock_writer.call_count == expected_call_count, ( f"Expected {expected_call_count} calls to write_query, " f"but got {mock_writer.call_count}" ) Index: tests/duckdb/test_duckdb_utils.py ================================================================== --- tests/duckdb/test_duckdb_utils.py +++ tests/duckdb/test_duckdb_utils.py @@ -1,15 +1,16 @@ +import datetime + from query_generator.duckdb_connection.setup import setup_duckdb from query_generator.duckdb_connection.utils import ( get_distinct_count, get_equi_height_histogram, get_frequent_non_null_values, ) from query_generator.tools.histograms import DuckDBHistogramParser from query_generator.utils.definitions import Dataset -from tests.utils import is_date, is_float -import datetime +from tests.utils import is_float def test_distinct_values(): """Test the setup of DuckDB.""" # Setup DuckDB Index: tests/file_management/test_read_histograms.py ================================================================== --- tests/file_management/test_read_histograms.py +++ tests/file_management/test_read_histograms.py @@ -1,22 +1,22 @@ from unittest import mock +import polars as pl import pytest -import polars as pl from query_generator.predicate_generator.predicate_generator import ( HistogramDataType, PredicateGenerator, ) from query_generator.tools.histograms import HistogramColumns -from query_generator.utils.definitions import Dataset +from query_generator.utils.definitions import Dataset, PredicateParameters from query_generator.utils.exceptions import InvalidHistogramTypeError def test_read_histograms(): for dataset in Dataset: - predicate_generator = PredicateGenerator(dataset) + predicate_generator = PredicateGenerator(dataset, None) histogram = predicate_generator.read_histogram() assert not histogram.is_empty() assert histogram[HistogramColumns.DTYPE.value].dtype == pl.Utf8 assert histogram[HistogramColumns.COLUMN.value].dtype == pl.Utf8 @@ -76,20 +76,29 @@ ): with mock.patch( "query_generator.predicate_generator.predicate_generator.random.randint", return_value=mock_rand, ): - predicate_generator = PredicateGenerator(Dataset.TPCH) + predicate_generator = PredicateGenerator( + Dataset.TPCH, + PredicateParameters( + extra_predicates=None, + row_retention_probability=row_retention_probability, + operator_weights=None, + equality_lower_bound_probability=None, + extra_values_for_in=None, + ), + ) min_value, max_value = predicate_generator._get_min_max_from_bins( - bins_array, row_retention_probability, dtype + bins_array, dtype ) assert min_value == bins_array[min_index] assert max_value == bins_array[max_index] def test_get_invalid_histogram_type(): - predicate_generator = PredicateGenerator(Dataset.TPCH) + predicate_generator = PredicateGenerator(Dataset.TPCH, None) with pytest.raises(InvalidHistogramTypeError): predicate_generator._get_histogram_type("not_supported_type") @pytest.mark.parametrize( @@ -102,7 +111,7 @@ ("DATE", HistogramDataType.DATE), ("VARCHAR", HistogramDataType.STRING), ], ) def test_get_valid_histogram_type(input_type, expected_type): - predicate_generator = PredicateGenerator(Dataset.TPCH) + predicate_generator = PredicateGenerator(Dataset.TPCH, None) assert predicate_generator._get_histogram_type(input_type) == expected_type Index: tests/query_generation/test_make_queries.py ================================================================== --- tests/query_generation/test_make_queries.py +++ tests/query_generation/test_make_queries.py @@ -1,26 +1,28 @@ from unittest import mock import pytest -from pypika import functions as fn from query_generator.database_schemas.schemas import get_schema from query_generator.join_based_query_generator.snowflake import ( QueryBuilder, generate_and_write_queries, ) from query_generator.predicate_generator.predicate_generator import ( HistogramDataType, - PredicateGenerator, + PredicateRange, ) from query_generator.utils.definitions import ( Dataset, + PredicateOperatorProbability, + PredicateParameters, QueryGenerationParameters, ) -from query_generator.utils.exceptions import UnkwonDatasetError +from query_generator.utils.exceptions import UnkownDatasetError from pypika import OracleQuery +from pypika import functions as fn def test_tpch_query_generation(): with mock.patch( "query_generator.join_based_query_generator.snowflake.Writer.write_query", @@ -29,15 +31,24 @@ QueryGenerationParameters( dataset=Dataset.TPCDS, max_hops=1, max_queries_per_fact_table=1, max_queries_per_signature=1, - keep_edge_prob=0.2, - row_retention_probability=0.2, - extra_predicates=1, + keep_edge_probability=0.2, seen_subgraphs={}, - ), + predicate_parameters=PredicateParameters( + operator_weights=PredicateOperatorProbability( + operator_in=0.4, + operator_equal=0.4, + operator_range=0.2, + ), + extra_predicates=1, + row_retention_probability=0.2, + equality_lower_bound_probability=0, + extra_values_for_in=3, + ), + ) ) assert mock_writer.call_count > 5 @@ -49,14 +60,23 @@ QueryGenerationParameters( dataset=Dataset.TPCDS, max_hops=1, max_queries_per_fact_table=1, max_queries_per_signature=1, - keep_edge_prob=0.2, - row_retention_probability=0.2, - extra_predicates=1, + keep_edge_probability=0.2, seen_subgraphs={}, + predicate_parameters=PredicateParameters( + operator_weights=PredicateOperatorProbability( + operator_in=0.4, + operator_equal=0.4, + operator_range=0.2, + ), + extra_predicates=1, + row_retention_probability=0.2, + equality_lower_bound_probability=0, + extra_values_for_in=3, + ), ), ) assert mock_writer.call_count > 5 @@ -63,37 +83,57 @@ def test_non_implemented_dataset(): with mock.patch( "query_generator.join_based_query_generator.snowflake.Writer.write_query", ) as mock_writer: - with pytest.raises(UnkwonDatasetError): + with pytest.raises(UnkownDatasetError): generate_and_write_queries( QueryGenerationParameters( dataset="non_implemented_dataset", max_hops=1, max_queries_per_fact_table=1, max_queries_per_signature=1, - keep_edge_prob=0.2, - row_retention_probability=0.2, - extra_predicates=1, + keep_edge_probability=0.2, seen_subgraphs={}, + predicate_parameters=PredicateParameters( + operator_weights=PredicateOperatorProbability( + operator_in=0.4, + operator_equal=0.4, + operator_range=0.2, + ), + extra_predicates=1, + row_retention_probability=0.2, + equality_lower_bound_probability=0, + extra_values_for_in=3, + ), ), ) assert mock_writer.call_count == 0 def test_add_rage_supports_all_histogram_types(): tables_schema, _ = get_schema(Dataset.TPCH) - query_builder = QueryBuilder(None, tables_schema, Dataset.TPCH) + query_builder = QueryBuilder( + None, + tables_schema, + Dataset.TPCH, + PredicateParameters( + extra_predicates=None, + row_retention_probability=0.2, + operator_weights=None, + equality_lower_bound_probability=None, + extra_values_for_in=None, + ), + ) for dtype in HistogramDataType: query_builder._add_range( OracleQuery() .from_(query_builder.table_to_pypika_table["lineitem"]) .select(fn.Count("*")), - PredicateGenerator.Predicate( + PredicateRange( table="lineitem", column="foo", min_value=2020, max_value=2020, dtype=dtype, ), )