Many hyperlinks are disabled.
Use anonymous login
to enable hyperlinks.
Changes In Branch new-predicates Excluding Merge-Ins
This is equivalent to a diff from 0f0856db5a to 4162607284
|
2025-05-28
| ||
| 09:27 | Merges IN and = predicate check-in: 0a0518ab14 user: mathos tags: trunk | |
| 09:24 | Adds equality lower bound as an array. Leaf check-in: 4162607284 user: mathos tags: new-predicates | |
| 00:05 | Finally. An stable version check-in: 48e17f1cde user: mathos tags: new-predicates | |
|
2025-05-26
| ||
| 11:22 | Minor refactor to predicate class. Ticket [1e726428f6e719fb] check-in: c93b2b766c user: mathos tags: new-predicates | |
| 11:12 | Fix to be able to save the CSV check-in: 0f0856db5a user: mathos tags: trunk | |
| 09:57 | adds table_size to histogram check-in: 288ba9b582 user: mathos tags: trunk | |
Changes to data/histograms/histogram_tpcds.parquet.
cannot compute difference between binary files
Added params_config/search_params/tpcds.toml.
> > > > > > > > > > > > > > > > | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | dataset = "TPCDS" dev = true max_hops = [1,2,4] extra_predicates = [1,3,5] row_retention_probability = [0.2, 0.3, 0.4, 0.6, 0.8, 0.85, 0.9, 1.0] unique_joins = true max_queries_per_fact_table = 10 max_queries_per_signature = 2 keep_edge_probability = 0.2 equality_lower_bound_probability = [0,0.1] extra_values_for_in = 3 [operator_weights] operator_in = 1 operator_range = 3 operator_equal = 3 |
Added params_config/search_params/tpcds_dev.toml.
> > > > > > > > > > > > > > > > | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | dataset = "TPCDS" dev = true max_hops = [1] extra_predicates = [5] row_retention_probability = [0.2, 0.9] unique_joins = true max_queries_per_fact_table = 1 max_queries_per_signature = 2 keep_edge_probability = 0.2 equality_lower_bound_probability = [0,0.1] extra_values_for_in = 3 [operator_weights] operator_in = 1 operator_range = 3 operator_equal = 3 |
Added params_config/snowflake/tpcds.toml.
> > > > > > > > > > > > > > > > > > | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | dataset = "TPCDS" max_hops = 3 max_queries_per_fact_table = 100 max_queries_per_signature = 1 keep_edge_probability = 0.2 [predicate_parameters] row_retention_probability = 0.2 extra_predicates = 3 equality_lower_bound_probability = 0.00 extra_values_for_in = 3 [predicate_parameters.operator_weights] operator_in = 1 operator_range = 3 operator_equal = 3 |
Changes to pyproject.toml.
| ︙ | ︙ | |||
26 27 28 29 30 31 32 33 34 35 36 37 38 39 | typer = ">=0.15.2,<0.16" rich = ">=14.0.0,<15" pypika = ">=0.48.9,<0.49" numpy = ">=2.2.5,<3" duckdb = ">=1.2.2,<2" polars = ">=1.27.1,<2" tqdm = "*" [tool.pixi.feature.test.dependencies] pytest = ">=8.3.5,<9" [tool.pixi.feature.lint.dependencies] ruff = ">=0.11.7,<0.12" | > | 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | typer = ">=0.15.2,<0.16" rich = ">=14.0.0,<15" pypika = ">=0.48.9,<0.49" numpy = ">=2.2.5,<3" duckdb = ">=1.2.2,<2" polars = ">=1.27.1,<2" tqdm = "*" cattrs = ">=24.1.2,<25" [tool.pixi.feature.test.dependencies] pytest = ">=8.3.5,<9" [tool.pixi.feature.lint.dependencies] ruff = ">=0.11.7,<0.12" |
| ︙ | ︙ |
Changes to src/query_generator/database_schemas/schemas.py.
1 2 3 4 5 6 7 | from typing import Any from query_generator.database_schemas.tpcds import get_tpcds_table_info from query_generator.database_schemas.tpch import get_tpch_table_info from query_generator.utils.definitions import Dataset from query_generator.utils.exceptions import ( PartiallySupportedDatasetError, | | | | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from typing import Any
from query_generator.database_schemas.tpcds import get_tpcds_table_info
from query_generator.database_schemas.tpch import get_tpch_table_info
from query_generator.utils.definitions import Dataset
from query_generator.utils.exceptions import (
PartiallySupportedDatasetError,
UnkownDatasetError,
)
def get_schema(dataset: Dataset) -> tuple[dict[str, dict[str, Any]], list[str]]:
"""Get the schema of the database based on the dataset.
Args:
dataset (Dataset): The dataset to get the schema for.
Returns:
Tuple[Dict[str, Dict[str, Any]], List[str]]: A tuple containing the schema
as a dictionary and a list of fact tables
"""
if dataset == Dataset.TPCDS:
return get_tpcds_table_info()
if dataset == Dataset.TPCH:
return get_tpch_table_info()
if dataset == Dataset.JOB:
raise PartiallySupportedDatasetError(dataset.value)
raise UnkownDatasetError(dataset)
|
Changes to src/query_generator/database_schemas/tpcds.py.
| ︙ | ︙ | |||
436 437 438 439 440 441 442 |
"s_floor_space": {"max": 9917607, "min": 5010719},
"s_gmt_offset": {"max": -5.0, "min": -6.0},
"s_market_id": {"max": 10, "min": 1},
"s_number_employees": {"max": 300, "min": 200},
"s_rec_end_date": {"max": "2001-03-12", "min": "1999-03-13"},
"s_rec_start_date": {"max": "2001-03-13", "min": "1997-03-13"},
"s_store_sk": {"max": 402, "min": 1},
| < | 436 437 438 439 440 441 442 443 444 445 446 447 448 449 |
"s_floor_space": {"max": 9917607, "min": 5010719},
"s_gmt_offset": {"max": -5.0, "min": -6.0},
"s_market_id": {"max": 10, "min": 1},
"s_number_employees": {"max": 300, "min": 200},
"s_rec_end_date": {"max": "2001-03-12", "min": "1999-03-13"},
"s_rec_start_date": {"max": "2001-03-13", "min": "1997-03-13"},
"s_store_sk": {"max": 402, "min": 1},
},
"foreign_keys": [],
},
"store_returns": {
"alias": "sr",
"columns": {
"sr_addr_sk": {"max": 1000000, "min": 1},
|
| ︙ | ︙ |
Changes to src/query_generator/duckdb_connection/binning.py.
| ︙ | ︙ | |||
9 10 11 12 13 14 15 | QueryGenerator, ) from query_generator.join_based_query_generator.utils.query_writer import ( Writer, ) from query_generator.utils.definitions import ( BatchGeneratedQueryFeatures, | < > > | < < < < | > | | > > > | > > | | | > | | | | > > | | < > > > > | 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
QueryGenerator,
)
from query_generator.join_based_query_generator.utils.query_writer import (
Writer,
)
from query_generator.utils.definitions import (
BatchGeneratedQueryFeatures,
Extension,
PredicateParameters,
QueryGenerationParameters,
)
from query_generator.utils.params import SearchParametersEndpoint
@dataclass
class SearchParameters:
user_input: SearchParametersEndpoint
scale_factor: int | float
con: duckdb.DuckDBPyConnection
def get_result_from_duckdb(query: str, con: duckdb.DuckDBPyConnection) -> int:
try:
result = int(con.sql(query).fetchall()[0][0])
except duckdb.BinderException as e:
print(f"Invalid query, exception: {e},\n{query}")
return -1
return result
def get_total_iterations(search_params: SearchParametersEndpoint) -> int:
"""Get the total number of iterations for the Snowflake binning process.
Args:
search_params (SearchParameters): The parameters for the Snowflake
binning process.
Returns:
int: The total number of iterations.
"""
return (
len(search_params.max_hops)
* len(search_params.extra_predicates)
* len(search_params.row_retention_probability)
* len(search_params.equality_lower_bound_probability)
)
def run_snowflake_param_seach(
search_params: SearchParameters,
) -> None:
"""Run the Snowflake binning process. Binning is equiwidth binning.
Args:
parameters (BinningSnowflakeParameters): The parameters for
the Snowflake binning process.
"""
query_writer = Writer(
search_params.user_input.dataset,
Extension.SNOWFLAKE_SEARCH_PARAMS,
)
rows: list[dict[str, str | int | float]] = []
total_iterations = get_total_iterations(search_params.user_input)
batch_number = 0
seen_subgraphs: dict[int, bool] = {}
for (
max_hops,
extra_predicates,
row_retention_probability,
equality_lower_bound_probability,
) in tqdm(
product(
search_params.user_input.max_hops,
search_params.user_input.extra_predicates,
search_params.user_input.row_retention_probability,
search_params.user_input.equality_lower_bound_probability,
),
total=total_iterations,
desc="Progress",
):
batch_number += 1
query_generator = QueryGenerator(
QueryGenerationParameters(
dataset=search_params.user_input.dataset,
max_hops=max_hops,
max_queries_per_fact_table=search_params.user_input.max_queries_per_fact_table,
max_queries_per_signature=search_params.user_input.max_queries_per_signature,
keep_edge_probability=search_params.user_input.keep_edge_probability,
seen_subgraphs=seen_subgraphs,
predicate_parameters=PredicateParameters(
extra_predicates=extra_predicates,
row_retention_probability=row_retention_probability,
operator_weights=search_params.user_input.operator_weights,
equality_lower_bound_probability=equality_lower_bound_probability,
extra_values_for_in=search_params.user_input.extra_values_for_in,
),
)
)
for query in query_generator.generate_queries():
selected_rows = get_result_from_duckdb(query.query, search_params.con)
if selected_rows == -1:
continue # invalid query
|
| ︙ | ︙ | |||
124 125 126 127 128 129 130 |
"predicate_number": query.predicate_number,
"fact_table": query.fact_table,
"max_hops": max_hops,
"row_retention_probability": row_retention_probability,
},
)
# Update the seen subgraphs with the new ones
| | | 133 134 135 136 137 138 139 140 141 142 143 |
"predicate_number": query.predicate_number,
"fact_table": query.fact_table,
"max_hops": max_hops,
"row_retention_probability": row_retention_probability,
},
)
# Update the seen subgraphs with the new ones
if search_params.user_input.unique_joins:
seen_subgraphs = query_generator.subgraph_generator.seen_subgraphs
df_queries = pl.DataFrame(rows)
query_writer.write_dataframe(df_queries)
|
Changes to src/query_generator/duckdb_connection/setup.py.
1 2 3 4 5 6 7 8 | import os import duckdb from query_generator.utils.definitions import Dataset from query_generator.utils.exceptions import ( MissingScaleFactorError, PartiallySupportedDatasetError, | | | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import os
import duckdb
from query_generator.utils.definitions import Dataset
from query_generator.utils.exceptions import (
MissingScaleFactorError,
PartiallySupportedDatasetError,
UnkownDatasetError,
)
def load_and_install_libraries() -> None:
duckdb.install_extension("TPCDS")
duckdb.install_extension("TPCH")
duckdb.load_extension("TPCDS")
|
| ︙ | ︙ | |||
25 26 27 28 29 30 31 |
if dataset == Dataset.TPCDS:
con.execute(f"CALL dsdgen(sf = {scale_factor})")
elif dataset == Dataset.TPCH:
con.execute(f"CALL dbgen(sf = {scale_factor})")
elif dataset == Dataset.JOB:
raise PartiallySupportedDatasetError(dataset.value)
else:
| | | | 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
if dataset == Dataset.TPCDS:
con.execute(f"CALL dsdgen(sf = {scale_factor})")
elif dataset == Dataset.TPCH:
con.execute(f"CALL dbgen(sf = {scale_factor})")
elif dataset == Dataset.JOB:
raise PartiallySupportedDatasetError(dataset.value)
else:
raise UnkownDatasetError(dataset)
def get_path(
dataset: Dataset,
scale_factor: float | int | None,
) -> str:
if dataset in [Dataset.TPCDS, Dataset.TPCH]:
return f"data/duckdb/{dataset.value}/{scale_factor}.db"
if dataset == Dataset.JOB:
return f"data/duckdb/{dataset.value}/job.db"
raise UnkownDatasetError(dataset.value)
def setup_duckdb(
dataset: Dataset,
scale_factor: int | float | None = None,
) -> duckdb.DuckDBPyConnection:
"""Installs TPCDS and TPCH datasets in DuckDB.
|
| ︙ | ︙ |
Changes to src/query_generator/join_based_query_generator/snowflake.py.
| ︙ | ︙ | |||
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
# fmt: on
from query_generator.join_based_query_generator.utils.query_writer import (
Writer,
)
from query_generator.predicate_generator.predicate_generator import (
HistogramDataType,
PredicateGenerator,
)
from query_generator.utils.definitions import (
Dataset,
Extension,
GeneratedQueryFeatures,
QueryGenerationParameters,
)
from query_generator.utils.exceptions import InvalidHistogramTypeError
from query_generator.utils.utils import set_seed
class QueryBuilder:
def __init__(
self,
subgraph_generator: SubGraphGenerator,
# TODO(Gabriel): http://localhost:8080/tktview/b9400c203a38f3aef46ec250d98563638ba7988b
tables_schema: Any,
dataset: Dataset,
) -> None:
self.sub_graph_gen = subgraph_generator
self.table_to_pypika_table = {
i: Table(i, alias=tables_schema[i]["alias"]) for i in tables_schema
}
| > > > > > > | | 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
# fmt: on
from query_generator.join_based_query_generator.utils.query_writer import (
Writer,
)
from query_generator.predicate_generator.predicate_generator import (
HistogramDataType,
PredicateEquality,
PredicateGenerator,
PredicateIn,
PredicateRange,
SupportedHistogramType,
)
from query_generator.utils.definitions import (
Dataset,
Extension,
GeneratedQueryFeatures,
PredicateParameters,
QueryGenerationParameters,
)
from query_generator.utils.exceptions import InvalidHistogramTypeError
from query_generator.utils.utils import set_seed
class QueryBuilder:
def __init__(
self,
subgraph_generator: SubGraphGenerator,
# TODO(Gabriel): http://localhost:8080/tktview/b9400c203a38f3aef46ec250d98563638ba7988b
tables_schema: Any,
dataset: Dataset,
predicate_params: PredicateParameters,
) -> None:
self.sub_graph_gen = subgraph_generator
self.table_to_pypika_table = {
i: Table(i, alias=tables_schema[i]["alias"]) for i in tables_schema
}
self.predicate_gen = PredicateGenerator(dataset, predicate_params)
self.tables_schema = tables_schema
def get_subgraph_tables(
self,
subgraph: list[ForeignKeyGraph.Edge],
) -> list[str]:
return list(
|
| ︙ | ︙ | |||
82 83 84 85 86 87 88 |
)
return query
def add_predicates(
self,
subgraph: list[ForeignKeyGraph.Edge],
query: OracleQuery,
| < < < < > | > > > > > | | | < < > | | < | < | | < < < > > > | | | | < < < < | < | | < < < > | > < < | 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
)
return query
def add_predicates(
self,
subgraph: list[ForeignKeyGraph.Edge],
query: OracleQuery,
) -> OracleQuery:
subgraph_tables = self.get_subgraph_tables(subgraph)
for predicate in self.predicate_gen.get_random_predicates(
subgraph_tables,
):
if isinstance(predicate, PredicateRange):
return self._add_range(query, predicate)
if isinstance(predicate, PredicateEquality):
return self._add_equality(query, predicate)
if isinstance(predicate, PredicateIn):
return self._add_in(query, predicate)
raise InvalidHistogramTypeError(str(predicate.dtype))
return query
def _cast_if_needed(
self, value: SupportedHistogramType, dtype: HistogramDataType
) -> Any:
"""Cast the value to the appropriate type if needed."""
if dtype == HistogramDataType.DATE:
return fn.Cast(value, "date")
return value
def _add_range(
self, query: OracleQuery, predicate: PredicateRange
) -> OracleQuery:
return query.where(
self.table_to_pypika_table[predicate.table][predicate.column]
>= self._cast_if_needed(predicate.min_value, predicate.dtype),
).where(
self.table_to_pypika_table[predicate.table][predicate.column]
<= self._cast_if_needed(predicate.max_value, predicate.dtype)
)
def _add_equality(
self, query: OracleQuery, predicate: PredicateEquality
) -> OracleQuery:
return query.where(
self.table_to_pypika_table[predicate.table][predicate.column]
== predicate.equality_value
)
def _add_in(self, query: OracleQuery, predicate: PredicateIn) -> OracleQuery:
return query.where(
self.table_to_pypika_table[predicate.table][predicate.column].isin(
[self._cast_if_needed(i, predicate.dtype) for i in predicate.in_values]
)
)
class QueryGenerator:
def __init__(self, params: QueryGenerationParameters) -> None:
set_seed()
self.params = params
self.tables_schema, self.fact_tables = get_schema(params.dataset)
self.foreign_key_graph = ForeignKeyGraph(self.tables_schema)
self.subgraph_generator = SubGraphGenerator(
self.foreign_key_graph,
params.keep_edge_probability,
params.max_hops,
params.seen_subgraphs,
)
self.query_builder = QueryBuilder(
self.subgraph_generator,
self.tables_schema,
params.dataset,
params.predicate_parameters,
)
def generate_queries(self) -> Iterator[GeneratedQueryFeatures]:
for fact_table in self.fact_tables:
for cnt, subgraph in enumerate(
self.subgraph_generator.generate_subgraph(
fact_table,
self.params.max_queries_per_fact_table,
),
):
query = self.query_builder.generate_query_from_subgraph(subgraph)
for idx in range(1, self.params.max_queries_per_signature + 1):
query = self.query_builder.add_predicates(
subgraph,
query,
)
yield GeneratedQueryFeatures(
query=query.get_sql(),
template_number=cnt,
predicate_number=idx,
fact_table=fact_table,
|
| ︙ | ︙ |
Changes to src/query_generator/join_based_query_generator/utils/subgraph_generator.py.
| ︙ | ︙ | |||
9 10 11 12 13 14 15 |
MAX_ATTEMPTS_FOR_NEW_SUBGRAPH = 1000
class SubGraphGenerator:
def __init__(
self,
graph: ForeignKeyGraph,
| | | | > | | 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
MAX_ATTEMPTS_FOR_NEW_SUBGRAPH = 1000
class SubGraphGenerator:
def __init__(
self,
graph: ForeignKeyGraph,
keep_edge_probability: float,
max_hops: int,
seen_subgraphs: dict[int, bool],
) -> None:
self.hops = max_hops
self.keep_edge_probability = keep_edge_probability
self.graph = graph
self.seen_subgraphs: dict[int, bool] = seen_subgraphs.copy()
def get_random_subgraph(self, fact_table: str) -> list[ForeignKeyGraph.Edge]:
"""Starting from the fact table, for each edge of the current table we
decide based on the keep_edge_probabilityability whether to keep the
edge or not.
We repeat this process up until the maximum number of hops.
"""
@dataclass
class JoinDepthNode:
table: str
depth: int
queue: deque[JoinDepthNode] = deque()
queue.append(JoinDepthNode(fact_table, 0))
edges_subgraph = []
while queue:
current_node = queue.popleft()
if current_node.depth >= self.hops:
continue
current_edges = self.graph.get_edges(current_node.table)
for current_edge in current_edges:
if random.random() < self.keep_edge_probability:
edges_subgraph.append(current_edge)
queue.append(
JoinDepthNode(
current_edge.reference_table.name,
current_node.depth + 1,
),
)
|
| ︙ | ︙ |
Changes to src/query_generator/main.py.
| ︙ | ︙ | |||
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | make_redundant_histograms, query_histograms, ) from query_generator.utils.definitions import ( Dataset, Extension, QueryGenerationParameters, ) from query_generator.utils.show_messages import show_dev_warning from query_generator.utils.utils import validate_file_path app = typer.Typer(name="Query Generation") @app.command() def snowflake( | > > > > > | < < < < | < | < < < < < < < < | < | < < < < < < < < < < < < < < < < < | < < < < < < < < < < < < < < < < < < < < < | > | | | | | < < > < < < < < | < < < < < < < < < < < < < < < < | < | < < < < < < < < | < < < < < < < < < | < < > | | | | < < < > | | | | < < < < | 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
make_redundant_histograms,
query_histograms,
)
from query_generator.utils.definitions import (
Dataset,
Extension,
QueryGenerationParameters,
)
from query_generator.utils.params import (
SearchParametersEndpoint,
SnowflakeEndpoint,
read_and_parse_toml,
)
from query_generator.utils.show_messages import show_dev_warning
from query_generator.utils.utils import validate_file_path
app = typer.Typer(name="Query Generation")
@app.command()
def snowflake(
config_path: Annotated[
str,
typer.Option(
"-c",
"--config",
help="The path to the configuration file"
"They can be found in the params_config/query_generation/ folder",
),
],
) -> None:
"""Generate queries using a random subgraph."""
params_endpoint = read_and_parse_toml(Path(config_path), SnowflakeEndpoint)
params = QueryGenerationParameters(
dataset=params_endpoint.dataset,
max_hops=params_endpoint.max_hops,
max_queries_per_fact_table=params_endpoint.max_queries_per_fact_table,
max_queries_per_signature=params_endpoint.max_queries_per_signature,
keep_edge_probability=params_endpoint.keep_edge_probability,
seen_subgraphs={},
predicate_parameters=params_endpoint.predicate_parameters,
)
generate_and_write_queries(params)
@app.command()
def param_search(
config_path: Annotated[
str,
typer.Option(
"-c",
"--config",
help="The path to the configuration file"
"They can be found in the params_config/search_params/ folder",
),
],
) -> None:
"""This is an extension of the Snowflake algorithm.
It runs multiple batches with different configurations of the algorithm.
This allows us to get multiple results.
"""
params = read_and_parse_toml(
Path(config_path),
SearchParametersEndpoint,
)
show_dev_warning(dev=params.dev)
scale_factor = 0.1 if params.dev else 100
con = setup_duckdb(params.dataset, scale_factor)
run_snowflake_param_seach(
SearchParameters(
scale_factor=scale_factor,
con=con,
user_input=params,
),
)
@app.command()
def cherry_pick(
dataset: Annotated[
|
| ︙ | ︙ |
Changes to src/query_generator/predicate_generator/predicate_generator.py.
1 2 3 4 5 6 7 8 | import math import random from collections.abc import Iterator from dataclasses import dataclass from enum import Enum import polars as pl | > > | > > > | > > > > | > > > > > > > > > > > < | | | | > > > > > | | | > > > > > > > > > > > | > | | | | > > > > > > > > > > > > > | | | > > > > > > > > > > > > > > > > > < < | | | > | < > > > > > > > | > > > > | > > > > > > > | > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > | > > > > > > > > | > > > > > > > > | | | | | | | < < | | | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 |
import math
import random
from abc import ABC
from collections.abc import Iterator
from dataclasses import dataclass
from enum import Enum
import numpy as np
import polars as pl
from query_generator.tools.histograms import (
HistogramColumns,
MostCommonValuesColumns,
)
from query_generator.utils.definitions import (
Dataset,
PredicateOperatorProbability,
PredicateParameters,
)
from query_generator.utils.exceptions import (
InvalidHistogramTypeError,
UnkownDatasetError,
)
SupportedHistogramType = float | int | str
SuportedHistogramArrayType = list[float] | list[int] | list[str]
MAX_DISTINCT_COUNT_FOR_RANGE = 500
PROBABILITY_TO_CHOOSE_EQUALITY = 0.8
PREDICATE_IN_SIZE = 5
class PredicateTypes(Enum):
IN = "in"
RANGE = "range"
EQUALITY = "equality"
class HistogramDataType(Enum):
INT = "int"
FLOAT = "float"
DATE = "date"
STRING = "string"
@dataclass
class Predicate(ABC):
table: str
column: str
dtype: HistogramDataType
@dataclass
class PredicateRange(Predicate):
min_value: SupportedHistogramType
max_value: SupportedHistogramType
@dataclass
class PredicateEquality(Predicate):
equality_value: SupportedHistogramType
@dataclass
class PredicateIn(Predicate):
in_values: SuportedHistogramArrayType
class PredicateGenerator:
def __init__(self, dataset: Dataset, predicate_params: PredicateParameters):
self.dataset = dataset
self.histogram: pl.DataFrame = self.read_histogram()
self.predicate_params = predicate_params
def _cast_array(
self, str_array: list[str], dtype: HistogramDataType
) -> SuportedHistogramArrayType:
"""Parse the bin string representation to a list of values.
Args:
bin_str (str): String representation of bins.
dtype (str): Data type of the values.
Returns:
list: List of parsed values.
"""
if dtype == HistogramDataType.INT:
return [int(float(x)) for x in str_array]
if dtype == HistogramDataType.FLOAT:
return [float(x) for x in str_array]
if dtype == HistogramDataType.DATE:
return str_array
if dtype == HistogramDataType.STRING:
return str_array
raise InvalidHistogramTypeError(dtype)
def _cast_element(
self, value: str, dtype: HistogramDataType
) -> SupportedHistogramType:
if dtype == HistogramDataType.INT:
return int(float(value))
if dtype == HistogramDataType.FLOAT:
return float(value)
if dtype == HistogramDataType.DATE:
return value
if dtype == HistogramDataType.STRING:
return value
raise InvalidHistogramTypeError(dtype)
def read_histogram(self) -> pl.DataFrame:
"""Read the histogram data for the specified dataset.
Args:
dataset: The dataset type (TPCH or TPCDS).
Returns:
pd.DataFrame: DataFrame containing the histogram data.
"""
if self.dataset == Dataset.TPCH:
path = "data/histograms/histogram_tpch.parquet"
elif self.dataset == Dataset.TPCDS:
path = "data/histograms/histogram_tpcds.parquet"
elif self.dataset == Dataset.JOB:
path = "data/histograms/histogram_job.parquet"
else:
raise UnkownDatasetError(self.dataset.value)
return pl.read_parquet(path).filter(pl.col("histogram") != [])
def _get_histogram_type(self, dtype: str) -> HistogramDataType:
if dtype in ["INTEGER", "BIGINT"]:
return HistogramDataType.INT
if dtype.startswith("DECIMAL"):
return HistogramDataType.FLOAT
if dtype == "DATE":
return HistogramDataType.DATE
if dtype == "VARCHAR":
return HistogramDataType.STRING
raise InvalidHistogramTypeError(dtype)
def _choose_predicate_type(
self, operator_weights: PredicateOperatorProbability
) -> PredicateTypes:
weights = [
operator_weights.operator_equal,
operator_weights.operator_in,
operator_weights.operator_range,
]
return random.choices(
[
PredicateTypes.EQUALITY,
PredicateTypes.IN,
PredicateTypes.RANGE,
],
weights=weights,
)[0]
def get_random_predicates(
self,
tables: list[str],
) -> Iterator[Predicate]:
"""Generate random predicates based on the histogram data.
Args:
tables (str): List of tables to select predicates from.
num_predicates (int): Number of predicates to generate.
row_retention_probability (float): Probability of retaining rows.
Returns:
List[Predicate]: List of generated predicates.
"""
selected_tables_histogram = self.histogram.filter(
pl.col(HistogramColumns.TABLE.value).is_in(tables)
)
for row in selected_tables_histogram.sample(
n=self.predicate_params.extra_predicates
).iter_rows(named=True):
table = row[HistogramColumns.TABLE.value]
column = row[HistogramColumns.COLUMN.value]
dtype = self._get_histogram_type(row[HistogramColumns.DTYPE.value])
predicate_type = self._choose_predicate_type(
self.predicate_params.operator_weights
)
if predicate_type == PredicateTypes.RANGE:
yield self._get_range_predicate(
table, column, row[HistogramColumns.HISTOGRAM.value], dtype
)
elif predicate_type == PredicateTypes.IN:
array = self._get_in_array(
row[HistogramColumns.MOST_COMMON_VALUES.value],
row[HistogramColumns.TABLE_SIZE.value],
row[HistogramColumns.HISTOGRAM_MCV.value],
)
if array is not None:
yield self._get_in_predicate(array, table, column, dtype)
else:
continue
elif predicate_type == PredicateTypes.EQUALITY:
value = self._get_equality_value(
row[HistogramColumns.MOST_COMMON_VALUES.value],
row[HistogramColumns.TABLE_SIZE.value],
)
if value is not None:
yield self._get_equality_predicate(value, table, column, dtype)
else:
continue
def _get_in_predicate(
self, array: list[str], table: str, column: str, dtype: HistogramDataType
) -> PredicateIn:
cast_array = self._cast_array(array, dtype)
return PredicateIn(table, column, dtype, cast_array)
def _get_in_array(
self,
most_common_values: list[dict[str, int | str]],
table_size: int,
histogram: list[str],
) -> list[str] | None:
"""
Gets the array for the IN operator
"""
value = self._get_equality_value(most_common_values, table_size)
if value is None:
return None
noise_values = random.sample(
histogram,
k=min(self.predicate_params.extra_values_for_in, len(histogram)),
)
return [value] + noise_values
def _get_equality_predicate(
self, value: str, table: str, column: str, dtype: HistogramDataType
) -> PredicateEquality:
cast_value = self._cast_element(value, dtype)
return PredicateEquality(
table=table, column=column, dtype=dtype, equality_value=cast_value
)
def _get_equality_value(
self,
most_common_values: list[dict[str, int | str]],
table_size: int,
) -> str | None:
mcv_probabilities: list[float] = [
float(table_size) / float(v[MostCommonValuesColumns.COUNT.value])
for v in most_common_values
]
mcv_probabilities_np = np.array(mcv_probabilities)
filtered_indices = np.where(
mcv_probabilities_np
> self.predicate_params.equality_lower_bound_probability
)[0]
if len(filtered_indices) == 0:
return None
idx = random.choice(filtered_indices)
value = most_common_values[idx][MostCommonValuesColumns.VALUE.value]
assert isinstance(value, str)
return value
def _get_range_predicate(
self,
table: str,
column: str,
bins: list[str],
dtype: HistogramDataType,
) -> PredicateRange:
min_value, max_value = self._get_min_max_from_bins(bins, dtype)
return PredicateRange(
table=table,
column=column,
min_value=min_value,
max_value=max_value,
dtype=dtype,
)
def _get_min_max_from_bins(
self,
bins: list[str],
dtype: HistogramDataType,
) -> tuple[SupportedHistogramType, SupportedHistogramType]:
"""Convert the bins string representation to a tuple of min and max values.
Args:
bins (str): String representation of bins.
row_retention_probability (float): Probability of retaining rows.
Returns:
tuple: Tuple containing min and max values.
"""
histogram_array: SuportedHistogramArrayType = self._cast_array(bins, dtype)
subrange_length = math.ceil(
self.predicate_params.row_retention_probability * len(histogram_array)
)
start_index = random.randint(0, len(histogram_array) - subrange_length)
min_value = histogram_array[start_index]
max_value = histogram_array[
min(start_index + subrange_length, len(histogram_array) - 1)
]
return min_value, max_value
|
Changes to src/query_generator/tools/histograms.py.
| ︙ | ︙ | |||
16 17 18 19 20 21 22 | get_equi_height_histogram, get_frequent_non_null_values, get_histogram_excluding_common_values, get_tables, ) from query_generator.utils.exceptions import InvalidHistogramTypeError | | > > > | 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | get_equi_height_histogram, get_frequent_non_null_values, get_histogram_excluding_common_values, get_tables, ) from query_generator.utils.exceptions import InvalidHistogramTypeError class MostCommonValuesColumns(Enum): VALUE = "value" COUNT = "count" class RedundantHistogramsDataType(Enum): """ This class was made for compatibility with old code that generated this histogram: https://github.com/udao-moo/udao-spark-optimizer-dev/blob/main |
| ︙ | ︙ | |||
87 88 89 90 91 92 93 | def get_most_common_values( con: duckdb.DuckDBPyConnection, table: str, column: str, common_value_size: int, | < < < | < | 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
def get_most_common_values(
con: duckdb.DuckDBPyConnection,
table: str,
column: str,
common_value_size: int,
) -> list[RawDuckDBMostCommonValues]:
return get_frequent_non_null_values(con, table, column, common_value_size)
def get_histogram_array(histogram_params: HistogramParams) -> list[str]:
histogram_raw = get_equi_height_histogram(
histogram_params.con,
histogram_params.table,
histogram_params.column.column_name,
|
| ︙ | ︙ | |||
114 115 116 117 118 119 120 | def get_histogram_array_excluding_common_values( histogram_params: HistogramParams, common_values_size: int, distinct_count: int, ) -> list[str]: histogram_array: list[RawDuckDBHistograms] = [] | < < | < | 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
def get_histogram_array_excluding_common_values(
histogram_params: HistogramParams,
common_values_size: int,
distinct_count: int,
) -> list[str]:
histogram_array: list[RawDuckDBHistograms] = []
if distinct_count > common_values_size:
histogram_array = get_histogram_excluding_common_values(
histogram_params.con,
histogram_params.table,
histogram_params.column.column_name,
histogram_params.histogram_size,
common_values_size,
)
|
| ︙ | ︙ | |||
180 181 182 183 184 185 186 |
if include_mvc:
# Get most common values
most_common_values = get_most_common_values(
con,
table,
column.column_name,
common_values_size,
| < > | > > | 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
if include_mvc:
# Get most common values
most_common_values = get_most_common_values(
con,
table,
column.column_name,
common_values_size,
)
# Get histogram array excluding common values
histogram_array_excluding_mcv = (
get_histogram_array_excluding_common_values(
histogram_params,
common_values_size,
distinct_count,
)
)
row_dict |= {
HistogramColumns.MOST_COMMON_VALUES.value: [
{
MostCommonValuesColumns.VALUE.value: value.value,
MostCommonValuesColumns.COUNT.value: value.count,
}
for value in most_common_values
],
HistogramColumns.HISTOGRAM_MCV.value: histogram_array_excluding_mcv,
}
rows.append(row_dict)
return pl.DataFrame(rows)
|
| ︙ | ︙ |
Changes to src/query_generator/utils/definitions.py.
| ︙ | ︙ | |||
15 16 17 18 19 20 21 22 23 24 25 | class Dataset(Enum): TPCDS = "TPCDS" TPCH = "TPCH" JOB = "JOB" @dataclass class QueryGenerationParameters: max_hops: int max_queries_per_signature: int max_queries_per_fact_table: int | > > > > > > > > > > > > > > > > > > > > > > > | < < < > | 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | class Dataset(Enum): TPCDS = "TPCDS" TPCH = "TPCH" JOB = "JOB" @dataclass class PredicateOperatorProbability: """Probability of using a specific predicate operator. They are based on choice with weights for each operator. """ operator_in: float operator_equal: float operator_range: float @dataclass class PredicateParameters: extra_predicates: int row_retention_probability: float operator_weights: PredicateOperatorProbability equality_lower_bound_probability: float extra_values_for_in: int # TODO(Gabriel): http://localhost:8080/tktview/205e90a1fa @dataclass class QueryGenerationParameters: dataset: Dataset max_hops: int max_queries_per_signature: int max_queries_per_fact_table: int keep_edge_probability: float seen_subgraphs: dict[int, bool] predicate_parameters: PredicateParameters @dataclass class GeneratedQueryFeatures: query: str template_number: int predicate_number: int |
| ︙ | ︙ |
Changes to src/query_generator/utils/exceptions.py.
| ︙ | ︙ | |||
23 24 25 26 27 28 29 |
class DuplicateEdgesError(Exception):
def __init__(self, table: str) -> None:
super().__init__(f"Duplicate edges found for table {table}.")
| | | 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
class DuplicateEdgesError(Exception):
def __init__(self, table: str) -> None:
super().__init__(f"Duplicate edges found for table {table}.")
class UnkownDatasetError(Exception):
def __init__(self, dataset: str) -> None:
super().__init__(f"Unknown dataset: {dataset}")
class MissingScaleFactorError(Exception):
def __init__(self, dataset: str) -> None:
super().__init__(
|
| ︙ | ︙ |
Added src/query_generator/utils/params.py.
> > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import tomllib
from dataclasses import dataclass
from pathlib import Path
from typing import TypeVar
from cattrs import structure
from query_generator.utils.definitions import (
Dataset,
PredicateOperatorProbability,
PredicateParameters,
)
@dataclass
class SearchParametersEndpoint:
"""
Represents the parameters used for configuring search queries, including
query builder, subgraph, and predicate options.
This class is designed to support both the `IN` and `=` statements in
query generation.
Attributes:
dataset (Dataset): The dataset to be queried.
dev (bool): Flag indicating whether to use development settings.
max_queries_per_fact_table (int): Maximum number of queries per fact
table.
max_queries_per_signature (int): Maximum number of queries per
signature.
unique_joins (bool): Whether to enforce unique joins in the subgraph.
max_hops (list[int]): Maximum number of hops allowed in the subgraph.
keep_edge_probability (float): Probability of retaining an edge in the
subgraph.
extra_predicates (list[int]): Number of additional predicates to include
in the query.
row_retention_probability (list[float]): Probability of retaining a row
for range predicates
operator_weights (PredicateOperatorProbability): Probability
distribution for predicate operators.
equality_lower_bound_probability (float): Lower bound probability when
using the `=` and the `IN` operators
"""
# Query Builder
dataset: Dataset
dev: bool
max_queries_per_fact_table: int
max_queries_per_signature: int
# Subgraph
unique_joins: bool
max_hops: list[int]
keep_edge_probability: float
# Predicates
extra_predicates: list[int]
row_retention_probability: list[float]
operator_weights: PredicateOperatorProbability
equality_lower_bound_probability: list[float]
extra_values_for_in: int
@dataclass
class SnowflakeEndpoint:
"""
Represents the parameters used for configuring query generation,
including query builder, subgraph, and predicate options.
Attributes:
dataset (Dataset): The dataset to be used for query generation.
max_queries_per_signature (int): Maximum number of queries to generate
per signature.
max_queries_per_fact_table (int): Maximum number of queries to generate
per fact table.
max_hops (int): Maximum number of hops allowed in the subgraph.
keep_edge_probability (float): Probability of retaining an edge in the
subgraph.
extra_predicates (int): Number of extra predicates to add to the query.
row_retention_probability (float): Probability of retaining a row after
applying predicates.
operator_weights (PredicateOperatorProbability): Probability
distribution for predicate operators.
equality_lower_bound_probability (float): Probability of using a lower
bound for equality predicates.
"""
# Query builder
dataset: Dataset
max_queries_per_signature: int
max_queries_per_fact_table: int
# Subgraph
max_hops: int
keep_edge_probability: float
# Predicates
predicate_parameters: PredicateParameters
T = TypeVar("T")
def read_and_parse_toml(path: Path, cls: type[T]) -> T:
toml_dict = tomllib.loads(path.read_text())
return structure(toml_dict, cls)
|
Changes to tests/duckdb/test_binning.py.
1 2 3 4 5 6 7 8 9 10 | from unittest import mock import polars as pl import pytest from query_generator.duckdb_connection.binning import ( SearchParameters, run_snowflake_param_seach, ) from query_generator.tools.cherry_pick_binning import make_bins_in_csv | > > | | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import tomllib
from unittest import mock
from cattrs import structure
import polars as pl
import pytest
from query_generator.duckdb_connection.binning import (
SearchParameters,
run_snowflake_param_seach,
)
from query_generator.tools.cherry_pick_binning import make_bins_in_csv
from query_generator.utils.params import SearchParametersEndpoint
@pytest.mark.parametrize(
"count_star, upper_bound, total_bins, expected_bin",
[
(5, 10, 5, 3),
(0, 10, 5, 0),
|
| ︙ | ︙ | |||
40 41 42 43 44 45 46 |
" but got {computed_bin}"
)
@pytest.mark.parametrize(
"extra_predicates, expected_call_count, unique_joins",
[
| | | | | > > > > > > > > > > > > > > > > > > > < < < < < > | 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
" but got {computed_bin}"
)
@pytest.mark.parametrize(
"extra_predicates, expected_call_count, unique_joins",
[
("[1]", 120 * 1 + 14, "false"),
("[1]", 120 * 1 + 14, "true"),
# Inventory is small and prooduces 14 queries total
("[1, 2]", 120 * 2 + 14, "true"),
("[1, 2]", 120 * 2 + 14 * 2, "false"),
],
)
def test_binning_calls(extra_predicates, expected_call_count, unique_joins):
with mock.patch(
"query_generator.duckdb_connection.binning.Writer.write_query_to_batch",
) as mock_writer:
with mock.patch(
"query_generator.duckdb_connection.binning.get_result_from_duckdb",
) as mock_connect:
mock_connect.return_value = 0
data_toml = f"""
dataset = "TPCDS"
dev = true
max_hops = [1]
extra_predicates = {extra_predicates}
row_retention_probability = [0.2]
unique_joins = {unique_joins}
max_queries_per_fact_table = 10
max_queries_per_signature = 2
keep_edge_probability = 0.2
equality_lower_bound_probability = [0]
extra_values_for_in = 3
[operator_weights]
operator_in = 1
operator_range = 3
operator_equal = 3
"""
user_input = structure(tomllib.loads(data_toml), SearchParametersEndpoint)
run_snowflake_param_seach(
search_params=SearchParameters(
scale_factor=0,
con=None,
user_input=user_input,
),
)
assert mock_writer.call_count == expected_call_count, (
f"Expected {expected_call_count} calls to write_query, "
f"but got {mock_writer.call_count}"
)
|
Changes to tests/duckdb/test_duckdb_utils.py.
1 2 3 4 5 6 7 8 | from query_generator.duckdb_connection.setup import setup_duckdb from query_generator.duckdb_connection.utils import ( get_distinct_count, get_equi_height_histogram, get_frequent_non_null_values, ) from query_generator.tools.histograms import DuckDBHistogramParser from query_generator.utils.definitions import Dataset | > > | < | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import datetime from query_generator.duckdb_connection.setup import setup_duckdb from query_generator.duckdb_connection.utils import ( get_distinct_count, get_equi_height_histogram, get_frequent_non_null_values, ) from query_generator.tools.histograms import DuckDBHistogramParser from query_generator.utils.definitions import Dataset from tests.utils import is_float def test_distinct_values(): """Test the setup of DuckDB.""" # Setup DuckDB con = setup_duckdb(Dataset.TPCDS, 0.1) assert get_distinct_count(con, "call_center", "cc_call_center_sk") == 1 |
| ︙ | ︙ |
Changes to tests/file_management/test_read_histograms.py.
1 2 3 4 | from unittest import mock import pytest | > < | | | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
from unittest import mock
import polars as pl
import pytest
from query_generator.predicate_generator.predicate_generator import (
HistogramDataType,
PredicateGenerator,
)
from query_generator.tools.histograms import HistogramColumns
from query_generator.utils.definitions import Dataset, PredicateParameters
from query_generator.utils.exceptions import InvalidHistogramTypeError
def test_read_histograms():
for dataset in Dataset:
predicate_generator = PredicateGenerator(dataset, None)
histogram = predicate_generator.read_histogram()
assert not histogram.is_empty()
assert histogram[HistogramColumns.DTYPE.value].dtype == pl.Utf8
assert histogram[HistogramColumns.COLUMN.value].dtype == pl.Utf8
assert histogram[HistogramColumns.DTYPE.value].dtype == pl.Utf8
assert histogram[HistogramColumns.HISTOGRAM.value].dtype == pl.List(pl.Utf8)
|
| ︙ | ︙ | |||
74 75 76 77 78 79 80 |
max_index,
dtype,
):
with mock.patch(
"query_generator.predicate_generator.predicate_generator.random.randint",
return_value=mock_rand,
):
| | > > > > > > > > > | | | | 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
max_index,
dtype,
):
with mock.patch(
"query_generator.predicate_generator.predicate_generator.random.randint",
return_value=mock_rand,
):
predicate_generator = PredicateGenerator(
Dataset.TPCH,
PredicateParameters(
extra_predicates=None,
row_retention_probability=row_retention_probability,
operator_weights=None,
equality_lower_bound_probability=None,
extra_values_for_in=None,
),
)
min_value, max_value = predicate_generator._get_min_max_from_bins(
bins_array, dtype
)
assert min_value == bins_array[min_index]
assert max_value == bins_array[max_index]
def test_get_invalid_histogram_type():
predicate_generator = PredicateGenerator(Dataset.TPCH, None)
with pytest.raises(InvalidHistogramTypeError):
predicate_generator._get_histogram_type("not_supported_type")
@pytest.mark.parametrize(
"input_type, expected_type",
[
("INTEGER", HistogramDataType.INT),
("BIGINT", HistogramDataType.INT),
("DECIMAL(10,2)", HistogramDataType.FLOAT),
("DECIMAL(7,4)", HistogramDataType.FLOAT),
("DATE", HistogramDataType.DATE),
("VARCHAR", HistogramDataType.STRING),
],
)
def test_get_valid_histogram_type(input_type, expected_type):
predicate_generator = PredicateGenerator(Dataset.TPCH, None)
assert predicate_generator._get_histogram_type(input_type) == expected_type
|
Changes to tests/query_generation/test_make_queries.py.
1 2 3 | from unittest import mock import pytest | < | > > | > | > > | > > > > | > > | | > | > > | > > > > | > > | > | | > > | > > > > | > > | > | > > > > > > > > > > > | | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from unittest import mock
import pytest
from query_generator.database_schemas.schemas import get_schema
from query_generator.join_based_query_generator.snowflake import (
QueryBuilder,
generate_and_write_queries,
)
from query_generator.predicate_generator.predicate_generator import (
HistogramDataType,
PredicateRange,
)
from query_generator.utils.definitions import (
Dataset,
PredicateOperatorProbability,
PredicateParameters,
QueryGenerationParameters,
)
from query_generator.utils.exceptions import UnkownDatasetError
from pypika import OracleQuery
from pypika import functions as fn
def test_tpch_query_generation():
with mock.patch(
"query_generator.join_based_query_generator.snowflake.Writer.write_query",
) as mock_writer:
generate_and_write_queries(
QueryGenerationParameters(
dataset=Dataset.TPCDS,
max_hops=1,
max_queries_per_fact_table=1,
max_queries_per_signature=1,
keep_edge_probability=0.2,
seen_subgraphs={},
predicate_parameters=PredicateParameters(
operator_weights=PredicateOperatorProbability(
operator_in=0.4,
operator_equal=0.4,
operator_range=0.2,
),
extra_predicates=1,
row_retention_probability=0.2,
equality_lower_bound_probability=0,
extra_values_for_in=3,
),
)
)
assert mock_writer.call_count > 5
def test_tpcds_query_generation():
with mock.patch(
"query_generator.join_based_query_generator.snowflake.Writer.write_query",
) as mock_writer:
generate_and_write_queries(
QueryGenerationParameters(
dataset=Dataset.TPCDS,
max_hops=1,
max_queries_per_fact_table=1,
max_queries_per_signature=1,
keep_edge_probability=0.2,
seen_subgraphs={},
predicate_parameters=PredicateParameters(
operator_weights=PredicateOperatorProbability(
operator_in=0.4,
operator_equal=0.4,
operator_range=0.2,
),
extra_predicates=1,
row_retention_probability=0.2,
equality_lower_bound_probability=0,
extra_values_for_in=3,
),
),
)
assert mock_writer.call_count > 5
def test_non_implemented_dataset():
with mock.patch(
"query_generator.join_based_query_generator.snowflake.Writer.write_query",
) as mock_writer:
with pytest.raises(UnkownDatasetError):
generate_and_write_queries(
QueryGenerationParameters(
dataset="non_implemented_dataset",
max_hops=1,
max_queries_per_fact_table=1,
max_queries_per_signature=1,
keep_edge_probability=0.2,
seen_subgraphs={},
predicate_parameters=PredicateParameters(
operator_weights=PredicateOperatorProbability(
operator_in=0.4,
operator_equal=0.4,
operator_range=0.2,
),
extra_predicates=1,
row_retention_probability=0.2,
equality_lower_bound_probability=0,
extra_values_for_in=3,
),
),
)
assert mock_writer.call_count == 0
def test_add_rage_supports_all_histogram_types():
tables_schema, _ = get_schema(Dataset.TPCH)
query_builder = QueryBuilder(
None,
tables_schema,
Dataset.TPCH,
PredicateParameters(
extra_predicates=None,
row_retention_probability=0.2,
operator_weights=None,
equality_lower_bound_probability=None,
extra_values_for_in=None,
),
)
for dtype in HistogramDataType:
query_builder._add_range(
OracleQuery()
.from_(query_builder.table_to_pypika_table["lineitem"])
.select(fn.Count("*")),
PredicateRange(
table="lineitem",
column="foo",
min_value=2020,
max_value=2020,
dtype=dtype,
),
)
|