Query-generation

Changes On Branch new-predicates
Login

Changes On Branch new-predicates

Many hyperlinks are disabled.
Use anonymous login to enable hyperlinks.

Changes In Branch new-predicates Excluding Merge-Ins

This is equivalent to a diff from 0f0856db5a to 4162607284

2025-05-28
09:27
Merges IN and = predicate check-in: 0a0518ab14 user: mathos tags: trunk
09:24
Adds equality lower bound as an array. Leaf check-in: 4162607284 user: mathos tags: new-predicates
00:05
Finally. An stable version check-in: 48e17f1cde user: mathos tags: new-predicates
2025-05-26
11:22
Minor refactor to predicate class. Ticket [1e726428f6e719fb] check-in: c93b2b766c user: mathos tags: new-predicates
11:12
Fix to be able to save the CSV check-in: 0f0856db5a user: mathos tags: trunk
09:57
adds table_size to histogram check-in: 288ba9b582 user: mathos tags: trunk

Changes to data/histograms/histogram_tpcds.parquet.

cannot compute difference between binary files

Added params_config/search_params/tpcds.toml.
































>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
dataset = "TPCDS"
dev = true
max_hops = [1,2,4]
extra_predicates = [1,3,5]
row_retention_probability = [0.2, 0.3, 0.4, 0.6, 0.8, 0.85, 0.9, 1.0]
unique_joins = true
max_queries_per_fact_table = 10
max_queries_per_signature = 2
keep_edge_probability = 0.2
equality_lower_bound_probability = [0,0.1]
extra_values_for_in = 3

[operator_weights]
operator_in = 1
operator_range = 3
operator_equal = 3
Added params_config/search_params/tpcds_dev.toml.
































>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
dataset = "TPCDS"
dev = true
max_hops = [1]
extra_predicates = [5]
row_retention_probability = [0.2, 0.9]
unique_joins = true
max_queries_per_fact_table = 1
max_queries_per_signature = 2
keep_edge_probability = 0.2
equality_lower_bound_probability = [0,0.1]
extra_values_for_in = 3

[operator_weights]
operator_in = 1
operator_range = 3
operator_equal = 3
Added params_config/snowflake/tpcds.toml.




































>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
dataset = "TPCDS"
max_hops = 3
max_queries_per_fact_table = 100
max_queries_per_signature = 1
keep_edge_probability = 0.2



[predicate_parameters]
row_retention_probability = 0.2
extra_predicates = 3
equality_lower_bound_probability = 0.00
extra_values_for_in = 3

[predicate_parameters.operator_weights]
operator_in = 1
operator_range = 3
operator_equal = 3
Changes to pyproject.toml.
26
27
28
29
30
31
32

33
34
35
36
37
38
39
typer = ">=0.15.2,<0.16"
rich  = ">=14.0.0,<15"
pypika = ">=0.48.9,<0.49"
numpy = ">=2.2.5,<3"
duckdb = ">=1.2.2,<2"
polars = ">=1.27.1,<2"
tqdm = "*"



[tool.pixi.feature.test.dependencies]
pytest = ">=8.3.5,<9"

[tool.pixi.feature.lint.dependencies]
ruff = ">=0.11.7,<0.12"







>







26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
typer = ">=0.15.2,<0.16"
rich  = ">=14.0.0,<15"
pypika = ">=0.48.9,<0.49"
numpy = ">=2.2.5,<3"
duckdb = ">=1.2.2,<2"
polars = ">=1.27.1,<2"
tqdm = "*"
cattrs = ">=24.1.2,<25"


[tool.pixi.feature.test.dependencies]
pytest = ">=8.3.5,<9"

[tool.pixi.feature.lint.dependencies]
ruff = ">=0.11.7,<0.12"
Changes to src/query_generator/database_schemas/schemas.py.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from typing import Any

from query_generator.database_schemas.tpcds import get_tpcds_table_info
from query_generator.database_schemas.tpch import get_tpch_table_info
from query_generator.utils.definitions import Dataset
from query_generator.utils.exceptions import (
  PartiallySupportedDatasetError,
  UnkwonDatasetError,
)


def get_schema(dataset: Dataset) -> tuple[dict[str, dict[str, Any]], list[str]]:
  """Get the schema of the database based on the dataset.

  Args:
      dataset (Dataset): The dataset to get the schema for.

  Returns:
      Tuple[Dict[str, Dict[str, Any]], List[str]]: A tuple containing the schema
      as a dictionary and a list of fact tables

  """
  if dataset == Dataset.TPCDS:
    return get_tpcds_table_info()
  if dataset == Dataset.TPCH:
    return get_tpch_table_info()
  if dataset == Dataset.JOB:
    raise PartiallySupportedDatasetError(dataset.value)
  raise UnkwonDatasetError(dataset)







|




















|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from typing import Any

from query_generator.database_schemas.tpcds import get_tpcds_table_info
from query_generator.database_schemas.tpch import get_tpch_table_info
from query_generator.utils.definitions import Dataset
from query_generator.utils.exceptions import (
  PartiallySupportedDatasetError,
  UnkownDatasetError,
)


def get_schema(dataset: Dataset) -> tuple[dict[str, dict[str, Any]], list[str]]:
  """Get the schema of the database based on the dataset.

  Args:
      dataset (Dataset): The dataset to get the schema for.

  Returns:
      Tuple[Dict[str, Dict[str, Any]], List[str]]: A tuple containing the schema
      as a dictionary and a list of fact tables

  """
  if dataset == Dataset.TPCDS:
    return get_tpcds_table_info()
  if dataset == Dataset.TPCH:
    return get_tpch_table_info()
  if dataset == Dataset.JOB:
    raise PartiallySupportedDatasetError(dataset.value)
  raise UnkownDatasetError(dataset)
Changes to src/query_generator/database_schemas/tpcds.py.
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
        "s_floor_space": {"max": 9917607, "min": 5010719},
        "s_gmt_offset": {"max": -5.0, "min": -6.0},
        "s_market_id": {"max": 10, "min": 1},
        "s_number_employees": {"max": 300, "min": 200},
        "s_rec_end_date": {"max": "2001-03-12", "min": "1999-03-13"},
        "s_rec_start_date": {"max": "2001-03-13", "min": "1997-03-13"},
        "s_store_sk": {"max": 402, "min": 1},
        "s_tax_precentage": {"max": 0.11, "min": 0.0},
      },
      "foreign_keys": [],
    },
    "store_returns": {
      "alias": "sr",
      "columns": {
        "sr_addr_sk": {"max": 1000000, "min": 1},







<







436
437
438
439
440
441
442

443
444
445
446
447
448
449
        "s_floor_space": {"max": 9917607, "min": 5010719},
        "s_gmt_offset": {"max": -5.0, "min": -6.0},
        "s_market_id": {"max": 10, "min": 1},
        "s_number_employees": {"max": 300, "min": 200},
        "s_rec_end_date": {"max": "2001-03-12", "min": "1999-03-13"},
        "s_rec_start_date": {"max": "2001-03-13", "min": "1997-03-13"},
        "s_store_sk": {"max": 402, "min": 1},

      },
      "foreign_keys": [],
    },
    "store_returns": {
      "alias": "sr",
      "columns": {
        "sr_addr_sk": {"max": 1000000, "min": 1},
Changes to src/query_generator/duckdb_connection/binning.py.
9
10
11
12
13
14
15
16
17

18
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77



78


79
80
81
82

83
84
85
86
87
88
89
90
91
92
93
94


95
96
97




98
99
100
101
102
103
104
  QueryGenerator,
)
from query_generator.join_based_query_generator.utils.query_writer import (
  Writer,
)
from query_generator.utils.definitions import (
  BatchGeneratedQueryFeatures,
  Dataset,
  Extension,

  QueryGenerationParameters,
)



@dataclass
class SearchParameters:
  dataset: Dataset
  scale_factor: int | float
  con: duckdb.DuckDBPyConnection
  max_hops: list[int]
  extra_predicates: list[int]
  row_retention_probability: list[float]
  unique_joins: bool


def get_result_from_duckdb(query: str, con: duckdb.DuckDBPyConnection) -> int:
  try:
    result = int(con.sql(query).fetchall()[0][0])
  except duckdb.BinderException as e:
    print(f"Invalid query, exception: {e},\n{query}")
    return -1
  return result


def get_total_iterations(search_params: SearchParameters) -> int:
  """Get the total number of iterations for the Snowflake binning process.

  Args:
    search_params (SearchParameters): The parameters for the Snowflake
    binning process.

  Returns:
    int: The total number of iterations.

  """
  return (
    len(search_params.max_hops)
    * len(search_params.extra_predicates)
    * len(search_params.row_retention_probability)

  )


def run_snowflake_param_seach(
  search_params: SearchParameters,
) -> None:
  """Run the Snowflake binning process. Binning is equiwidth binning.

  Args:
    parameters (BinningSnowflakeParameters): The parameters for
    the Snowflake binning process.

  """
  query_writer = Writer(
    search_params.dataset,
    Extension.SNOWFLAKE_SEARCH_PARAMS,
  )
  rows: list[dict[str, str | int | float]] = []
  total_iterations = get_total_iterations(search_params)
  batch_number = 0
  seen_subgraphs: dict[int, bool] = {}



  for max_hops, extra_predicates, row_retention_probability in tqdm(


    product(
      search_params.max_hops,
      search_params.extra_predicates,
      search_params.row_retention_probability,

    ),
    total=total_iterations,
    desc="Progress",
  ):
    batch_number += 1
    query_generator = QueryGenerator(
      QueryGenerationParameters(
        dataset=search_params.dataset,
        max_hops=max_hops,
        max_queries_per_fact_table=10,
        max_queries_per_signature=2,
        keep_edge_prob=0.2,


        extra_predicates=extra_predicates,
        row_retention_probability=float(row_retention_probability),
        seen_subgraphs=seen_subgraphs,




      )
    )
    for query in query_generator.generate_queries():
      selected_rows = get_result_from_duckdb(query.query, search_params.con)
      if selected_rows == -1:
        continue  # invalid query








<

>


>




|


<
<
<
<











|














>














|



|


>
>
>
|
>
>

|
|
|
>







|

|
|
|
>
>
|
|
<
>
>
>
>







9
10
11
12
13
14
15

16
17
18
19
20
21
22
23
24
25
26
27




28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

103
104
105
106
107
108
109
110
111
112
113
  QueryGenerator,
)
from query_generator.join_based_query_generator.utils.query_writer import (
  Writer,
)
from query_generator.utils.definitions import (
  BatchGeneratedQueryFeatures,

  Extension,
  PredicateParameters,
  QueryGenerationParameters,
)
from query_generator.utils.params import SearchParametersEndpoint


@dataclass
class SearchParameters:
  user_input: SearchParametersEndpoint
  scale_factor: int | float
  con: duckdb.DuckDBPyConnection






def get_result_from_duckdb(query: str, con: duckdb.DuckDBPyConnection) -> int:
  try:
    result = int(con.sql(query).fetchall()[0][0])
  except duckdb.BinderException as e:
    print(f"Invalid query, exception: {e},\n{query}")
    return -1
  return result


def get_total_iterations(search_params: SearchParametersEndpoint) -> int:
  """Get the total number of iterations for the Snowflake binning process.

  Args:
    search_params (SearchParameters): The parameters for the Snowflake
    binning process.

  Returns:
    int: The total number of iterations.

  """
  return (
    len(search_params.max_hops)
    * len(search_params.extra_predicates)
    * len(search_params.row_retention_probability)
    * len(search_params.equality_lower_bound_probability)
  )


def run_snowflake_param_seach(
  search_params: SearchParameters,
) -> None:
  """Run the Snowflake binning process. Binning is equiwidth binning.

  Args:
    parameters (BinningSnowflakeParameters): The parameters for
    the Snowflake binning process.

  """
  query_writer = Writer(
    search_params.user_input.dataset,
    Extension.SNOWFLAKE_SEARCH_PARAMS,
  )
  rows: list[dict[str, str | int | float]] = []
  total_iterations = get_total_iterations(search_params.user_input)
  batch_number = 0
  seen_subgraphs: dict[int, bool] = {}
  for (
    max_hops,
    extra_predicates,
    row_retention_probability,
    equality_lower_bound_probability,
  ) in tqdm(
    product(
      search_params.user_input.max_hops,
      search_params.user_input.extra_predicates,
      search_params.user_input.row_retention_probability,
      search_params.user_input.equality_lower_bound_probability,
    ),
    total=total_iterations,
    desc="Progress",
  ):
    batch_number += 1
    query_generator = QueryGenerator(
      QueryGenerationParameters(
        dataset=search_params.user_input.dataset,
        max_hops=max_hops,
        max_queries_per_fact_table=search_params.user_input.max_queries_per_fact_table,
        max_queries_per_signature=search_params.user_input.max_queries_per_signature,
        keep_edge_probability=search_params.user_input.keep_edge_probability,
        seen_subgraphs=seen_subgraphs,
        predicate_parameters=PredicateParameters(
          extra_predicates=extra_predicates,
          row_retention_probability=row_retention_probability,

          operator_weights=search_params.user_input.operator_weights,
          equality_lower_bound_probability=equality_lower_bound_probability,
          extra_values_for_in=search_params.user_input.extra_values_for_in,
        ),
      )
    )
    for query in query_generator.generate_queries():
      selected_rows = get_result_from_duckdb(query.query, search_params.con)
      if selected_rows == -1:
        continue  # invalid query

124
125
126
127
128
129
130
131
132
133
134
          "predicate_number": query.predicate_number,
          "fact_table": query.fact_table,
          "max_hops": max_hops,
          "row_retention_probability": row_retention_probability,
        },
      )
    # Update the seen subgraphs with the new ones
    if search_params.unique_joins:
      seen_subgraphs = query_generator.subgraph_generator.seen_subgraphs
  df_queries = pl.DataFrame(rows)
  query_writer.write_dataframe(df_queries)







|



133
134
135
136
137
138
139
140
141
142
143
          "predicate_number": query.predicate_number,
          "fact_table": query.fact_table,
          "max_hops": max_hops,
          "row_retention_probability": row_retention_probability,
        },
      )
    # Update the seen subgraphs with the new ones
    if search_params.user_input.unique_joins:
      seen_subgraphs = query_generator.subgraph_generator.seen_subgraphs
  df_queries = pl.DataFrame(rows)
  query_writer.write_dataframe(df_queries)
Changes to src/query_generator/duckdb_connection/setup.py.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import os

import duckdb

from query_generator.utils.definitions import Dataset
from query_generator.utils.exceptions import (
  MissingScaleFactorError,
  PartiallySupportedDatasetError,
  UnkwonDatasetError,
)


def load_and_install_libraries() -> None:
  duckdb.install_extension("TPCDS")
  duckdb.install_extension("TPCH")
  duckdb.load_extension("TPCDS")








|







1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import os

import duckdb

from query_generator.utils.definitions import Dataset
from query_generator.utils.exceptions import (
  MissingScaleFactorError,
  PartiallySupportedDatasetError,
  UnkownDatasetError,
)


def load_and_install_libraries() -> None:
  duckdb.install_extension("TPCDS")
  duckdb.install_extension("TPCH")
  duckdb.load_extension("TPCDS")
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
  if dataset == Dataset.TPCDS:
    con.execute(f"CALL dsdgen(sf = {scale_factor})")
  elif dataset == Dataset.TPCH:
    con.execute(f"CALL dbgen(sf = {scale_factor})")
  elif dataset == Dataset.JOB:
    raise PartiallySupportedDatasetError(dataset.value)
  else:
    raise UnkwonDatasetError(dataset)


def get_path(
  dataset: Dataset,
  scale_factor: float | int | None,
) -> str:
  if dataset in [Dataset.TPCDS, Dataset.TPCH]:
    return f"data/duckdb/{dataset.value}/{scale_factor}.db"
  if dataset == Dataset.JOB:
    return f"data/duckdb/{dataset.value}/job.db"
  raise UnkwonDatasetError(dataset.value)


def setup_duckdb(
  dataset: Dataset,
  scale_factor: int | float | None = None,
) -> duckdb.DuckDBPyConnection:
  """Installs TPCDS and TPCH datasets in DuckDB.







|










|







25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
  if dataset == Dataset.TPCDS:
    con.execute(f"CALL dsdgen(sf = {scale_factor})")
  elif dataset == Dataset.TPCH:
    con.execute(f"CALL dbgen(sf = {scale_factor})")
  elif dataset == Dataset.JOB:
    raise PartiallySupportedDatasetError(dataset.value)
  else:
    raise UnkownDatasetError(dataset)


def get_path(
  dataset: Dataset,
  scale_factor: float | int | None,
) -> str:
  if dataset in [Dataset.TPCDS, Dataset.TPCH]:
    return f"data/duckdb/{dataset.value}/{scale_factor}.db"
  if dataset == Dataset.JOB:
    return f"data/duckdb/{dataset.value}/job.db"
  raise UnkownDatasetError(dataset.value)


def setup_duckdb(
  dataset: Dataset,
  scale_factor: int | float | None = None,
) -> duckdb.DuckDBPyConnection:
  """Installs TPCDS and TPCH datasets in DuckDB.
Changes to src/query_generator/join_based_query_generator/snowflake.py.
16
17
18
19
20
21
22

23



24
25
26
27
28

29
30
31
32
33
34
35
36
37
38
39
40
41

42
43
44
45
46
47
48
49
50
51
52
53
54

# fmt: on
from query_generator.join_based_query_generator.utils.query_writer import (
  Writer,
)
from query_generator.predicate_generator.predicate_generator import (
  HistogramDataType,

  PredicateGenerator,



)
from query_generator.utils.definitions import (
  Dataset,
  Extension,
  GeneratedQueryFeatures,

  QueryGenerationParameters,
)
from query_generator.utils.exceptions import InvalidHistogramTypeError
from query_generator.utils.utils import set_seed


class QueryBuilder:
  def __init__(
    self,
    subgraph_generator: SubGraphGenerator,
    # TODO(Gabriel): http://localhost:8080/tktview/b9400c203a38f3aef46ec250d98563638ba7988b
    tables_schema: Any,
    dataset: Dataset,

  ) -> None:
    self.sub_graph_gen = subgraph_generator
    self.table_to_pypika_table = {
      i: Table(i, alias=tables_schema[i]["alias"]) for i in tables_schema
    }
    self.predicate_gen = PredicateGenerator(dataset)
    self.tables_schema = tables_schema

  def get_subgraph_tables(
    self,
    subgraph: list[ForeignKeyGraph.Edge],
  ) -> list[str]:
    return list(







>

>
>
>





>













>





|







16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

# fmt: on
from query_generator.join_based_query_generator.utils.query_writer import (
  Writer,
)
from query_generator.predicate_generator.predicate_generator import (
  HistogramDataType,
  PredicateEquality,
  PredicateGenerator,
  PredicateIn,
  PredicateRange,
  SupportedHistogramType,
)
from query_generator.utils.definitions import (
  Dataset,
  Extension,
  GeneratedQueryFeatures,
  PredicateParameters,
  QueryGenerationParameters,
)
from query_generator.utils.exceptions import InvalidHistogramTypeError
from query_generator.utils.utils import set_seed


class QueryBuilder:
  def __init__(
    self,
    subgraph_generator: SubGraphGenerator,
    # TODO(Gabriel): http://localhost:8080/tktview/b9400c203a38f3aef46ec250d98563638ba7988b
    tables_schema: Any,
    dataset: Dataset,
    predicate_params: PredicateParameters,
  ) -> None:
    self.sub_graph_gen = subgraph_generator
    self.table_to_pypika_table = {
      i: Table(i, alias=tables_schema[i]["alias"]) for i in tables_schema
    }
    self.predicate_gen = PredicateGenerator(dataset, predicate_params)
    self.tables_schema = tables_schema

  def get_subgraph_tables(
    self,
    subgraph: list[ForeignKeyGraph.Edge],
  ) -> list[str]:
    return list(
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

98





99
100
101
102
103
104
105

106
107
108
109
110
111
112
113
114
115
116
117
118
119



120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
      )
    return query

  def add_predicates(
    self,
    subgraph: list[ForeignKeyGraph.Edge],
    query: OracleQuery,
    extra_predicates: int,
    row_retention_probability: float,
  ) -> OracleQuery:
    subgraph_tables = self.get_subgraph_tables(subgraph)
    for predicate in self.predicate_gen.get_random_predicates(
      subgraph_tables,
      extra_predicates,
      row_retention_probability,
    ):

      query = self._add_range(query, predicate)





    return query

  def _add_range(
    self, query: OracleQuery, predicate: PredicateGenerator.Predicate
  ) -> OracleQuery:
    if predicate.dtype in [HistogramDataType.INT, HistogramDataType.FLOAT]:
      return self._add_range_number(query, predicate)

    if predicate.dtype in [HistogramDataType.DATE]:
      return self._add_range_date(query, predicate)
    if predicate.dtype in [HistogramDataType.STRING]:
      return self._add_range_string(query, predicate)
    raise InvalidHistogramTypeError(str(predicate.dtype))

  def _add_range_number(
    self, query: OracleQuery, predicate: PredicateGenerator.Predicate
  ) -> OracleQuery:
    return query.where(
      self.table_to_pypika_table[predicate.table][predicate.column]
      >= predicate.min_value,
    ).where(
      self.table_to_pypika_table[predicate.table][predicate.column]



      <= predicate.max_value,
    )

  def _add_range_date(
    self, query: OracleQuery, predicate: PredicateGenerator.Predicate
  ) -> OracleQuery:
    return query.where(
      self.table_to_pypika_table[predicate.table][predicate.column]
      >= fn.Cast(predicate.min_value, "date"),
    ).where(
      self.table_to_pypika_table[predicate.table][predicate.column]
      <= fn.Cast(predicate.max_value, "date"),
    )

  def _add_range_string(
    self, query: OracleQuery, predicate: PredicateGenerator.Predicate
  ) -> OracleQuery:
    return query.where(
      self.table_to_pypika_table[predicate.table][predicate.column]
      >= predicate.min_value,
    ).where(
      self.table_to_pypika_table[predicate.table][predicate.column]
      <= predicate.max_value

    )


class QueryGenerator:
  def __init__(self, params: QueryGenerationParameters) -> None:
    set_seed()
    self.params = params
    self.tables_schema, self.fact_tables = get_schema(params.dataset)
    self.foreign_key_graph = ForeignKeyGraph(self.tables_schema)
    self.subgraph_generator = SubGraphGenerator(
      self.foreign_key_graph,
      params.keep_edge_prob,
      params.max_hops,
      params.seen_subgraphs,
    )
    self.query_builder = QueryBuilder(
      self.subgraph_generator,
      self.tables_schema,
      params.dataset,

    )

  def generate_queries(self) -> Iterator[GeneratedQueryFeatures]:
    for fact_table in self.fact_tables:
      for cnt, subgraph in enumerate(
        self.subgraph_generator.generate_subgraph(
          fact_table,
          self.params.max_queries_per_fact_table,
        ),
      ):
        query = self.query_builder.generate_query_from_subgraph(subgraph)
        for idx in range(1, self.params.max_queries_per_signature + 1):
          query = self.query_builder.add_predicates(
            subgraph,
            query,
            self.params.extra_predicates,
            self.params.row_retention_probability,
          )

          yield GeneratedQueryFeatures(
            query=query.get_sql(),
            template_number=cnt,
            predicate_number=idx,
            fact_table=fact_table,







<
<




<
<

>
|
>
>
>
>
>


|
|
|
<
<
>
|
|
<
|
<

|
|


<
<
<

>
>
>
|


|
|



|
<
<
<


<
|
<

|
|
<
<
<
>











|







>















<
<







88
89
90
91
92
93
94


95
96
97
98


99
100
101
102
103
104
105
106
107
108
109
110
111


112
113
114

115

116
117
118
119
120



121
122
123
124
125
126
127
128
129
130
131
132
133



134
135

136

137
138
139



140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175


176
177
178
179
180
181
182
      )
    return query

  def add_predicates(
    self,
    subgraph: list[ForeignKeyGraph.Edge],
    query: OracleQuery,


  ) -> OracleQuery:
    subgraph_tables = self.get_subgraph_tables(subgraph)
    for predicate in self.predicate_gen.get_random_predicates(
      subgraph_tables,


    ):
      if isinstance(predicate, PredicateRange):
        return self._add_range(query, predicate)
      if isinstance(predicate, PredicateEquality):
        return self._add_equality(query, predicate)
      if isinstance(predicate, PredicateIn):
        return self._add_in(query, predicate)
      raise InvalidHistogramTypeError(str(predicate.dtype))
    return query

  def _cast_if_needed(
    self, value: SupportedHistogramType, dtype: HistogramDataType
  ) -> Any:


    """Cast the value to the appropriate type if needed."""
    if dtype == HistogramDataType.DATE:
      return fn.Cast(value, "date")

    return value


  def _add_range(
    self, query: OracleQuery, predicate: PredicateRange
  ) -> OracleQuery:
    return query.where(



      self.table_to_pypika_table[predicate.table][predicate.column]
      >= self._cast_if_needed(predicate.min_value, predicate.dtype),
    ).where(
      self.table_to_pypika_table[predicate.table][predicate.column]
      <= self._cast_if_needed(predicate.max_value, predicate.dtype)
    )

  def _add_equality(
    self, query: OracleQuery, predicate: PredicateEquality
  ) -> OracleQuery:
    return query.where(
      self.table_to_pypika_table[predicate.table][predicate.column]
      == predicate.equality_value



    )


  def _add_in(self, query: OracleQuery, predicate: PredicateIn) -> OracleQuery:

    return query.where(
      self.table_to_pypika_table[predicate.table][predicate.column].isin(
        [self._cast_if_needed(i, predicate.dtype) for i in predicate.in_values]



      )
    )


class QueryGenerator:
  def __init__(self, params: QueryGenerationParameters) -> None:
    set_seed()
    self.params = params
    self.tables_schema, self.fact_tables = get_schema(params.dataset)
    self.foreign_key_graph = ForeignKeyGraph(self.tables_schema)
    self.subgraph_generator = SubGraphGenerator(
      self.foreign_key_graph,
      params.keep_edge_probability,
      params.max_hops,
      params.seen_subgraphs,
    )
    self.query_builder = QueryBuilder(
      self.subgraph_generator,
      self.tables_schema,
      params.dataset,
      params.predicate_parameters,
    )

  def generate_queries(self) -> Iterator[GeneratedQueryFeatures]:
    for fact_table in self.fact_tables:
      for cnt, subgraph in enumerate(
        self.subgraph_generator.generate_subgraph(
          fact_table,
          self.params.max_queries_per_fact_table,
        ),
      ):
        query = self.query_builder.generate_query_from_subgraph(subgraph)
        for idx in range(1, self.params.max_queries_per_signature + 1):
          query = self.query_builder.add_predicates(
            subgraph,
            query,


          )

          yield GeneratedQueryFeatures(
            query=query.get_sql(),
            template_number=cnt,
            predicate_number=idx,
            fact_table=fact_table,
Changes to src/query_generator/join_based_query_generator/utils/subgraph_generator.py.
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
MAX_ATTEMPTS_FOR_NEW_SUBGRAPH = 1000


class SubGraphGenerator:
  def __init__(
    self,
    graph: ForeignKeyGraph,
    keep_edge_prob: float,
    max_hops: int,
    seen_subgraphs: dict[int, bool],
  ) -> None:
    self.hops = max_hops
    self.keep_edge_prob = keep_edge_prob
    self.graph = graph
    self.seen_subgraphs: dict[int, bool] = seen_subgraphs.copy()

  def get_random_subgraph(self, fact_table: str) -> list[ForeignKeyGraph.Edge]:
    """Starting from the fact table, for each edge of the current table we
    decide based on the keep_edge_probability whether to keep the edge or not.


    We repeat this process up until the maximum number of hops.
    """

    @dataclass
    class JoinDepthNode:
      table: str
      depth: int

    queue: deque[JoinDepthNode] = deque()
    queue.append(JoinDepthNode(fact_table, 0))
    edges_subgraph = []

    while queue:
      current_node = queue.popleft()
      if current_node.depth >= self.hops:
        continue

      current_edges = self.graph.get_edges(current_node.table)
      for current_edge in current_edges:
        if random.random() < self.keep_edge_prob:
          edges_subgraph.append(current_edge)
          queue.append(
            JoinDepthNode(
              current_edge.reference_table.name,
              current_node.depth + 1,
            ),
          )







|




|





|
>




















|







9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
MAX_ATTEMPTS_FOR_NEW_SUBGRAPH = 1000


class SubGraphGenerator:
  def __init__(
    self,
    graph: ForeignKeyGraph,
    keep_edge_probability: float,
    max_hops: int,
    seen_subgraphs: dict[int, bool],
  ) -> None:
    self.hops = max_hops
    self.keep_edge_probability = keep_edge_probability
    self.graph = graph
    self.seen_subgraphs: dict[int, bool] = seen_subgraphs.copy()

  def get_random_subgraph(self, fact_table: str) -> list[ForeignKeyGraph.Edge]:
    """Starting from the fact table, for each edge of the current table we
    decide based on the keep_edge_probabilityability whether to keep the
    edge or not.

    We repeat this process up until the maximum number of hops.
    """

    @dataclass
    class JoinDepthNode:
      table: str
      depth: int

    queue: deque[JoinDepthNode] = deque()
    queue.append(JoinDepthNode(fact_table, 0))
    edges_subgraph = []

    while queue:
      current_node = queue.popleft()
      if current_node.depth >= self.hops:
        continue

      current_edges = self.graph.get_edges(current_node.table)
      for current_edge in current_edges:
        if random.random() < self.keep_edge_probability:
          edges_subgraph.append(current_edge)
          queue.append(
            JoinDepthNode(
              current_edge.reference_table.name,
              current_node.depth + 1,
            ),
          )
Changes to src/query_generator/main.py.
26
27
28
29
30
31
32





33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

105
106
107
108
109
110
111
112
113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

167
168
169
170
171
172
173
174
175
176
177
178
179
180

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
  make_redundant_histograms,
  query_histograms,
)
from query_generator.utils.definitions import (
  Dataset,
  Extension,
  QueryGenerationParameters,





)
from query_generator.utils.show_messages import show_dev_warning
from query_generator.utils.utils import validate_file_path

app = typer.Typer(name="Query Generation")


@app.command()
def snowflake(
  dataset: Annotated[
    Dataset,
    typer.Option("--dataset", "-d", help="The dataset used"),
  ],
  max_hops: Annotated[
    int,
    typer.Option(
      "--max-hops",
      "-h",
      help="The maximum number of hops",
      min=1,
      max=5,
    ),
  ] = 3,
  max_queries_per_fact_table: Annotated[
    int,
    typer.Option(
      "--fact",
      "-f",
      help="The maximum number of queries per fact table",
      min=1,
    ),
  ] = 100,
  max_queries_per_signature: Annotated[
    int,
    typer.Option(
      "--signature",
      "-s",
      help="The maximum number of queries per signature/template",
      min=1,
    ),
  ] = 1,
  keep_edge_prob: Annotated[
    float,
    typer.Option(
      "--edge-prob",
      "-p",
      help="The probability of keeping an edge in the subgraph",
      min=0.0,
      max=1.0,
    ),
  ] = 0.2,
  row_retention_probability: Annotated[
    float,
    typer.Option(
      "--row-retention",
      "-r",
      help="The probability of keeping a row in each predicate",
      min=0.0,
      max=1.0,
    ),
  ] = 0.2,
  extra_predicates: Annotated[
    int,
    typer.Option(
      "--extra-predicates",
      "-e",
      help="The number of extra predicates to add to the query",
      min=0,
    ),
  ] = 3,
) -> None:
  """Generate queries using a random subgraph."""

  params = QueryGenerationParameters(
    dataset=dataset,
    max_hops=max_hops,
    max_queries_per_fact_table=max_queries_per_fact_table,
    max_queries_per_signature=max_queries_per_signature,
    keep_edge_prob=keep_edge_prob,
    extra_predicates=extra_predicates,
    row_retention_probability=row_retention_probability,
    seen_subgraphs={},

  )
  generate_and_write_queries(params)


@app.command()
def param_search(
  dataset: Annotated[
    Dataset,
    typer.Option("--dataset", "-d", help="The dataset used"),
  ],
  *,
  dev: Annotated[
    bool,
    typer.Option(
      "--dev",
      help="Development testing. If true then uses scale factor 0.1 to check.",
    ),
  ] = False,
  unique_joins: Annotated[
    bool,
    typer.Option(
      "--unique-joins",
      "-u",
      help="If true all queries will have a unique join structure "
      "(not recommended for TPC-H)",
    ),
  ] = False,
  max_hops_range: Annotated[
    list[int] | None,
    typer.Option(
      "--max-hops-range",
      "-h",
      help="The range of hops to use for the query generation",
      show_default="1, 2, 4",
    ),
  ] = None,
  extra_predicates_range: Annotated[
    list[int] | None,
    typer.Option(
      "--extra-predicates-range",
      "-e",
      help="The range of extra predicates to use for the query generation",
      show_default="1, 2, 3, 5",
    ),
  ] = None,
  row_retention_probability_range: Annotated[
    list[float] | None,
    typer.Option(
      "--row-retention-probability-range",
      "-r",
      help="The range of row retention probabilities to use "
      "for the query generation",
      show_default="0.2, 0.3, 0.4, 0.6, 0.8, 0.85, 0.9, 1.0",

    ),
  ] = None,
) -> None:
  """This is an extension of the Snowflake algorithm.

  It runs multiple batches with different configurations of the algorithm.
  This allows us to get multiple results.
  """
  if max_hops_range is None:
    max_hops_range = [1, 2, 4]
  if extra_predicates_range is None:
    extra_predicates_range = [1, 2, 3, 5]
  if row_retention_probability_range is None:
    row_retention_probability_range = [0.2, 0.3, 0.4, 0.6, 0.8, 0.85, 0.9, 1.0]

  show_dev_warning(dev=dev)
  scale_factor = 0.1 if dev else 100
  con = setup_duckdb(dataset, scale_factor)
  run_snowflake_param_seach(
    SearchParameters(
      scale_factor=scale_factor,
      con=con,
      dataset=dataset,
      max_hops=max_hops_range,
      extra_predicates=extra_predicates_range,
      row_retention_probability=row_retention_probability_range,
      unique_joins=unique_joins,
    ),
  )


@app.command()
def cherry_pick(
  dataset: Annotated[







>
>
>
>
>









|
<
<
<
<
|

<
|
<
<
<
<
<
<
<
<
|
<
|
<
<
<
<
<
<
<
<
<
<
<
<
<
<
<
<
<
|
<
<

<
<
<
<
<
<
<
<
<
<
<
<
<
<
<
<
<
<
<
|


>

|
|
|
|
|
<
<

>






<
<
<
<
<
|
<
<
<
<
<
<
<
<
<
<
<
<
<
<
<
<
|

<
|
<
<
<
<
<
<
<
<
|
<
<
<
<
<
<
<
<
<
|
<
<
>

|






|
|
|
<
<
<
>
|
|
|




|
<
<
<
<







26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47




48
49

50








51

52

















53


54



















55
56
57
58
59
60
61
62
63
64


65
66
67
68
69
70
71
72





73
















74
75

76








77









78


79
80
81
82
83
84
85
86
87
88
89
90



91
92
93
94
95
96
97
98
99




100
101
102
103
104
105
106
  make_redundant_histograms,
  query_histograms,
)
from query_generator.utils.definitions import (
  Dataset,
  Extension,
  QueryGenerationParameters,
)
from query_generator.utils.params import (
  SearchParametersEndpoint,
  SnowflakeEndpoint,
  read_and_parse_toml,
)
from query_generator.utils.show_messages import show_dev_warning
from query_generator.utils.utils import validate_file_path

app = typer.Typer(name="Query Generation")


@app.command()
def snowflake(
  config_path: Annotated[




    str,
    typer.Option(

      "-c",








      "--config",

      help="The path to the configuration file"

















      "They can be found in the params_config/query_generation/ folder",


    ),



















  ],
) -> None:
  """Generate queries using a random subgraph."""
  params_endpoint = read_and_parse_toml(Path(config_path), SnowflakeEndpoint)
  params = QueryGenerationParameters(
    dataset=params_endpoint.dataset,
    max_hops=params_endpoint.max_hops,
    max_queries_per_fact_table=params_endpoint.max_queries_per_fact_table,
    max_queries_per_signature=params_endpoint.max_queries_per_signature,
    keep_edge_probability=params_endpoint.keep_edge_probability,


    seen_subgraphs={},
    predicate_parameters=params_endpoint.predicate_parameters,
  )
  generate_and_write_queries(params)


@app.command()
def param_search(





  config_path: Annotated[
















    str,
    typer.Option(

      "-c",








      "--config",









      help="The path to the configuration file"


      "They can be found in the params_config/search_params/ folder",
    ),
  ],
) -> None:
  """This is an extension of the Snowflake algorithm.

  It runs multiple batches with different configurations of the algorithm.
  This allows us to get multiple results.
  """
  params = read_and_parse_toml(
    Path(config_path),
    SearchParametersEndpoint,



  )
  show_dev_warning(dev=params.dev)
  scale_factor = 0.1 if params.dev else 100
  con = setup_duckdb(params.dataset, scale_factor)
  run_snowflake_param_seach(
    SearchParameters(
      scale_factor=scale_factor,
      con=con,
      user_input=params,




    ),
  )


@app.command()
def cherry_pick(
  dataset: Annotated[
Changes to src/query_generator/predicate_generator/predicate_generator.py.
1
2

3
4
5
6

7
8
9



10




11
12
13
14
15
16
17
18











19
20
21
22
23
24
25
26
27
28
29
30
31





32
33
34
35











36
37
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57













58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

















93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

116
117
118
119







120




121







122



















































123








124








125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import math
import random

from collections.abc import Iterator
from dataclasses import dataclass
from enum import Enum


import polars as pl

from query_generator.tools.histograms import HistogramColumns



from query_generator.utils.definitions import Dataset




from query_generator.utils.exceptions import (
  InvalidHistogramTypeError,
  UnkwonDatasetError,
)

SupportedHistogramType = float | int | str
SuportedHistogramArrayType = list[float] | list[int] | list[str]













class HistogramDataType(Enum):
  INT = "int"
  FLOAT = "float"
  DATE = "date"
  STRING = "string"


class PredicateGenerator:
  @dataclass
  class Predicate:
    table: str
    column: str





    min_value: SupportedHistogramType
    max_value: SupportedHistogramType
    dtype: HistogramDataType












  def __init__(self, dataset: Dataset):
    self.dataset = dataset
    self.histogram: pl.DataFrame = self.read_histogram()


  def _parse_bin(
    self, hist_array: list[str], dtype: HistogramDataType
  ) -> SuportedHistogramArrayType:
    """Parse the bin string representation to a list of values.

    Args:
        bin_str (str): String representation of bins.
        dtype (str): Data type of the values.

    Returns:
        list: List of parsed values.

    """
    if dtype == HistogramDataType.INT:
      return [int(float(x)) for x in hist_array]
    if dtype == HistogramDataType.FLOAT:
      return [float(x) for x in hist_array]
    if dtype == HistogramDataType.DATE:













      return hist_array
    if dtype == HistogramDataType.STRING:
      return hist_array
    raise InvalidHistogramTypeError(dtype)

  def read_histogram(self) -> pl.DataFrame:
    """Read the histogram data for the specified dataset.

    Args:
        dataset: The dataset type (TPCH or TPCDS).

    Returns:
        pd.DataFrame: DataFrame containing the histogram data.

    """
    if self.dataset == Dataset.TPCH:
      path = "data/histograms/histogram_tpch.parquet"
    elif self.dataset == Dataset.TPCDS:
      path = "data/histograms/histogram_tpcds.parquet"
    elif self.dataset == Dataset.JOB:
      path = "data/histograms/histogram_job.parquet"
    else:
      raise UnkwonDatasetError(self.dataset.value)
    return pl.read_parquet(path).filter(pl.col("histogram") != [])

  def _get_histogram_type(self, dtype: str) -> HistogramDataType:
    if dtype in ["INTEGER", "BIGINT"]:
      return HistogramDataType.INT
    if dtype.startswith("DECIMAL"):
      return HistogramDataType.FLOAT
    if dtype == "DATE":
      return HistogramDataType.DATE
    if dtype == "VARCHAR":
      return HistogramDataType.STRING
    raise InvalidHistogramTypeError(dtype)


















  def get_random_predicates(
    self,
    tables: list[str],
    num_predicates: int,
    row_retention_probability: float,
  ) -> Iterator["PredicateGenerator.Predicate"]:
    """Generate random predicates based on the histogram data.

    Args:
        tables (str): List of tables to select predicates from.
        num_predicates (int): Number of predicates to generate.
        row_retention_probability (float): Probability of retaining rows.

    Returns:
        List[PredicateGenerator.Predicate]: List of generated predicates.

    """
    selected_tables_histogram = self.histogram.filter(
      pl.col(HistogramColumns.TABLE.value).is_in(tables)
    )

    for row in selected_tables_histogram.sample(n=num_predicates).iter_rows(

      named=True
    ):
      table = row[HistogramColumns.TABLE.value]
      column = row[HistogramColumns.COLUMN.value]







      bins = row[HistogramColumns.HISTOGRAM.value]




      dtype = self._get_histogram_type(row[HistogramColumns.DTYPE.value])







      min_value, max_value = self._get_min_max_from_bins(



















































        bins, row_retention_probability, dtype








      )








      predicate = PredicateGenerator.Predicate(
        table=table,
        column=column,
        min_value=min_value,
        max_value=max_value,
        dtype=dtype,
      )
      yield predicate

  def _get_min_max_from_bins(
    self,
    bins: list[str],
    row_retention_probability: float,
    dtype: HistogramDataType,
  ) -> tuple[SupportedHistogramType, SupportedHistogramType]:
    """Convert the bins string representation to a tuple of min and max values.

    Args:
        bins (str): String representation of bins.
        row_retention_probability (float): Probability of retaining rows.

    Returns:
        tuple: Tuple containing min and max values.

    """
    histogram_array: SuportedHistogramArrayType = self._parse_bin(bins, dtype)
    subrange_length = math.ceil(
      row_retention_probability * len(histogram_array)
    )
    start_index = random.randint(0, len(histogram_array) - subrange_length)

    min_value = histogram_array[start_index]
    max_value = histogram_array[
      min(start_index + subrange_length, len(histogram_array) - 1)
    ]
    return min_value, max_value


>




>


|
>
>
>
|
>
>
>
>


|





>
>
>
>
>
>
>
>
>
>
>








<
|
|
|
|
>
>
>
>
>
|
|
|

>
>
>
>
>
>
>
>
>
>
>
|


>

|
|












|

|

>
>
>
>
>
>
>
>
>
>
>
>
>
|

|



















|












>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>




<
<
|








|






|
>
|
<


>
>
>
>
>
>
>
|
>
>
>
>
|
>
>
>
>
>
>
>
|
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
|
>
>
>
>
>
>
>
>
|
>
>
>
>
>
>
>
>
|
|
|
|
|
|
|
<




<












|

|








1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162


163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280

281
282
283
284

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
import math
import random
from abc import ABC
from collections.abc import Iterator
from dataclasses import dataclass
from enum import Enum

import numpy as np
import polars as pl

from query_generator.tools.histograms import (
  HistogramColumns,
  MostCommonValuesColumns,
)
from query_generator.utils.definitions import (
  Dataset,
  PredicateOperatorProbability,
  PredicateParameters,
)
from query_generator.utils.exceptions import (
  InvalidHistogramTypeError,
  UnkownDatasetError,
)

SupportedHistogramType = float | int | str
SuportedHistogramArrayType = list[float] | list[int] | list[str]


MAX_DISTINCT_COUNT_FOR_RANGE = 500
PROBABILITY_TO_CHOOSE_EQUALITY = 0.8
PREDICATE_IN_SIZE = 5


class PredicateTypes(Enum):
  IN = "in"
  RANGE = "range"
  EQUALITY = "equality"


class HistogramDataType(Enum):
  INT = "int"
  FLOAT = "float"
  DATE = "date"
  STRING = "string"



@dataclass
class Predicate(ABC):
  table: str
  column: str
  dtype: HistogramDataType


@dataclass
class PredicateRange(Predicate):
  min_value: SupportedHistogramType
  max_value: SupportedHistogramType


@dataclass
class PredicateEquality(Predicate):
  equality_value: SupportedHistogramType


@dataclass
class PredicateIn(Predicate):
  in_values: SuportedHistogramArrayType


class PredicateGenerator:
  def __init__(self, dataset: Dataset, predicate_params: PredicateParameters):
    self.dataset = dataset
    self.histogram: pl.DataFrame = self.read_histogram()
    self.predicate_params = predicate_params

  def _cast_array(
    self, str_array: list[str], dtype: HistogramDataType
  ) -> SuportedHistogramArrayType:
    """Parse the bin string representation to a list of values.

    Args:
        bin_str (str): String representation of bins.
        dtype (str): Data type of the values.

    Returns:
        list: List of parsed values.

    """
    if dtype == HistogramDataType.INT:
      return [int(float(x)) for x in str_array]
    if dtype == HistogramDataType.FLOAT:
      return [float(x) for x in str_array]
    if dtype == HistogramDataType.DATE:
      return str_array
    if dtype == HistogramDataType.STRING:
      return str_array
    raise InvalidHistogramTypeError(dtype)

  def _cast_element(
    self, value: str, dtype: HistogramDataType
  ) -> SupportedHistogramType:
    if dtype == HistogramDataType.INT:
      return int(float(value))
    if dtype == HistogramDataType.FLOAT:
      return float(value)
    if dtype == HistogramDataType.DATE:
      return value
    if dtype == HistogramDataType.STRING:
      return value
    raise InvalidHistogramTypeError(dtype)

  def read_histogram(self) -> pl.DataFrame:
    """Read the histogram data for the specified dataset.

    Args:
        dataset: The dataset type (TPCH or TPCDS).

    Returns:
        pd.DataFrame: DataFrame containing the histogram data.

    """
    if self.dataset == Dataset.TPCH:
      path = "data/histograms/histogram_tpch.parquet"
    elif self.dataset == Dataset.TPCDS:
      path = "data/histograms/histogram_tpcds.parquet"
    elif self.dataset == Dataset.JOB:
      path = "data/histograms/histogram_job.parquet"
    else:
      raise UnkownDatasetError(self.dataset.value)
    return pl.read_parquet(path).filter(pl.col("histogram") != [])

  def _get_histogram_type(self, dtype: str) -> HistogramDataType:
    if dtype in ["INTEGER", "BIGINT"]:
      return HistogramDataType.INT
    if dtype.startswith("DECIMAL"):
      return HistogramDataType.FLOAT
    if dtype == "DATE":
      return HistogramDataType.DATE
    if dtype == "VARCHAR":
      return HistogramDataType.STRING
    raise InvalidHistogramTypeError(dtype)

  def _choose_predicate_type(
    self, operator_weights: PredicateOperatorProbability
  ) -> PredicateTypes:
    weights = [
      operator_weights.operator_equal,
      operator_weights.operator_in,
      operator_weights.operator_range,
    ]
    return random.choices(
      [
        PredicateTypes.EQUALITY,
        PredicateTypes.IN,
        PredicateTypes.RANGE,
      ],
      weights=weights,
    )[0]

  def get_random_predicates(
    self,
    tables: list[str],


  ) -> Iterator[Predicate]:
    """Generate random predicates based on the histogram data.

    Args:
        tables (str): List of tables to select predicates from.
        num_predicates (int): Number of predicates to generate.
        row_retention_probability (float): Probability of retaining rows.

    Returns:
        List[Predicate]: List of generated predicates.

    """
    selected_tables_histogram = self.histogram.filter(
      pl.col(HistogramColumns.TABLE.value).is_in(tables)
    )

    for row in selected_tables_histogram.sample(
      n=self.predicate_params.extra_predicates
    ).iter_rows(named=True):

      table = row[HistogramColumns.TABLE.value]
      column = row[HistogramColumns.COLUMN.value]
      dtype = self._get_histogram_type(row[HistogramColumns.DTYPE.value])
      predicate_type = self._choose_predicate_type(
        self.predicate_params.operator_weights
      )

      if predicate_type == PredicateTypes.RANGE:
        yield self._get_range_predicate(
          table, column, row[HistogramColumns.HISTOGRAM.value], dtype
        )
      elif predicate_type == PredicateTypes.IN:
        array = self._get_in_array(
          row[HistogramColumns.MOST_COMMON_VALUES.value],
          row[HistogramColumns.TABLE_SIZE.value],
          row[HistogramColumns.HISTOGRAM_MCV.value],
        )
        if array is not None:
          yield self._get_in_predicate(array, table, column, dtype)
        else:
          continue
      elif predicate_type == PredicateTypes.EQUALITY:
        value = self._get_equality_value(
          row[HistogramColumns.MOST_COMMON_VALUES.value],
          row[HistogramColumns.TABLE_SIZE.value],
        )
        if value is not None:
          yield self._get_equality_predicate(value, table, column, dtype)
        else:
          continue

  def _get_in_predicate(
    self, array: list[str], table: str, column: str, dtype: HistogramDataType
  ) -> PredicateIn:
    cast_array = self._cast_array(array, dtype)
    return PredicateIn(table, column, dtype, cast_array)

  def _get_in_array(
    self,
    most_common_values: list[dict[str, int | str]],
    table_size: int,
    histogram: list[str],
  ) -> list[str] | None:
    """
    Gets the array for the IN operator
    """
    value = self._get_equality_value(most_common_values, table_size)
    if value is None:
      return None
    noise_values = random.sample(
      histogram,
      k=min(self.predicate_params.extra_values_for_in, len(histogram)),
    )
    return [value] + noise_values

  def _get_equality_predicate(
    self, value: str, table: str, column: str, dtype: HistogramDataType
  ) -> PredicateEquality:
    cast_value = self._cast_element(value, dtype)
    return PredicateEquality(
      table=table, column=column, dtype=dtype, equality_value=cast_value
    )

  def _get_equality_value(
    self,
    most_common_values: list[dict[str, int | str]],
    table_size: int,
  ) -> str | None:
    mcv_probabilities: list[float] = [
      float(table_size) / float(v[MostCommonValuesColumns.COUNT.value])
      for v in most_common_values
    ]
    mcv_probabilities_np = np.array(mcv_probabilities)
    filtered_indices = np.where(
      mcv_probabilities_np
      > self.predicate_params.equality_lower_bound_probability
    )[0]
    if len(filtered_indices) == 0:
      return None
    idx = random.choice(filtered_indices)
    value = most_common_values[idx][MostCommonValuesColumns.VALUE.value]
    assert isinstance(value, str)
    return value

  def _get_range_predicate(
    self,
    table: str,
    column: str,
    bins: list[str],
    dtype: HistogramDataType,
  ) -> PredicateRange:
    min_value, max_value = self._get_min_max_from_bins(bins, dtype)
    return PredicateRange(
      table=table,
      column=column,
      min_value=min_value,
      max_value=max_value,
      dtype=dtype,
    )


  def _get_min_max_from_bins(
    self,
    bins: list[str],

    dtype: HistogramDataType,
  ) -> tuple[SupportedHistogramType, SupportedHistogramType]:
    """Convert the bins string representation to a tuple of min and max values.

    Args:
        bins (str): String representation of bins.
        row_retention_probability (float): Probability of retaining rows.

    Returns:
        tuple: Tuple containing min and max values.

    """
    histogram_array: SuportedHistogramArrayType = self._cast_array(bins, dtype)
    subrange_length = math.ceil(
      self.predicate_params.row_retention_probability * len(histogram_array)
    )
    start_index = random.randint(0, len(histogram_array) - subrange_length)

    min_value = histogram_array[start_index]
    max_value = histogram_array[
      min(start_index + subrange_length, len(histogram_array) - 1)
    ]
    return min_value, max_value
Changes to src/query_generator/tools/histograms.py.
16
17
18
19
20
21
22
23



24
25
26
27
28
29
30
  get_equi_height_histogram,
  get_frequent_non_null_values,
  get_histogram_excluding_common_values,
  get_tables,
)
from query_generator.utils.exceptions import InvalidHistogramTypeError

LIMIT_FOR_DISTINCT_VALUES = 1000





class RedundantHistogramsDataType(Enum):
  """
  This class was made for compatibility with old code that
  generated this histogram:
  https://github.com/udao-moo/udao-spark-optimizer-dev/blob/main







|
>
>
>







16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
  get_equi_height_histogram,
  get_frequent_non_null_values,
  get_histogram_excluding_common_values,
  get_tables,
)
from query_generator.utils.exceptions import InvalidHistogramTypeError


class MostCommonValuesColumns(Enum):
  VALUE = "value"
  COUNT = "count"


class RedundantHistogramsDataType(Enum):
  """
  This class was made for compatibility with old code that
  generated this histogram:
  https://github.com/udao-moo/udao-spark-optimizer-dev/blob/main
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106


def get_most_common_values(
  con: duckdb.DuckDBPyConnection,
  table: str,
  column: str,
  common_value_size: int,
  distinct_count: int,
) -> list[RawDuckDBMostCommonValues]:
  result: list[RawDuckDBMostCommonValues] = []
  if distinct_count < LIMIT_FOR_DISTINCT_VALUES:
    result = get_frequent_non_null_values(con, table, column, common_value_size)
  return result


def get_histogram_array(histogram_params: HistogramParams) -> list[str]:
  histogram_raw = get_equi_height_histogram(
    histogram_params.con,
    histogram_params.table,
    histogram_params.column.column_name,







<

<
<
|
<







90
91
92
93
94
95
96

97


98

99
100
101
102
103
104
105


def get_most_common_values(
  con: duckdb.DuckDBPyConnection,
  table: str,
  column: str,
  common_value_size: int,

) -> list[RawDuckDBMostCommonValues]:


  return get_frequent_non_null_values(con, table, column, common_value_size)



def get_histogram_array(histogram_params: HistogramParams) -> list[str]:
  histogram_raw = get_equi_height_histogram(
    histogram_params.con,
    histogram_params.table,
    histogram_params.column.column_name,
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

def get_histogram_array_excluding_common_values(
  histogram_params: HistogramParams,
  common_values_size: int,
  distinct_count: int,
) -> list[str]:
  histogram_array: list[RawDuckDBHistograms] = []
  if (
    distinct_count < LIMIT_FOR_DISTINCT_VALUES
    and distinct_count > common_values_size
  ):
    histogram_array = get_histogram_excluding_common_values(
      histogram_params.con,
      histogram_params.table,
      histogram_params.column.column_name,
      histogram_params.histogram_size,
      common_values_size,
    )







<
<
|
<







113
114
115
116
117
118
119


120

121
122
123
124
125
126
127

def get_histogram_array_excluding_common_values(
  histogram_params: HistogramParams,
  common_values_size: int,
  distinct_count: int,
) -> list[str]:
  histogram_array: list[RawDuckDBHistograms] = []


  if distinct_count > common_values_size:

    histogram_array = get_histogram_excluding_common_values(
      histogram_params.con,
      histogram_params.table,
      histogram_params.column.column_name,
      histogram_params.histogram_size,
      common_values_size,
    )
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

201


202
203
204
205
206
207
208
      if include_mvc:
        # Get most common values
        most_common_values = get_most_common_values(
          con,
          table,
          column.column_name,
          common_values_size,
          distinct_count,
        )

        # Get histogram array excluding common values
        histogram_array_excluding_mcv = (
          get_histogram_array_excluding_common_values(
            histogram_params,
            common_values_size,
            distinct_count,
          )
        )

        row_dict |= {
          HistogramColumns.MOST_COMMON_VALUES.value: [

            {"value": value.value, "count": value.count}


            for value in most_common_values
          ],
          HistogramColumns.HISTOGRAM_MCV.value: histogram_array_excluding_mcv,
        }

      rows.append(row_dict)
  return pl.DataFrame(rows)







<













>
|
>
>







176
177
178
179
180
181
182

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
      if include_mvc:
        # Get most common values
        most_common_values = get_most_common_values(
          con,
          table,
          column.column_name,
          common_values_size,

        )

        # Get histogram array excluding common values
        histogram_array_excluding_mcv = (
          get_histogram_array_excluding_common_values(
            histogram_params,
            common_values_size,
            distinct_count,
          )
        )

        row_dict |= {
          HistogramColumns.MOST_COMMON_VALUES.value: [
            {
              MostCommonValuesColumns.VALUE.value: value.value,
              MostCommonValuesColumns.COUNT.value: value.count,
            }
            for value in most_common_values
          ],
          HistogramColumns.HISTOGRAM_MCV.value: histogram_array_excluding_mcv,
        }

      rows.append(row_dict)
  return pl.DataFrame(rows)
Changes to src/query_generator/utils/definitions.py.
15
16
17
18
19
20
21






















22

23
24
25
26
27
28
29
30

31
32
33
34
35
36
37
class Dataset(Enum):
  TPCDS = "TPCDS"
  TPCH = "TPCH"
  JOB = "JOB"


@dataclass






















class QueryGenerationParameters:

  max_hops: int
  max_queries_per_signature: int
  max_queries_per_fact_table: int
  keep_edge_prob: float
  dataset: Dataset
  extra_predicates: int
  row_retention_probability: float
  seen_subgraphs: dict[int, bool]



@dataclass
class GeneratedQueryFeatures:
  query: str
  template_number: int
  predicate_number: int







>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>

>



|
<
<
<

>







15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49



50
51
52
53
54
55
56
57
58
class Dataset(Enum):
  TPCDS = "TPCDS"
  TPCH = "TPCH"
  JOB = "JOB"


@dataclass
class PredicateOperatorProbability:
  """Probability of using a specific predicate operator.

  They are based on choice with weights for each operator.
  """

  operator_in: float
  operator_equal: float
  operator_range: float


@dataclass
class PredicateParameters:
  extra_predicates: int
  row_retention_probability: float
  operator_weights: PredicateOperatorProbability
  equality_lower_bound_probability: float
  extra_values_for_in: int


# TODO(Gabriel): http://localhost:8080/tktview/205e90a1fa
@dataclass
class QueryGenerationParameters:
  dataset: Dataset
  max_hops: int
  max_queries_per_signature: int
  max_queries_per_fact_table: int
  keep_edge_probability: float



  seen_subgraphs: dict[int, bool]
  predicate_parameters: PredicateParameters


@dataclass
class GeneratedQueryFeatures:
  query: str
  template_number: int
  predicate_number: int
Changes to src/query_generator/utils/exceptions.py.
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


class DuplicateEdgesError(Exception):
  def __init__(self, table: str) -> None:
    super().__init__(f"Duplicate edges found for table {table}.")


class UnkwonDatasetError(Exception):
  def __init__(self, dataset: str) -> None:
    super().__init__(f"Unknown dataset: {dataset}")


class MissingScaleFactorError(Exception):
  def __init__(self, dataset: str) -> None:
    super().__init__(







|







23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


class DuplicateEdgesError(Exception):
  def __init__(self, table: str) -> None:
    super().__init__(f"Duplicate edges found for table {table}.")


class UnkownDatasetError(Exception):
  def __init__(self, dataset: str) -> None:
    super().__init__(f"Unknown dataset: {dataset}")


class MissingScaleFactorError(Exception):
  def __init__(self, dataset: str) -> None:
    super().__init__(
Added src/query_generator/utils/params.py.












































































































































































































>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import tomllib
from dataclasses import dataclass
from pathlib import Path
from typing import TypeVar

from cattrs import structure

from query_generator.utils.definitions import (
  Dataset,
  PredicateOperatorProbability,
  PredicateParameters,
)


@dataclass
class SearchParametersEndpoint:
  """
  Represents the parameters used for configuring search queries, including
  query builder, subgraph, and predicate options.

  This class is designed to support both the `IN` and `=` statements in
  query generation.

  Attributes:
    dataset (Dataset): The dataset to be queried.
    dev (bool): Flag indicating whether to use development settings.
    max_queries_per_fact_table (int): Maximum number of queries per fact
      table.
    max_queries_per_signature (int): Maximum number of queries per
      signature.
    unique_joins (bool): Whether to enforce unique joins in the subgraph.
    max_hops (list[int]): Maximum number of hops allowed in the subgraph.
    keep_edge_probability (float): Probability of retaining an edge in the
      subgraph.
    extra_predicates (list[int]): Number of additional predicates to include
      in the query.
    row_retention_probability (list[float]): Probability of retaining a row
      for range predicates
    operator_weights (PredicateOperatorProbability): Probability
      distribution for predicate operators.
    equality_lower_bound_probability (float): Lower bound probability when
      using the `=` and the `IN` operators
  """

  # Query Builder
  dataset: Dataset
  dev: bool
  max_queries_per_fact_table: int
  max_queries_per_signature: int
  # Subgraph
  unique_joins: bool
  max_hops: list[int]
  keep_edge_probability: float
  # Predicates
  extra_predicates: list[int]
  row_retention_probability: list[float]
  operator_weights: PredicateOperatorProbability
  equality_lower_bound_probability: list[float]
  extra_values_for_in: int


@dataclass
class SnowflakeEndpoint:
  """
  Represents the parameters used for configuring query generation,
  including query builder, subgraph, and predicate options.

  Attributes:
    dataset (Dataset): The dataset to be used for query generation.
    max_queries_per_signature (int): Maximum number of queries to generate
      per signature.
    max_queries_per_fact_table (int): Maximum number of queries to generate
      per fact table.
    max_hops (int): Maximum number of hops allowed in the subgraph.
    keep_edge_probability (float): Probability of retaining an edge in the
      subgraph.
    extra_predicates (int): Number of extra predicates to add to the query.
    row_retention_probability (float): Probability of retaining a row after
      applying predicates.
    operator_weights (PredicateOperatorProbability): Probability
      distribution for predicate operators.
    equality_lower_bound_probability (float): Probability of using a lower
      bound for equality predicates.
  """

  # Query builder
  dataset: Dataset
  max_queries_per_signature: int
  max_queries_per_fact_table: int
  # Subgraph
  max_hops: int
  keep_edge_probability: float
  # Predicates
  predicate_parameters: PredicateParameters


T = TypeVar("T")


def read_and_parse_toml(path: Path, cls: type[T]) -> T:
  toml_dict = tomllib.loads(path.read_text())
  return structure(toml_dict, cls)
Changes to tests/duckdb/test_binning.py.

1
2

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

from unittest import mock


import polars as pl
import pytest

from query_generator.duckdb_connection.binning import (
  SearchParameters,
  run_snowflake_param_seach,
)
from query_generator.tools.cherry_pick_binning import make_bins_in_csv
from query_generator.utils.definitions import Dataset


@pytest.mark.parametrize(
  "count_star, upper_bound, total_bins, expected_bin",
  [
    (5, 10, 5, 3),
    (0, 10, 5, 0),
>


>








|







1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import tomllib
from unittest import mock

from cattrs import structure
import polars as pl
import pytest

from query_generator.duckdb_connection.binning import (
  SearchParameters,
  run_snowflake_param_seach,
)
from query_generator.tools.cherry_pick_binning import make_bins_in_csv
from query_generator.utils.params import SearchParametersEndpoint


@pytest.mark.parametrize(
  "count_star, upper_bound, total_bins, expected_bin",
  [
    (5, 10, 5, 3),
    (0, 10, 5, 0),
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61



















62
63
64
65
66
67
68
69
70

71
72
73
74
75
76
    " but got {computed_bin}"
  )


@pytest.mark.parametrize(
  "extra_predicates, expected_call_count, unique_joins",
  [
    ([1], 120 * 1 + 14, False),
    ([1], 120 * 1 + 14, True),
    # Inventory is small and prooduces 14 queries total
    ([1, 2], 120 * 2 + 14, True),
    ([1, 2], 120 * 2 + 14 * 2, False),
  ],
)
def test_binning_calls(extra_predicates, expected_call_count, unique_joins):
  with mock.patch(
    "query_generator.duckdb_connection.binning.Writer.write_query_to_batch",
  ) as mock_writer:
    with mock.patch(
      "query_generator.duckdb_connection.binning.get_result_from_duckdb",
    ) as mock_connect:
      mock_connect.return_value = 0



















      run_snowflake_param_seach(
        search_params=SearchParameters(
          scale_factor=0,
          dataset=Dataset.TPCDS,
          max_hops=[1],
          extra_predicates=extra_predicates,
          row_retention_probability=[0.2],
          con=None,
          unique_joins=unique_joins,

        ),
      )
    assert mock_writer.call_count == expected_call_count, (
      f"Expected {expected_call_count} calls to write_query, "
      f"but got {mock_writer.call_count}"
    )







|
|

|
|










>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>
>



<
<
<
<

<
>






42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85




86

87
88
89
90
91
92
93
    " but got {computed_bin}"
  )


@pytest.mark.parametrize(
  "extra_predicates, expected_call_count, unique_joins",
  [
    ("[1]", 120 * 1 + 14, "false"),
    ("[1]", 120 * 1 + 14, "true"),
    # Inventory is small and prooduces 14 queries total
    ("[1, 2]", 120 * 2 + 14, "true"),
    ("[1, 2]", 120 * 2 + 14 * 2, "false"),
  ],
)
def test_binning_calls(extra_predicates, expected_call_count, unique_joins):
  with mock.patch(
    "query_generator.duckdb_connection.binning.Writer.write_query_to_batch",
  ) as mock_writer:
    with mock.patch(
      "query_generator.duckdb_connection.binning.get_result_from_duckdb",
    ) as mock_connect:
      mock_connect.return_value = 0
      data_toml = f"""
        dataset = "TPCDS"
        dev = true
        max_hops = [1]
        extra_predicates = {extra_predicates}
        row_retention_probability = [0.2]
        unique_joins = {unique_joins}
        max_queries_per_fact_table = 10
        max_queries_per_signature = 2
        keep_edge_probability = 0.2
        equality_lower_bound_probability = [0]
        extra_values_for_in = 3

        [operator_weights]
        operator_in = 1
        operator_range = 3
        operator_equal = 3
        """
      user_input = structure(tomllib.loads(data_toml), SearchParametersEndpoint)
      run_snowflake_param_seach(
        search_params=SearchParameters(
          scale_factor=0,




          con=None,

          user_input=user_input,
        ),
      )
    assert mock_writer.call_count == expected_call_count, (
      f"Expected {expected_call_count} calls to write_query, "
      f"but got {mock_writer.call_count}"
    )
Changes to tests/duckdb/test_duckdb_utils.py.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17


from query_generator.duckdb_connection.setup import setup_duckdb
from query_generator.duckdb_connection.utils import (
  get_distinct_count,
  get_equi_height_histogram,
  get_frequent_non_null_values,
)
from query_generator.tools.histograms import DuckDBHistogramParser
from query_generator.utils.definitions import Dataset
from tests.utils import is_date, is_float
import datetime


def test_distinct_values():
  """Test the setup of DuckDB."""
  # Setup DuckDB
  con = setup_duckdb(Dataset.TPCDS, 0.1)
  assert get_distinct_count(con, "call_center", "cc_call_center_sk") == 1
>
>








|
<







1
2
3
4
5
6
7
8
9
10
11

12
13
14
15
16
17
18
import datetime

from query_generator.duckdb_connection.setup import setup_duckdb
from query_generator.duckdb_connection.utils import (
  get_distinct_count,
  get_equi_height_histogram,
  get_frequent_non_null_values,
)
from query_generator.tools.histograms import DuckDBHistogramParser
from query_generator.utils.definitions import Dataset
from tests.utils import is_float



def test_distinct_values():
  """Test the setup of DuckDB."""
  # Setup DuckDB
  con = setup_duckdb(Dataset.TPCDS, 0.1)
  assert get_distinct_count(con, "call_center", "cc_call_center_sk") == 1
Changes to tests/file_management/test_read_histograms.py.
1
2

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from unittest import mock


import pytest

import polars as pl
from query_generator.predicate_generator.predicate_generator import (
  HistogramDataType,
  PredicateGenerator,
)
from query_generator.tools.histograms import HistogramColumns
from query_generator.utils.definitions import Dataset
from query_generator.utils.exceptions import InvalidHistogramTypeError


def test_read_histograms():
  for dataset in Dataset:
    predicate_generator = PredicateGenerator(dataset)
    histogram = predicate_generator.read_histogram()
    assert not histogram.is_empty()

    assert histogram[HistogramColumns.DTYPE.value].dtype == pl.Utf8
    assert histogram[HistogramColumns.COLUMN.value].dtype == pl.Utf8
    assert histogram[HistogramColumns.DTYPE.value].dtype == pl.Utf8
    assert histogram[HistogramColumns.HISTOGRAM.value].dtype == pl.List(pl.Utf8)


>


<





|





|







1
2
3
4
5

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from unittest import mock

import polars as pl
import pytest


from query_generator.predicate_generator.predicate_generator import (
  HistogramDataType,
  PredicateGenerator,
)
from query_generator.tools.histograms import HistogramColumns
from query_generator.utils.definitions import Dataset, PredicateParameters
from query_generator.utils.exceptions import InvalidHistogramTypeError


def test_read_histograms():
  for dataset in Dataset:
    predicate_generator = PredicateGenerator(dataset, None)
    histogram = predicate_generator.read_histogram()
    assert not histogram.is_empty()

    assert histogram[HistogramColumns.DTYPE.value].dtype == pl.Utf8
    assert histogram[HistogramColumns.COLUMN.value].dtype == pl.Utf8
    assert histogram[HistogramColumns.DTYPE.value].dtype == pl.Utf8
    assert histogram[HistogramColumns.HISTOGRAM.value].dtype == pl.List(pl.Utf8)
74
75
76
77
78
79
80
81









82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
  max_index,
  dtype,
):
  with mock.patch(
    "query_generator.predicate_generator.predicate_generator.random.randint",
    return_value=mock_rand,
  ):
    predicate_generator = PredicateGenerator(Dataset.TPCH)









    min_value, max_value = predicate_generator._get_min_max_from_bins(
      bins_array, row_retention_probability, dtype
    )
  assert min_value == bins_array[min_index]
  assert max_value == bins_array[max_index]


def test_get_invalid_histogram_type():
  predicate_generator = PredicateGenerator(Dataset.TPCH)
  with pytest.raises(InvalidHistogramTypeError):
    predicate_generator._get_histogram_type("not_supported_type")


@pytest.mark.parametrize(
  "input_type, expected_type",
  [
    ("INTEGER", HistogramDataType.INT),
    ("BIGINT", HistogramDataType.INT),
    ("DECIMAL(10,2)", HistogramDataType.FLOAT),
    ("DECIMAL(7,4)", HistogramDataType.FLOAT),
    ("DATE", HistogramDataType.DATE),
    ("VARCHAR", HistogramDataType.STRING),
  ],
)
def test_get_valid_histogram_type(input_type, expected_type):
  predicate_generator = PredicateGenerator(Dataset.TPCH)
  assert predicate_generator._get_histogram_type(input_type) == expected_type







|
>
>
>
>
>
>
>
>
>

|






|
















|

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
  max_index,
  dtype,
):
  with mock.patch(
    "query_generator.predicate_generator.predicate_generator.random.randint",
    return_value=mock_rand,
  ):
    predicate_generator = PredicateGenerator(
      Dataset.TPCH,
      PredicateParameters(
        extra_predicates=None,
        row_retention_probability=row_retention_probability,
        operator_weights=None,
        equality_lower_bound_probability=None,
        extra_values_for_in=None,
      ),
    )
    min_value, max_value = predicate_generator._get_min_max_from_bins(
      bins_array, dtype
    )
  assert min_value == bins_array[min_index]
  assert max_value == bins_array[max_index]


def test_get_invalid_histogram_type():
  predicate_generator = PredicateGenerator(Dataset.TPCH, None)
  with pytest.raises(InvalidHistogramTypeError):
    predicate_generator._get_histogram_type("not_supported_type")


@pytest.mark.parametrize(
  "input_type, expected_type",
  [
    ("INTEGER", HistogramDataType.INT),
    ("BIGINT", HistogramDataType.INT),
    ("DECIMAL(10,2)", HistogramDataType.FLOAT),
    ("DECIMAL(7,4)", HistogramDataType.FLOAT),
    ("DATE", HistogramDataType.DATE),
    ("VARCHAR", HistogramDataType.STRING),
  ],
)
def test_get_valid_histogram_type(input_type, expected_type):
  predicate_generator = PredicateGenerator(Dataset.TPCH, None)
  assert predicate_generator._get_histogram_type(input_type) == expected_type
Changes to tests/query_generation/test_make_queries.py.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17


18
19
20
21

22
23
24
25
26
27
28
29
30
31
32
33
34


35




36


37
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54


55




56


57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75


76




77


78

79
80
81
82
83
84
85
86











87
88
89
90
91
92
93
94
95
96
97
98
99
from unittest import mock

import pytest
from pypika import functions as fn


from query_generator.database_schemas.schemas import get_schema
from query_generator.join_based_query_generator.snowflake import (
  QueryBuilder,
  generate_and_write_queries,
)
from query_generator.predicate_generator.predicate_generator import (
  HistogramDataType,
  PredicateGenerator,
)
from query_generator.utils.definitions import (
  Dataset,


  QueryGenerationParameters,
)
from query_generator.utils.exceptions import UnkwonDatasetError
from pypika import OracleQuery



def test_tpch_query_generation():
  with mock.patch(
    "query_generator.join_based_query_generator.snowflake.Writer.write_query",
  ) as mock_writer:
    generate_and_write_queries(
      QueryGenerationParameters(
        dataset=Dataset.TPCDS,
        max_hops=1,
        max_queries_per_fact_table=1,
        max_queries_per_signature=1,
        keep_edge_prob=0.2,


        row_retention_probability=0.2,




        extra_predicates=1,


        seen_subgraphs={},
      ),

    )

    assert mock_writer.call_count > 5


def test_tpcds_query_generation():
  with mock.patch(
    "query_generator.join_based_query_generator.snowflake.Writer.write_query",
  ) as mock_writer:
    generate_and_write_queries(
      QueryGenerationParameters(
        dataset=Dataset.TPCDS,
        max_hops=1,
        max_queries_per_fact_table=1,
        max_queries_per_signature=1,
        keep_edge_prob=0.2,


        row_retention_probability=0.2,




        extra_predicates=1,


        seen_subgraphs={},

      ),
    )

    assert mock_writer.call_count > 5


def test_non_implemented_dataset():
  with mock.patch(
    "query_generator.join_based_query_generator.snowflake.Writer.write_query",
  ) as mock_writer:
    with pytest.raises(UnkwonDatasetError):
      generate_and_write_queries(
        QueryGenerationParameters(
          dataset="non_implemented_dataset",
          max_hops=1,
          max_queries_per_fact_table=1,
          max_queries_per_signature=1,
          keep_edge_prob=0.2,


          row_retention_probability=0.2,




          extra_predicates=1,


          seen_subgraphs={},

        ),
      )
    assert mock_writer.call_count == 0


def test_add_rage_supports_all_histogram_types():
  tables_schema, _ = get_schema(Dataset.TPCH)
  query_builder = QueryBuilder(None, tables_schema, Dataset.TPCH)











  for dtype in HistogramDataType:
    query_builder._add_range(
      OracleQuery()
      .from_(query_builder.table_to_pypika_table["lineitem"])
      .select(fn.Count("*")),
      PredicateGenerator.Predicate(
        table="lineitem",
        column="foo",
        min_value=2020,
        max_value=2020,
        dtype=dtype,
      ),
    )



<









|



>
>


|

>












|
>
>
|
>
>
>
>
|
>
>
|
|
>















|
>
>
|
>
>
>
>
|
>
>
|
>










|






|
>
>
|
>
>
>
>
|
>
>
|
>







|
>
>
>
>
>
>
>
>
>
>
>





|







1
2
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from unittest import mock

import pytest



from query_generator.database_schemas.schemas import get_schema
from query_generator.join_based_query_generator.snowflake import (
  QueryBuilder,
  generate_and_write_queries,
)
from query_generator.predicate_generator.predicate_generator import (
  HistogramDataType,
  PredicateRange,
)
from query_generator.utils.definitions import (
  Dataset,
  PredicateOperatorProbability,
  PredicateParameters,
  QueryGenerationParameters,
)
from query_generator.utils.exceptions import UnkownDatasetError
from pypika import OracleQuery
from pypika import functions as fn


def test_tpch_query_generation():
  with mock.patch(
    "query_generator.join_based_query_generator.snowflake.Writer.write_query",
  ) as mock_writer:
    generate_and_write_queries(
      QueryGenerationParameters(
        dataset=Dataset.TPCDS,
        max_hops=1,
        max_queries_per_fact_table=1,
        max_queries_per_signature=1,
        keep_edge_probability=0.2,
        seen_subgraphs={},
        predicate_parameters=PredicateParameters(
          operator_weights=PredicateOperatorProbability(
            operator_in=0.4,
            operator_equal=0.4,
            operator_range=0.2,
          ),
          extra_predicates=1,
          row_retention_probability=0.2,
          equality_lower_bound_probability=0,
          extra_values_for_in=3,
        ),
      )
    )

    assert mock_writer.call_count > 5


def test_tpcds_query_generation():
  with mock.patch(
    "query_generator.join_based_query_generator.snowflake.Writer.write_query",
  ) as mock_writer:
    generate_and_write_queries(
      QueryGenerationParameters(
        dataset=Dataset.TPCDS,
        max_hops=1,
        max_queries_per_fact_table=1,
        max_queries_per_signature=1,
        keep_edge_probability=0.2,
        seen_subgraphs={},
        predicate_parameters=PredicateParameters(
          operator_weights=PredicateOperatorProbability(
            operator_in=0.4,
            operator_equal=0.4,
            operator_range=0.2,
          ),
          extra_predicates=1,
          row_retention_probability=0.2,
          equality_lower_bound_probability=0,
          extra_values_for_in=3,
        ),
      ),
    )

    assert mock_writer.call_count > 5


def test_non_implemented_dataset():
  with mock.patch(
    "query_generator.join_based_query_generator.snowflake.Writer.write_query",
  ) as mock_writer:
    with pytest.raises(UnkownDatasetError):
      generate_and_write_queries(
        QueryGenerationParameters(
          dataset="non_implemented_dataset",
          max_hops=1,
          max_queries_per_fact_table=1,
          max_queries_per_signature=1,
          keep_edge_probability=0.2,
          seen_subgraphs={},
          predicate_parameters=PredicateParameters(
            operator_weights=PredicateOperatorProbability(
              operator_in=0.4,
              operator_equal=0.4,
              operator_range=0.2,
            ),
            extra_predicates=1,
            row_retention_probability=0.2,
            equality_lower_bound_probability=0,
            extra_values_for_in=3,
          ),
        ),
      )
    assert mock_writer.call_count == 0


def test_add_rage_supports_all_histogram_types():
  tables_schema, _ = get_schema(Dataset.TPCH)
  query_builder = QueryBuilder(
    None,
    tables_schema,
    Dataset.TPCH,
    PredicateParameters(
      extra_predicates=None,
      row_retention_probability=0.2,
      operator_weights=None,
      equality_lower_bound_probability=None,
      extra_values_for_in=None,
    ),
  )
  for dtype in HistogramDataType:
    query_builder._add_range(
      OracleQuery()
      .from_(query_builder.table_to_pypika_table["lineitem"])
      .select(fn.Count("*")),
      PredicateRange(
        table="lineitem",
        column="foo",
        min_value=2020,
        max_value=2020,
        dtype=dtype,
      ),
    )