Many hyperlinks are disabled.
Use anonymous login
to enable hyperlinks.
Changes In Branch new-predicates Excluding Merge-Ins
This is equivalent to a diff from 0f0856db5a to 4162607284
2025-05-28
| ||
09:27 | Merges IN and = predicate check-in: 0a0518ab14 user: mathos tags: trunk | |
09:24 | Adds equality lower bound as an array. Leaf check-in: 4162607284 user: mathos tags: new-predicates | |
00:05 | Finally. An stable version check-in: 48e17f1cde user: mathos tags: new-predicates | |
2025-05-26
| ||
11:22 | Minor refactor to predicate class. Ticket [1e726428f6e719fb] check-in: c93b2b766c user: mathos tags: new-predicates | |
11:12 | Fix to be able to save the CSV check-in: 0f0856db5a user: mathos tags: trunk | |
09:57 | adds table_size to histogram check-in: 288ba9b582 user: mathos tags: trunk | |
Changes to data/histograms/histogram_tpcds.parquet.
cannot compute difference between binary files
Added params_config/search_params/tpcds.toml.
> > > > > > > > > > > > > > > > | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | dataset = "TPCDS" dev = true max_hops = [1,2,4] extra_predicates = [1,3,5] row_retention_probability = [0.2, 0.3, 0.4, 0.6, 0.8, 0.85, 0.9, 1.0] unique_joins = true max_queries_per_fact_table = 10 max_queries_per_signature = 2 keep_edge_probability = 0.2 equality_lower_bound_probability = [0,0.1] extra_values_for_in = 3 [operator_weights] operator_in = 1 operator_range = 3 operator_equal = 3 |
Added params_config/search_params/tpcds_dev.toml.
> > > > > > > > > > > > > > > > | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | dataset = "TPCDS" dev = true max_hops = [1] extra_predicates = [5] row_retention_probability = [0.2, 0.9] unique_joins = true max_queries_per_fact_table = 1 max_queries_per_signature = 2 keep_edge_probability = 0.2 equality_lower_bound_probability = [0,0.1] extra_values_for_in = 3 [operator_weights] operator_in = 1 operator_range = 3 operator_equal = 3 |
Added params_config/snowflake/tpcds.toml.
> > > > > > > > > > > > > > > > > > | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | dataset = "TPCDS" max_hops = 3 max_queries_per_fact_table = 100 max_queries_per_signature = 1 keep_edge_probability = 0.2 [predicate_parameters] row_retention_probability = 0.2 extra_predicates = 3 equality_lower_bound_probability = 0.00 extra_values_for_in = 3 [predicate_parameters.operator_weights] operator_in = 1 operator_range = 3 operator_equal = 3 |
Changes to pyproject.toml.
︙ | ︙ | |||
26 27 28 29 30 31 32 33 34 35 36 37 38 39 | typer = ">=0.15.2,<0.16" rich = ">=14.0.0,<15" pypika = ">=0.48.9,<0.49" numpy = ">=2.2.5,<3" duckdb = ">=1.2.2,<2" polars = ">=1.27.1,<2" tqdm = "*" [tool.pixi.feature.test.dependencies] pytest = ">=8.3.5,<9" [tool.pixi.feature.lint.dependencies] ruff = ">=0.11.7,<0.12" | > | 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | typer = ">=0.15.2,<0.16" rich = ">=14.0.0,<15" pypika = ">=0.48.9,<0.49" numpy = ">=2.2.5,<3" duckdb = ">=1.2.2,<2" polars = ">=1.27.1,<2" tqdm = "*" cattrs = ">=24.1.2,<25" [tool.pixi.feature.test.dependencies] pytest = ">=8.3.5,<9" [tool.pixi.feature.lint.dependencies] ruff = ">=0.11.7,<0.12" |
︙ | ︙ |
Changes to src/query_generator/database_schemas/schemas.py.
1 2 3 4 5 6 7 | from typing import Any from query_generator.database_schemas.tpcds import get_tpcds_table_info from query_generator.database_schemas.tpch import get_tpch_table_info from query_generator.utils.definitions import Dataset from query_generator.utils.exceptions import ( PartiallySupportedDatasetError, | | | | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | from typing import Any from query_generator.database_schemas.tpcds import get_tpcds_table_info from query_generator.database_schemas.tpch import get_tpch_table_info from query_generator.utils.definitions import Dataset from query_generator.utils.exceptions import ( PartiallySupportedDatasetError, UnkownDatasetError, ) def get_schema(dataset: Dataset) -> tuple[dict[str, dict[str, Any]], list[str]]: """Get the schema of the database based on the dataset. Args: dataset (Dataset): The dataset to get the schema for. Returns: Tuple[Dict[str, Dict[str, Any]], List[str]]: A tuple containing the schema as a dictionary and a list of fact tables """ if dataset == Dataset.TPCDS: return get_tpcds_table_info() if dataset == Dataset.TPCH: return get_tpch_table_info() if dataset == Dataset.JOB: raise PartiallySupportedDatasetError(dataset.value) raise UnkownDatasetError(dataset) |
Changes to src/query_generator/database_schemas/tpcds.py.
︙ | ︙ | |||
436 437 438 439 440 441 442 | "s_floor_space": {"max": 9917607, "min": 5010719}, "s_gmt_offset": {"max": -5.0, "min": -6.0}, "s_market_id": {"max": 10, "min": 1}, "s_number_employees": {"max": 300, "min": 200}, "s_rec_end_date": {"max": "2001-03-12", "min": "1999-03-13"}, "s_rec_start_date": {"max": "2001-03-13", "min": "1997-03-13"}, "s_store_sk": {"max": 402, "min": 1}, | < | 436 437 438 439 440 441 442 443 444 445 446 447 448 449 | "s_floor_space": {"max": 9917607, "min": 5010719}, "s_gmt_offset": {"max": -5.0, "min": -6.0}, "s_market_id": {"max": 10, "min": 1}, "s_number_employees": {"max": 300, "min": 200}, "s_rec_end_date": {"max": "2001-03-12", "min": "1999-03-13"}, "s_rec_start_date": {"max": "2001-03-13", "min": "1997-03-13"}, "s_store_sk": {"max": 402, "min": 1}, }, "foreign_keys": [], }, "store_returns": { "alias": "sr", "columns": { "sr_addr_sk": {"max": 1000000, "min": 1}, |
︙ | ︙ |
Changes to src/query_generator/duckdb_connection/binning.py.
︙ | ︙ | |||
9 10 11 12 13 14 15 | QueryGenerator, ) from query_generator.join_based_query_generator.utils.query_writer import ( Writer, ) from query_generator.utils.definitions import ( BatchGeneratedQueryFeatures, | < > > | < < < < | > | | > > > | > > | | | > | | | | > > | | < > > > > | 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | QueryGenerator, ) from query_generator.join_based_query_generator.utils.query_writer import ( Writer, ) from query_generator.utils.definitions import ( BatchGeneratedQueryFeatures, Extension, PredicateParameters, QueryGenerationParameters, ) from query_generator.utils.params import SearchParametersEndpoint @dataclass class SearchParameters: user_input: SearchParametersEndpoint scale_factor: int | float con: duckdb.DuckDBPyConnection def get_result_from_duckdb(query: str, con: duckdb.DuckDBPyConnection) -> int: try: result = int(con.sql(query).fetchall()[0][0]) except duckdb.BinderException as e: print(f"Invalid query, exception: {e},\n{query}") return -1 return result def get_total_iterations(search_params: SearchParametersEndpoint) -> int: """Get the total number of iterations for the Snowflake binning process. Args: search_params (SearchParameters): The parameters for the Snowflake binning process. Returns: int: The total number of iterations. """ return ( len(search_params.max_hops) * len(search_params.extra_predicates) * len(search_params.row_retention_probability) * len(search_params.equality_lower_bound_probability) ) def run_snowflake_param_seach( search_params: SearchParameters, ) -> None: """Run the Snowflake binning process. Binning is equiwidth binning. Args: parameters (BinningSnowflakeParameters): The parameters for the Snowflake binning process. """ query_writer = Writer( search_params.user_input.dataset, Extension.SNOWFLAKE_SEARCH_PARAMS, ) rows: list[dict[str, str | int | float]] = [] total_iterations = get_total_iterations(search_params.user_input) batch_number = 0 seen_subgraphs: dict[int, bool] = {} for ( max_hops, extra_predicates, row_retention_probability, equality_lower_bound_probability, ) in tqdm( product( search_params.user_input.max_hops, search_params.user_input.extra_predicates, search_params.user_input.row_retention_probability, search_params.user_input.equality_lower_bound_probability, ), total=total_iterations, desc="Progress", ): batch_number += 1 query_generator = QueryGenerator( QueryGenerationParameters( dataset=search_params.user_input.dataset, max_hops=max_hops, max_queries_per_fact_table=search_params.user_input.max_queries_per_fact_table, max_queries_per_signature=search_params.user_input.max_queries_per_signature, keep_edge_probability=search_params.user_input.keep_edge_probability, seen_subgraphs=seen_subgraphs, predicate_parameters=PredicateParameters( extra_predicates=extra_predicates, row_retention_probability=row_retention_probability, operator_weights=search_params.user_input.operator_weights, equality_lower_bound_probability=equality_lower_bound_probability, extra_values_for_in=search_params.user_input.extra_values_for_in, ), ) ) for query in query_generator.generate_queries(): selected_rows = get_result_from_duckdb(query.query, search_params.con) if selected_rows == -1: continue # invalid query |
︙ | ︙ | |||
124 125 126 127 128 129 130 | "predicate_number": query.predicate_number, "fact_table": query.fact_table, "max_hops": max_hops, "row_retention_probability": row_retention_probability, }, ) # Update the seen subgraphs with the new ones | | | 133 134 135 136 137 138 139 140 141 142 143 | "predicate_number": query.predicate_number, "fact_table": query.fact_table, "max_hops": max_hops, "row_retention_probability": row_retention_probability, }, ) # Update the seen subgraphs with the new ones if search_params.user_input.unique_joins: seen_subgraphs = query_generator.subgraph_generator.seen_subgraphs df_queries = pl.DataFrame(rows) query_writer.write_dataframe(df_queries) |
Changes to src/query_generator/duckdb_connection/setup.py.
1 2 3 4 5 6 7 8 | import os import duckdb from query_generator.utils.definitions import Dataset from query_generator.utils.exceptions import ( MissingScaleFactorError, PartiallySupportedDatasetError, | | | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | import os import duckdb from query_generator.utils.definitions import Dataset from query_generator.utils.exceptions import ( MissingScaleFactorError, PartiallySupportedDatasetError, UnkownDatasetError, ) def load_and_install_libraries() -> None: duckdb.install_extension("TPCDS") duckdb.install_extension("TPCH") duckdb.load_extension("TPCDS") |
︙ | ︙ | |||
25 26 27 28 29 30 31 | if dataset == Dataset.TPCDS: con.execute(f"CALL dsdgen(sf = {scale_factor})") elif dataset == Dataset.TPCH: con.execute(f"CALL dbgen(sf = {scale_factor})") elif dataset == Dataset.JOB: raise PartiallySupportedDatasetError(dataset.value) else: | | | | 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | if dataset == Dataset.TPCDS: con.execute(f"CALL dsdgen(sf = {scale_factor})") elif dataset == Dataset.TPCH: con.execute(f"CALL dbgen(sf = {scale_factor})") elif dataset == Dataset.JOB: raise PartiallySupportedDatasetError(dataset.value) else: raise UnkownDatasetError(dataset) def get_path( dataset: Dataset, scale_factor: float | int | None, ) -> str: if dataset in [Dataset.TPCDS, Dataset.TPCH]: return f"data/duckdb/{dataset.value}/{scale_factor}.db" if dataset == Dataset.JOB: return f"data/duckdb/{dataset.value}/job.db" raise UnkownDatasetError(dataset.value) def setup_duckdb( dataset: Dataset, scale_factor: int | float | None = None, ) -> duckdb.DuckDBPyConnection: """Installs TPCDS and TPCH datasets in DuckDB. |
︙ | ︙ |
Changes to src/query_generator/join_based_query_generator/snowflake.py.
︙ | ︙ | |||
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | # fmt: on from query_generator.join_based_query_generator.utils.query_writer import ( Writer, ) from query_generator.predicate_generator.predicate_generator import ( HistogramDataType, PredicateGenerator, ) from query_generator.utils.definitions import ( Dataset, Extension, GeneratedQueryFeatures, QueryGenerationParameters, ) from query_generator.utils.exceptions import InvalidHistogramTypeError from query_generator.utils.utils import set_seed class QueryBuilder: def __init__( self, subgraph_generator: SubGraphGenerator, # TODO(Gabriel): http://localhost:8080/tktview/b9400c203a38f3aef46ec250d98563638ba7988b tables_schema: Any, dataset: Dataset, ) -> None: self.sub_graph_gen = subgraph_generator self.table_to_pypika_table = { i: Table(i, alias=tables_schema[i]["alias"]) for i in tables_schema } | > > > > > > | | 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | # fmt: on from query_generator.join_based_query_generator.utils.query_writer import ( Writer, ) from query_generator.predicate_generator.predicate_generator import ( HistogramDataType, PredicateEquality, PredicateGenerator, PredicateIn, PredicateRange, SupportedHistogramType, ) from query_generator.utils.definitions import ( Dataset, Extension, GeneratedQueryFeatures, PredicateParameters, QueryGenerationParameters, ) from query_generator.utils.exceptions import InvalidHistogramTypeError from query_generator.utils.utils import set_seed class QueryBuilder: def __init__( self, subgraph_generator: SubGraphGenerator, # TODO(Gabriel): http://localhost:8080/tktview/b9400c203a38f3aef46ec250d98563638ba7988b tables_schema: Any, dataset: Dataset, predicate_params: PredicateParameters, ) -> None: self.sub_graph_gen = subgraph_generator self.table_to_pypika_table = { i: Table(i, alias=tables_schema[i]["alias"]) for i in tables_schema } self.predicate_gen = PredicateGenerator(dataset, predicate_params) self.tables_schema = tables_schema def get_subgraph_tables( self, subgraph: list[ForeignKeyGraph.Edge], ) -> list[str]: return list( |
︙ | ︙ | |||
82 83 84 85 86 87 88 | ) return query def add_predicates( self, subgraph: list[ForeignKeyGraph.Edge], query: OracleQuery, | < < < < > | > > > > > | | | < < > | | < | < | | < < < > > > | | | | < < < < | < | | < < < > | > < < | 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | ) return query def add_predicates( self, subgraph: list[ForeignKeyGraph.Edge], query: OracleQuery, ) -> OracleQuery: subgraph_tables = self.get_subgraph_tables(subgraph) for predicate in self.predicate_gen.get_random_predicates( subgraph_tables, ): if isinstance(predicate, PredicateRange): return self._add_range(query, predicate) if isinstance(predicate, PredicateEquality): return self._add_equality(query, predicate) if isinstance(predicate, PredicateIn): return self._add_in(query, predicate) raise InvalidHistogramTypeError(str(predicate.dtype)) return query def _cast_if_needed( self, value: SupportedHistogramType, dtype: HistogramDataType ) -> Any: """Cast the value to the appropriate type if needed.""" if dtype == HistogramDataType.DATE: return fn.Cast(value, "date") return value def _add_range( self, query: OracleQuery, predicate: PredicateRange ) -> OracleQuery: return query.where( self.table_to_pypika_table[predicate.table][predicate.column] >= self._cast_if_needed(predicate.min_value, predicate.dtype), ).where( self.table_to_pypika_table[predicate.table][predicate.column] <= self._cast_if_needed(predicate.max_value, predicate.dtype) ) def _add_equality( self, query: OracleQuery, predicate: PredicateEquality ) -> OracleQuery: return query.where( self.table_to_pypika_table[predicate.table][predicate.column] == predicate.equality_value ) def _add_in(self, query: OracleQuery, predicate: PredicateIn) -> OracleQuery: return query.where( self.table_to_pypika_table[predicate.table][predicate.column].isin( [self._cast_if_needed(i, predicate.dtype) for i in predicate.in_values] ) ) class QueryGenerator: def __init__(self, params: QueryGenerationParameters) -> None: set_seed() self.params = params self.tables_schema, self.fact_tables = get_schema(params.dataset) self.foreign_key_graph = ForeignKeyGraph(self.tables_schema) self.subgraph_generator = SubGraphGenerator( self.foreign_key_graph, params.keep_edge_probability, params.max_hops, params.seen_subgraphs, ) self.query_builder = QueryBuilder( self.subgraph_generator, self.tables_schema, params.dataset, params.predicate_parameters, ) def generate_queries(self) -> Iterator[GeneratedQueryFeatures]: for fact_table in self.fact_tables: for cnt, subgraph in enumerate( self.subgraph_generator.generate_subgraph( fact_table, self.params.max_queries_per_fact_table, ), ): query = self.query_builder.generate_query_from_subgraph(subgraph) for idx in range(1, self.params.max_queries_per_signature + 1): query = self.query_builder.add_predicates( subgraph, query, ) yield GeneratedQueryFeatures( query=query.get_sql(), template_number=cnt, predicate_number=idx, fact_table=fact_table, |
︙ | ︙ |
Changes to src/query_generator/join_based_query_generator/utils/subgraph_generator.py.
︙ | ︙ | |||
9 10 11 12 13 14 15 | MAX_ATTEMPTS_FOR_NEW_SUBGRAPH = 1000 class SubGraphGenerator: def __init__( self, graph: ForeignKeyGraph, | | | | > | | 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | MAX_ATTEMPTS_FOR_NEW_SUBGRAPH = 1000 class SubGraphGenerator: def __init__( self, graph: ForeignKeyGraph, keep_edge_probability: float, max_hops: int, seen_subgraphs: dict[int, bool], ) -> None: self.hops = max_hops self.keep_edge_probability = keep_edge_probability self.graph = graph self.seen_subgraphs: dict[int, bool] = seen_subgraphs.copy() def get_random_subgraph(self, fact_table: str) -> list[ForeignKeyGraph.Edge]: """Starting from the fact table, for each edge of the current table we decide based on the keep_edge_probabilityability whether to keep the edge or not. We repeat this process up until the maximum number of hops. """ @dataclass class JoinDepthNode: table: str depth: int queue: deque[JoinDepthNode] = deque() queue.append(JoinDepthNode(fact_table, 0)) edges_subgraph = [] while queue: current_node = queue.popleft() if current_node.depth >= self.hops: continue current_edges = self.graph.get_edges(current_node.table) for current_edge in current_edges: if random.random() < self.keep_edge_probability: edges_subgraph.append(current_edge) queue.append( JoinDepthNode( current_edge.reference_table.name, current_node.depth + 1, ), ) |
︙ | ︙ |
Changes to src/query_generator/main.py.
︙ | ︙ | |||
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | make_redundant_histograms, query_histograms, ) from query_generator.utils.definitions import ( Dataset, Extension, QueryGenerationParameters, ) from query_generator.utils.show_messages import show_dev_warning from query_generator.utils.utils import validate_file_path app = typer.Typer(name="Query Generation") @app.command() def snowflake( | > > > > > | < < < < | < | < < < < < < < < | < | < < < < < < < < < < < < < < < < < | < < < < < < < < < < < < < < < < < < < < < | > | | | | | < < > < < < < < | < < < < < < < < < < < < < < < < | < | < < < < < < < < | < < < < < < < < < | < < > | | | | < < < > | | | | < < < < | 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | make_redundant_histograms, query_histograms, ) from query_generator.utils.definitions import ( Dataset, Extension, QueryGenerationParameters, ) from query_generator.utils.params import ( SearchParametersEndpoint, SnowflakeEndpoint, read_and_parse_toml, ) from query_generator.utils.show_messages import show_dev_warning from query_generator.utils.utils import validate_file_path app = typer.Typer(name="Query Generation") @app.command() def snowflake( config_path: Annotated[ str, typer.Option( "-c", "--config", help="The path to the configuration file" "They can be found in the params_config/query_generation/ folder", ), ], ) -> None: """Generate queries using a random subgraph.""" params_endpoint = read_and_parse_toml(Path(config_path), SnowflakeEndpoint) params = QueryGenerationParameters( dataset=params_endpoint.dataset, max_hops=params_endpoint.max_hops, max_queries_per_fact_table=params_endpoint.max_queries_per_fact_table, max_queries_per_signature=params_endpoint.max_queries_per_signature, keep_edge_probability=params_endpoint.keep_edge_probability, seen_subgraphs={}, predicate_parameters=params_endpoint.predicate_parameters, ) generate_and_write_queries(params) @app.command() def param_search( config_path: Annotated[ str, typer.Option( "-c", "--config", help="The path to the configuration file" "They can be found in the params_config/search_params/ folder", ), ], ) -> None: """This is an extension of the Snowflake algorithm. It runs multiple batches with different configurations of the algorithm. This allows us to get multiple results. """ params = read_and_parse_toml( Path(config_path), SearchParametersEndpoint, ) show_dev_warning(dev=params.dev) scale_factor = 0.1 if params.dev else 100 con = setup_duckdb(params.dataset, scale_factor) run_snowflake_param_seach( SearchParameters( scale_factor=scale_factor, con=con, user_input=params, ), ) @app.command() def cherry_pick( dataset: Annotated[ |
︙ | ︙ |
Changes to src/query_generator/predicate_generator/predicate_generator.py.
1 2 3 4 5 6 7 8 | import math import random from collections.abc import Iterator from dataclasses import dataclass from enum import Enum import polars as pl | > > | > > > | > > > > | > > > > > > > > > > > < | | | | > > > > > | | | > > > > > > > > > > > | > | | | | > > > > > > > > > > > > > | | | > > > > > > > > > > > > > > > > > < < | | | > | < > > > > > > > | > > > > | > > > > > > > | > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > | > > > > > > > > | > > > > > > > > | | | | | | | < < | | | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 | import math import random from abc import ABC from collections.abc import Iterator from dataclasses import dataclass from enum import Enum import numpy as np import polars as pl from query_generator.tools.histograms import ( HistogramColumns, MostCommonValuesColumns, ) from query_generator.utils.definitions import ( Dataset, PredicateOperatorProbability, PredicateParameters, ) from query_generator.utils.exceptions import ( InvalidHistogramTypeError, UnkownDatasetError, ) SupportedHistogramType = float | int | str SuportedHistogramArrayType = list[float] | list[int] | list[str] MAX_DISTINCT_COUNT_FOR_RANGE = 500 PROBABILITY_TO_CHOOSE_EQUALITY = 0.8 PREDICATE_IN_SIZE = 5 class PredicateTypes(Enum): IN = "in" RANGE = "range" EQUALITY = "equality" class HistogramDataType(Enum): INT = "int" FLOAT = "float" DATE = "date" STRING = "string" @dataclass class Predicate(ABC): table: str column: str dtype: HistogramDataType @dataclass class PredicateRange(Predicate): min_value: SupportedHistogramType max_value: SupportedHistogramType @dataclass class PredicateEquality(Predicate): equality_value: SupportedHistogramType @dataclass class PredicateIn(Predicate): in_values: SuportedHistogramArrayType class PredicateGenerator: def __init__(self, dataset: Dataset, predicate_params: PredicateParameters): self.dataset = dataset self.histogram: pl.DataFrame = self.read_histogram() self.predicate_params = predicate_params def _cast_array( self, str_array: list[str], dtype: HistogramDataType ) -> SuportedHistogramArrayType: """Parse the bin string representation to a list of values. Args: bin_str (str): String representation of bins. dtype (str): Data type of the values. Returns: list: List of parsed values. """ if dtype == HistogramDataType.INT: return [int(float(x)) for x in str_array] if dtype == HistogramDataType.FLOAT: return [float(x) for x in str_array] if dtype == HistogramDataType.DATE: return str_array if dtype == HistogramDataType.STRING: return str_array raise InvalidHistogramTypeError(dtype) def _cast_element( self, value: str, dtype: HistogramDataType ) -> SupportedHistogramType: if dtype == HistogramDataType.INT: return int(float(value)) if dtype == HistogramDataType.FLOAT: return float(value) if dtype == HistogramDataType.DATE: return value if dtype == HistogramDataType.STRING: return value raise InvalidHistogramTypeError(dtype) def read_histogram(self) -> pl.DataFrame: """Read the histogram data for the specified dataset. Args: dataset: The dataset type (TPCH or TPCDS). Returns: pd.DataFrame: DataFrame containing the histogram data. """ if self.dataset == Dataset.TPCH: path = "data/histograms/histogram_tpch.parquet" elif self.dataset == Dataset.TPCDS: path = "data/histograms/histogram_tpcds.parquet" elif self.dataset == Dataset.JOB: path = "data/histograms/histogram_job.parquet" else: raise UnkownDatasetError(self.dataset.value) return pl.read_parquet(path).filter(pl.col("histogram") != []) def _get_histogram_type(self, dtype: str) -> HistogramDataType: if dtype in ["INTEGER", "BIGINT"]: return HistogramDataType.INT if dtype.startswith("DECIMAL"): return HistogramDataType.FLOAT if dtype == "DATE": return HistogramDataType.DATE if dtype == "VARCHAR": return HistogramDataType.STRING raise InvalidHistogramTypeError(dtype) def _choose_predicate_type( self, operator_weights: PredicateOperatorProbability ) -> PredicateTypes: weights = [ operator_weights.operator_equal, operator_weights.operator_in, operator_weights.operator_range, ] return random.choices( [ PredicateTypes.EQUALITY, PredicateTypes.IN, PredicateTypes.RANGE, ], weights=weights, )[0] def get_random_predicates( self, tables: list[str], ) -> Iterator[Predicate]: """Generate random predicates based on the histogram data. Args: tables (str): List of tables to select predicates from. num_predicates (int): Number of predicates to generate. row_retention_probability (float): Probability of retaining rows. Returns: List[Predicate]: List of generated predicates. """ selected_tables_histogram = self.histogram.filter( pl.col(HistogramColumns.TABLE.value).is_in(tables) ) for row in selected_tables_histogram.sample( n=self.predicate_params.extra_predicates ).iter_rows(named=True): table = row[HistogramColumns.TABLE.value] column = row[HistogramColumns.COLUMN.value] dtype = self._get_histogram_type(row[HistogramColumns.DTYPE.value]) predicate_type = self._choose_predicate_type( self.predicate_params.operator_weights ) if predicate_type == PredicateTypes.RANGE: yield self._get_range_predicate( table, column, row[HistogramColumns.HISTOGRAM.value], dtype ) elif predicate_type == PredicateTypes.IN: array = self._get_in_array( row[HistogramColumns.MOST_COMMON_VALUES.value], row[HistogramColumns.TABLE_SIZE.value], row[HistogramColumns.HISTOGRAM_MCV.value], ) if array is not None: yield self._get_in_predicate(array, table, column, dtype) else: continue elif predicate_type == PredicateTypes.EQUALITY: value = self._get_equality_value( row[HistogramColumns.MOST_COMMON_VALUES.value], row[HistogramColumns.TABLE_SIZE.value], ) if value is not None: yield self._get_equality_predicate(value, table, column, dtype) else: continue def _get_in_predicate( self, array: list[str], table: str, column: str, dtype: HistogramDataType ) -> PredicateIn: cast_array = self._cast_array(array, dtype) return PredicateIn(table, column, dtype, cast_array) def _get_in_array( self, most_common_values: list[dict[str, int | str]], table_size: int, histogram: list[str], ) -> list[str] | None: """ Gets the array for the IN operator """ value = self._get_equality_value(most_common_values, table_size) if value is None: return None noise_values = random.sample( histogram, k=min(self.predicate_params.extra_values_for_in, len(histogram)), ) return [value] + noise_values def _get_equality_predicate( self, value: str, table: str, column: str, dtype: HistogramDataType ) -> PredicateEquality: cast_value = self._cast_element(value, dtype) return PredicateEquality( table=table, column=column, dtype=dtype, equality_value=cast_value ) def _get_equality_value( self, most_common_values: list[dict[str, int | str]], table_size: int, ) -> str | None: mcv_probabilities: list[float] = [ float(table_size) / float(v[MostCommonValuesColumns.COUNT.value]) for v in most_common_values ] mcv_probabilities_np = np.array(mcv_probabilities) filtered_indices = np.where( mcv_probabilities_np > self.predicate_params.equality_lower_bound_probability )[0] if len(filtered_indices) == 0: return None idx = random.choice(filtered_indices) value = most_common_values[idx][MostCommonValuesColumns.VALUE.value] assert isinstance(value, str) return value def _get_range_predicate( self, table: str, column: str, bins: list[str], dtype: HistogramDataType, ) -> PredicateRange: min_value, max_value = self._get_min_max_from_bins(bins, dtype) return PredicateRange( table=table, column=column, min_value=min_value, max_value=max_value, dtype=dtype, ) def _get_min_max_from_bins( self, bins: list[str], dtype: HistogramDataType, ) -> tuple[SupportedHistogramType, SupportedHistogramType]: """Convert the bins string representation to a tuple of min and max values. Args: bins (str): String representation of bins. row_retention_probability (float): Probability of retaining rows. Returns: tuple: Tuple containing min and max values. """ histogram_array: SuportedHistogramArrayType = self._cast_array(bins, dtype) subrange_length = math.ceil( self.predicate_params.row_retention_probability * len(histogram_array) ) start_index = random.randint(0, len(histogram_array) - subrange_length) min_value = histogram_array[start_index] max_value = histogram_array[ min(start_index + subrange_length, len(histogram_array) - 1) ] return min_value, max_value |
Changes to src/query_generator/tools/histograms.py.
︙ | ︙ | |||
16 17 18 19 20 21 22 | get_equi_height_histogram, get_frequent_non_null_values, get_histogram_excluding_common_values, get_tables, ) from query_generator.utils.exceptions import InvalidHistogramTypeError | | > > > | 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | get_equi_height_histogram, get_frequent_non_null_values, get_histogram_excluding_common_values, get_tables, ) from query_generator.utils.exceptions import InvalidHistogramTypeError class MostCommonValuesColumns(Enum): VALUE = "value" COUNT = "count" class RedundantHistogramsDataType(Enum): """ This class was made for compatibility with old code that generated this histogram: https://github.com/udao-moo/udao-spark-optimizer-dev/blob/main |
︙ | ︙ | |||
87 88 89 90 91 92 93 | def get_most_common_values( con: duckdb.DuckDBPyConnection, table: str, column: str, common_value_size: int, | < < < | < | 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | def get_most_common_values( con: duckdb.DuckDBPyConnection, table: str, column: str, common_value_size: int, ) -> list[RawDuckDBMostCommonValues]: return get_frequent_non_null_values(con, table, column, common_value_size) def get_histogram_array(histogram_params: HistogramParams) -> list[str]: histogram_raw = get_equi_height_histogram( histogram_params.con, histogram_params.table, histogram_params.column.column_name, |
︙ | ︙ | |||
114 115 116 117 118 119 120 | def get_histogram_array_excluding_common_values( histogram_params: HistogramParams, common_values_size: int, distinct_count: int, ) -> list[str]: histogram_array: list[RawDuckDBHistograms] = [] | < < | < | 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | def get_histogram_array_excluding_common_values( histogram_params: HistogramParams, common_values_size: int, distinct_count: int, ) -> list[str]: histogram_array: list[RawDuckDBHistograms] = [] if distinct_count > common_values_size: histogram_array = get_histogram_excluding_common_values( histogram_params.con, histogram_params.table, histogram_params.column.column_name, histogram_params.histogram_size, common_values_size, ) |
︙ | ︙ | |||
180 181 182 183 184 185 186 | if include_mvc: # Get most common values most_common_values = get_most_common_values( con, table, column.column_name, common_values_size, | < > | > > | 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | if include_mvc: # Get most common values most_common_values = get_most_common_values( con, table, column.column_name, common_values_size, ) # Get histogram array excluding common values histogram_array_excluding_mcv = ( get_histogram_array_excluding_common_values( histogram_params, common_values_size, distinct_count, ) ) row_dict |= { HistogramColumns.MOST_COMMON_VALUES.value: [ { MostCommonValuesColumns.VALUE.value: value.value, MostCommonValuesColumns.COUNT.value: value.count, } for value in most_common_values ], HistogramColumns.HISTOGRAM_MCV.value: histogram_array_excluding_mcv, } rows.append(row_dict) return pl.DataFrame(rows) |
︙ | ︙ |
Changes to src/query_generator/utils/definitions.py.
︙ | ︙ | |||
15 16 17 18 19 20 21 22 23 24 25 | class Dataset(Enum): TPCDS = "TPCDS" TPCH = "TPCH" JOB = "JOB" @dataclass class QueryGenerationParameters: max_hops: int max_queries_per_signature: int max_queries_per_fact_table: int | > > > > > > > > > > > > > > > > > > > > > > > | < < < > | 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | class Dataset(Enum): TPCDS = "TPCDS" TPCH = "TPCH" JOB = "JOB" @dataclass class PredicateOperatorProbability: """Probability of using a specific predicate operator. They are based on choice with weights for each operator. """ operator_in: float operator_equal: float operator_range: float @dataclass class PredicateParameters: extra_predicates: int row_retention_probability: float operator_weights: PredicateOperatorProbability equality_lower_bound_probability: float extra_values_for_in: int # TODO(Gabriel): http://localhost:8080/tktview/205e90a1fa @dataclass class QueryGenerationParameters: dataset: Dataset max_hops: int max_queries_per_signature: int max_queries_per_fact_table: int keep_edge_probability: float seen_subgraphs: dict[int, bool] predicate_parameters: PredicateParameters @dataclass class GeneratedQueryFeatures: query: str template_number: int predicate_number: int |
︙ | ︙ |
Changes to src/query_generator/utils/exceptions.py.
︙ | ︙ | |||
23 24 25 26 27 28 29 | class DuplicateEdgesError(Exception): def __init__(self, table: str) -> None: super().__init__(f"Duplicate edges found for table {table}.") | | | 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | class DuplicateEdgesError(Exception): def __init__(self, table: str) -> None: super().__init__(f"Duplicate edges found for table {table}.") class UnkownDatasetError(Exception): def __init__(self, dataset: str) -> None: super().__init__(f"Unknown dataset: {dataset}") class MissingScaleFactorError(Exception): def __init__(self, dataset: str) -> None: super().__init__( |
︙ | ︙ |
Added src/query_generator/utils/params.py.
> > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | import tomllib from dataclasses import dataclass from pathlib import Path from typing import TypeVar from cattrs import structure from query_generator.utils.definitions import ( Dataset, PredicateOperatorProbability, PredicateParameters, ) @dataclass class SearchParametersEndpoint: """ Represents the parameters used for configuring search queries, including query builder, subgraph, and predicate options. This class is designed to support both the `IN` and `=` statements in query generation. Attributes: dataset (Dataset): The dataset to be queried. dev (bool): Flag indicating whether to use development settings. max_queries_per_fact_table (int): Maximum number of queries per fact table. max_queries_per_signature (int): Maximum number of queries per signature. unique_joins (bool): Whether to enforce unique joins in the subgraph. max_hops (list[int]): Maximum number of hops allowed in the subgraph. keep_edge_probability (float): Probability of retaining an edge in the subgraph. extra_predicates (list[int]): Number of additional predicates to include in the query. row_retention_probability (list[float]): Probability of retaining a row for range predicates operator_weights (PredicateOperatorProbability): Probability distribution for predicate operators. equality_lower_bound_probability (float): Lower bound probability when using the `=` and the `IN` operators """ # Query Builder dataset: Dataset dev: bool max_queries_per_fact_table: int max_queries_per_signature: int # Subgraph unique_joins: bool max_hops: list[int] keep_edge_probability: float # Predicates extra_predicates: list[int] row_retention_probability: list[float] operator_weights: PredicateOperatorProbability equality_lower_bound_probability: list[float] extra_values_for_in: int @dataclass class SnowflakeEndpoint: """ Represents the parameters used for configuring query generation, including query builder, subgraph, and predicate options. Attributes: dataset (Dataset): The dataset to be used for query generation. max_queries_per_signature (int): Maximum number of queries to generate per signature. max_queries_per_fact_table (int): Maximum number of queries to generate per fact table. max_hops (int): Maximum number of hops allowed in the subgraph. keep_edge_probability (float): Probability of retaining an edge in the subgraph. extra_predicates (int): Number of extra predicates to add to the query. row_retention_probability (float): Probability of retaining a row after applying predicates. operator_weights (PredicateOperatorProbability): Probability distribution for predicate operators. equality_lower_bound_probability (float): Probability of using a lower bound for equality predicates. """ # Query builder dataset: Dataset max_queries_per_signature: int max_queries_per_fact_table: int # Subgraph max_hops: int keep_edge_probability: float # Predicates predicate_parameters: PredicateParameters T = TypeVar("T") def read_and_parse_toml(path: Path, cls: type[T]) -> T: toml_dict = tomllib.loads(path.read_text()) return structure(toml_dict, cls) |
Changes to tests/duckdb/test_binning.py.
1 2 3 4 5 6 7 8 9 10 | from unittest import mock import polars as pl import pytest from query_generator.duckdb_connection.binning import ( SearchParameters, run_snowflake_param_seach, ) from query_generator.tools.cherry_pick_binning import make_bins_in_csv | > > | | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | import tomllib from unittest import mock from cattrs import structure import polars as pl import pytest from query_generator.duckdb_connection.binning import ( SearchParameters, run_snowflake_param_seach, ) from query_generator.tools.cherry_pick_binning import make_bins_in_csv from query_generator.utils.params import SearchParametersEndpoint @pytest.mark.parametrize( "count_star, upper_bound, total_bins, expected_bin", [ (5, 10, 5, 3), (0, 10, 5, 0), |
︙ | ︙ | |||
40 41 42 43 44 45 46 | " but got {computed_bin}" ) @pytest.mark.parametrize( "extra_predicates, expected_call_count, unique_joins", [ | | | | | > > > > > > > > > > > > > > > > > > > < < < < < > | 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | " but got {computed_bin}" ) @pytest.mark.parametrize( "extra_predicates, expected_call_count, unique_joins", [ ("[1]", 120 * 1 + 14, "false"), ("[1]", 120 * 1 + 14, "true"), # Inventory is small and prooduces 14 queries total ("[1, 2]", 120 * 2 + 14, "true"), ("[1, 2]", 120 * 2 + 14 * 2, "false"), ], ) def test_binning_calls(extra_predicates, expected_call_count, unique_joins): with mock.patch( "query_generator.duckdb_connection.binning.Writer.write_query_to_batch", ) as mock_writer: with mock.patch( "query_generator.duckdb_connection.binning.get_result_from_duckdb", ) as mock_connect: mock_connect.return_value = 0 data_toml = f""" dataset = "TPCDS" dev = true max_hops = [1] extra_predicates = {extra_predicates} row_retention_probability = [0.2] unique_joins = {unique_joins} max_queries_per_fact_table = 10 max_queries_per_signature = 2 keep_edge_probability = 0.2 equality_lower_bound_probability = [0] extra_values_for_in = 3 [operator_weights] operator_in = 1 operator_range = 3 operator_equal = 3 """ user_input = structure(tomllib.loads(data_toml), SearchParametersEndpoint) run_snowflake_param_seach( search_params=SearchParameters( scale_factor=0, con=None, user_input=user_input, ), ) assert mock_writer.call_count == expected_call_count, ( f"Expected {expected_call_count} calls to write_query, " f"but got {mock_writer.call_count}" ) |
Changes to tests/duckdb/test_duckdb_utils.py.
1 2 3 4 5 6 7 8 | from query_generator.duckdb_connection.setup import setup_duckdb from query_generator.duckdb_connection.utils import ( get_distinct_count, get_equi_height_histogram, get_frequent_non_null_values, ) from query_generator.tools.histograms import DuckDBHistogramParser from query_generator.utils.definitions import Dataset | > > | < | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import datetime from query_generator.duckdb_connection.setup import setup_duckdb from query_generator.duckdb_connection.utils import ( get_distinct_count, get_equi_height_histogram, get_frequent_non_null_values, ) from query_generator.tools.histograms import DuckDBHistogramParser from query_generator.utils.definitions import Dataset from tests.utils import is_float def test_distinct_values(): """Test the setup of DuckDB.""" # Setup DuckDB con = setup_duckdb(Dataset.TPCDS, 0.1) assert get_distinct_count(con, "call_center", "cc_call_center_sk") == 1 |
︙ | ︙ |
Changes to tests/file_management/test_read_histograms.py.
1 2 3 4 | from unittest import mock import pytest | > < | | | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | from unittest import mock import polars as pl import pytest from query_generator.predicate_generator.predicate_generator import ( HistogramDataType, PredicateGenerator, ) from query_generator.tools.histograms import HistogramColumns from query_generator.utils.definitions import Dataset, PredicateParameters from query_generator.utils.exceptions import InvalidHistogramTypeError def test_read_histograms(): for dataset in Dataset: predicate_generator = PredicateGenerator(dataset, None) histogram = predicate_generator.read_histogram() assert not histogram.is_empty() assert histogram[HistogramColumns.DTYPE.value].dtype == pl.Utf8 assert histogram[HistogramColumns.COLUMN.value].dtype == pl.Utf8 assert histogram[HistogramColumns.DTYPE.value].dtype == pl.Utf8 assert histogram[HistogramColumns.HISTOGRAM.value].dtype == pl.List(pl.Utf8) |
︙ | ︙ | |||
74 75 76 77 78 79 80 | max_index, dtype, ): with mock.patch( "query_generator.predicate_generator.predicate_generator.random.randint", return_value=mock_rand, ): | | > > > > > > > > > | | | | 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | max_index, dtype, ): with mock.patch( "query_generator.predicate_generator.predicate_generator.random.randint", return_value=mock_rand, ): predicate_generator = PredicateGenerator( Dataset.TPCH, PredicateParameters( extra_predicates=None, row_retention_probability=row_retention_probability, operator_weights=None, equality_lower_bound_probability=None, extra_values_for_in=None, ), ) min_value, max_value = predicate_generator._get_min_max_from_bins( bins_array, dtype ) assert min_value == bins_array[min_index] assert max_value == bins_array[max_index] def test_get_invalid_histogram_type(): predicate_generator = PredicateGenerator(Dataset.TPCH, None) with pytest.raises(InvalidHistogramTypeError): predicate_generator._get_histogram_type("not_supported_type") @pytest.mark.parametrize( "input_type, expected_type", [ ("INTEGER", HistogramDataType.INT), ("BIGINT", HistogramDataType.INT), ("DECIMAL(10,2)", HistogramDataType.FLOAT), ("DECIMAL(7,4)", HistogramDataType.FLOAT), ("DATE", HistogramDataType.DATE), ("VARCHAR", HistogramDataType.STRING), ], ) def test_get_valid_histogram_type(input_type, expected_type): predicate_generator = PredicateGenerator(Dataset.TPCH, None) assert predicate_generator._get_histogram_type(input_type) == expected_type |
Changes to tests/query_generation/test_make_queries.py.
1 2 3 | from unittest import mock import pytest | < | > > | > | > > | > > > > | > > | | > | > > | > > > > | > > | > | | > > | > > > > | > > | > | > > > > > > > > > > > | | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | from unittest import mock import pytest from query_generator.database_schemas.schemas import get_schema from query_generator.join_based_query_generator.snowflake import ( QueryBuilder, generate_and_write_queries, ) from query_generator.predicate_generator.predicate_generator import ( HistogramDataType, PredicateRange, ) from query_generator.utils.definitions import ( Dataset, PredicateOperatorProbability, PredicateParameters, QueryGenerationParameters, ) from query_generator.utils.exceptions import UnkownDatasetError from pypika import OracleQuery from pypika import functions as fn def test_tpch_query_generation(): with mock.patch( "query_generator.join_based_query_generator.snowflake.Writer.write_query", ) as mock_writer: generate_and_write_queries( QueryGenerationParameters( dataset=Dataset.TPCDS, max_hops=1, max_queries_per_fact_table=1, max_queries_per_signature=1, keep_edge_probability=0.2, seen_subgraphs={}, predicate_parameters=PredicateParameters( operator_weights=PredicateOperatorProbability( operator_in=0.4, operator_equal=0.4, operator_range=0.2, ), extra_predicates=1, row_retention_probability=0.2, equality_lower_bound_probability=0, extra_values_for_in=3, ), ) ) assert mock_writer.call_count > 5 def test_tpcds_query_generation(): with mock.patch( "query_generator.join_based_query_generator.snowflake.Writer.write_query", ) as mock_writer: generate_and_write_queries( QueryGenerationParameters( dataset=Dataset.TPCDS, max_hops=1, max_queries_per_fact_table=1, max_queries_per_signature=1, keep_edge_probability=0.2, seen_subgraphs={}, predicate_parameters=PredicateParameters( operator_weights=PredicateOperatorProbability( operator_in=0.4, operator_equal=0.4, operator_range=0.2, ), extra_predicates=1, row_retention_probability=0.2, equality_lower_bound_probability=0, extra_values_for_in=3, ), ), ) assert mock_writer.call_count > 5 def test_non_implemented_dataset(): with mock.patch( "query_generator.join_based_query_generator.snowflake.Writer.write_query", ) as mock_writer: with pytest.raises(UnkownDatasetError): generate_and_write_queries( QueryGenerationParameters( dataset="non_implemented_dataset", max_hops=1, max_queries_per_fact_table=1, max_queries_per_signature=1, keep_edge_probability=0.2, seen_subgraphs={}, predicate_parameters=PredicateParameters( operator_weights=PredicateOperatorProbability( operator_in=0.4, operator_equal=0.4, operator_range=0.2, ), extra_predicates=1, row_retention_probability=0.2, equality_lower_bound_probability=0, extra_values_for_in=3, ), ), ) assert mock_writer.call_count == 0 def test_add_rage_supports_all_histogram_types(): tables_schema, _ = get_schema(Dataset.TPCH) query_builder = QueryBuilder( None, tables_schema, Dataset.TPCH, PredicateParameters( extra_predicates=None, row_retention_probability=0.2, operator_weights=None, equality_lower_bound_probability=None, extra_values_for_in=None, ), ) for dtype in HistogramDataType: query_builder._add_range( OracleQuery() .from_(query_builder.table_to_pypika_table["lineitem"]) .select(fn.Count("*")), PredicateRange( table="lineitem", column="foo", min_value=2020, max_value=2020, dtype=dtype, ), ) |