Query-generation

Changes On Branch new-predicates
Login

Changes On Branch new-predicates

Many hyperlinks are disabled.
Use anonymous login to enable hyperlinks.

Changes In Branch new-predicates Excluding Merge-Ins

This is equivalent to a diff from 0f0856db5a to 4162607284

2025-05-28
09:27
Merges IN and = predicate check-in: 0a0518ab14 user: mathos tags: trunk
09:24
Adds equality lower bound as an array. Leaf check-in: 4162607284 user: mathos tags: new-predicates
00:05
Finally. An stable version check-in: 48e17f1cde user: mathos tags: new-predicates
2025-05-26
11:22
Minor refactor to predicate class. Ticket [1e726428f6e719fb] check-in: c93b2b766c user: mathos tags: new-predicates
11:12
Fix to be able to save the CSV check-in: 0f0856db5a user: mathos tags: trunk
09:57
adds table_size to histogram check-in: 288ba9b582 user: mathos tags: trunk

Changes to data/histograms/histogram_tpcds.parquet.

cannot compute difference between binary files

Added params_config/search_params/tpcds.toml.
















1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
dataset = "TPCDS"
dev = true
max_hops = [1,2,4]
extra_predicates = [1,3,5]
row_retention_probability = [0.2, 0.3, 0.4, 0.6, 0.8, 0.85, 0.9, 1.0]
unique_joins = true
max_queries_per_fact_table = 10
max_queries_per_signature = 2
keep_edge_probability = 0.2
equality_lower_bound_probability = [0,0.1]
extra_values_for_in = 3

[operator_weights]
operator_in = 1
operator_range = 3
operator_equal = 3
Added params_config/search_params/tpcds_dev.toml.
















1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
dataset = "TPCDS"
dev = true
max_hops = [1]
extra_predicates = [5]
row_retention_probability = [0.2, 0.9]
unique_joins = true
max_queries_per_fact_table = 1
max_queries_per_signature = 2
keep_edge_probability = 0.2
equality_lower_bound_probability = [0,0.1]
extra_values_for_in = 3

[operator_weights]
operator_in = 1
operator_range = 3
operator_equal = 3
Added params_config/snowflake/tpcds.toml.


















1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
dataset = "TPCDS"
max_hops = 3
max_queries_per_fact_table = 100
max_queries_per_signature = 1
keep_edge_probability = 0.2



[predicate_parameters]
row_retention_probability = 0.2
extra_predicates = 3
equality_lower_bound_probability = 0.00
extra_values_for_in = 3

[predicate_parameters.operator_weights]
operator_in = 1
operator_range = 3
operator_equal = 3
Changes to pyproject.toml.
26
27
28
29
30
31
32

33
34
35
36
37
38
39
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40







+







typer = ">=0.15.2,<0.16"
rich  = ">=14.0.0,<15"
pypika = ">=0.48.9,<0.49"
numpy = ">=2.2.5,<3"
duckdb = ">=1.2.2,<2"
polars = ">=1.27.1,<2"
tqdm = "*"
cattrs = ">=24.1.2,<25"


[tool.pixi.feature.test.dependencies]
pytest = ">=8.3.5,<9"

[tool.pixi.feature.lint.dependencies]
ruff = ">=0.11.7,<0.12"
Changes to src/query_generator/database_schemas/schemas.py.
1
2
3
4
5
6
7
8

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

1
2
3
4
5
6
7

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

29







-
+




















-
+
from typing import Any

from query_generator.database_schemas.tpcds import get_tpcds_table_info
from query_generator.database_schemas.tpch import get_tpch_table_info
from query_generator.utils.definitions import Dataset
from query_generator.utils.exceptions import (
  PartiallySupportedDatasetError,
  UnkwonDatasetError,
  UnkownDatasetError,
)


def get_schema(dataset: Dataset) -> tuple[dict[str, dict[str, Any]], list[str]]:
  """Get the schema of the database based on the dataset.

  Args:
      dataset (Dataset): The dataset to get the schema for.

  Returns:
      Tuple[Dict[str, Dict[str, Any]], List[str]]: A tuple containing the schema
      as a dictionary and a list of fact tables

  """
  if dataset == Dataset.TPCDS:
    return get_tpcds_table_info()
  if dataset == Dataset.TPCH:
    return get_tpch_table_info()
  if dataset == Dataset.JOB:
    raise PartiallySupportedDatasetError(dataset.value)
  raise UnkwonDatasetError(dataset)
  raise UnkownDatasetError(dataset)
Changes to src/query_generator/database_schemas/tpcds.py.
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
436
437
438
439
440
441
442

443
444
445
446
447
448
449







-







        "s_floor_space": {"max": 9917607, "min": 5010719},
        "s_gmt_offset": {"max": -5.0, "min": -6.0},
        "s_market_id": {"max": 10, "min": 1},
        "s_number_employees": {"max": 300, "min": 200},
        "s_rec_end_date": {"max": "2001-03-12", "min": "1999-03-13"},
        "s_rec_start_date": {"max": "2001-03-13", "min": "1997-03-13"},
        "s_store_sk": {"max": 402, "min": 1},
        "s_tax_precentage": {"max": 0.11, "min": 0.0},
      },
      "foreign_keys": [],
    },
    "store_returns": {
      "alias": "sr",
      "columns": {
        "sr_addr_sk": {"max": 1000000, "min": 1},
Changes to src/query_generator/duckdb_connection/binning.py.
9
10
11
12
13
14
15
16
17

18
19

20
21
22
23
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

43
44
45
46
47
48
49
50
51
52
53
54
55
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

72
73
74
75

76
77



78



79
80
81
82




83
84
85
86
87
88
89
90

91
92
93
94
95
96







97




98
99
100
101
102
103
104
9
10
11
12
13
14
15

16
17
18
19
20
21
22
23
24

25
26
27




28
29
30
31
32
33
34
35
36
37
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

69
70
71
72

73
74
75
76
77
78

79
80
81
82



83
84
85
86
87
88
89
90
91
92
93

94
95





96
97
98
99
100
101
102

103
104
105
106
107
108
109
110
111
112
113







-

+


+




-
+


-
-
-
-











-
+














+














-
+



-
+


+
+
+
-
+
+
+

-
-
-
+
+
+
+







-
+

-
-
-
-
-
+
+
+
+
+
+
+
-
+
+
+
+







  QueryGenerator,
)
from query_generator.join_based_query_generator.utils.query_writer import (
  Writer,
)
from query_generator.utils.definitions import (
  BatchGeneratedQueryFeatures,
  Dataset,
  Extension,
  PredicateParameters,
  QueryGenerationParameters,
)
from query_generator.utils.params import SearchParametersEndpoint


@dataclass
class SearchParameters:
  dataset: Dataset
  user_input: SearchParametersEndpoint
  scale_factor: int | float
  con: duckdb.DuckDBPyConnection
  max_hops: list[int]
  extra_predicates: list[int]
  row_retention_probability: list[float]
  unique_joins: bool


def get_result_from_duckdb(query: str, con: duckdb.DuckDBPyConnection) -> int:
  try:
    result = int(con.sql(query).fetchall()[0][0])
  except duckdb.BinderException as e:
    print(f"Invalid query, exception: {e},\n{query}")
    return -1
  return result


def get_total_iterations(search_params: SearchParameters) -> int:
def get_total_iterations(search_params: SearchParametersEndpoint) -> int:
  """Get the total number of iterations for the Snowflake binning process.

  Args:
    search_params (SearchParameters): The parameters for the Snowflake
    binning process.

  Returns:
    int: The total number of iterations.

  """
  return (
    len(search_params.max_hops)
    * len(search_params.extra_predicates)
    * len(search_params.row_retention_probability)
    * len(search_params.equality_lower_bound_probability)
  )


def run_snowflake_param_seach(
  search_params: SearchParameters,
) -> None:
  """Run the Snowflake binning process. Binning is equiwidth binning.

  Args:
    parameters (BinningSnowflakeParameters): The parameters for
    the Snowflake binning process.

  """
  query_writer = Writer(
    search_params.dataset,
    search_params.user_input.dataset,
    Extension.SNOWFLAKE_SEARCH_PARAMS,
  )
  rows: list[dict[str, str | int | float]] = []
  total_iterations = get_total_iterations(search_params)
  total_iterations = get_total_iterations(search_params.user_input)
  batch_number = 0
  seen_subgraphs: dict[int, bool] = {}
  for (
    max_hops,
    extra_predicates,
  for max_hops, extra_predicates, row_retention_probability in tqdm(
    row_retention_probability,
    equality_lower_bound_probability,
  ) in tqdm(
    product(
      search_params.max_hops,
      search_params.extra_predicates,
      search_params.row_retention_probability,
      search_params.user_input.max_hops,
      search_params.user_input.extra_predicates,
      search_params.user_input.row_retention_probability,
      search_params.user_input.equality_lower_bound_probability,
    ),
    total=total_iterations,
    desc="Progress",
  ):
    batch_number += 1
    query_generator = QueryGenerator(
      QueryGenerationParameters(
        dataset=search_params.dataset,
        dataset=search_params.user_input.dataset,
        max_hops=max_hops,
        max_queries_per_fact_table=10,
        max_queries_per_signature=2,
        keep_edge_prob=0.2,
        extra_predicates=extra_predicates,
        row_retention_probability=float(row_retention_probability),
        max_queries_per_fact_table=search_params.user_input.max_queries_per_fact_table,
        max_queries_per_signature=search_params.user_input.max_queries_per_signature,
        keep_edge_probability=search_params.user_input.keep_edge_probability,
        seen_subgraphs=seen_subgraphs,
        predicate_parameters=PredicateParameters(
          extra_predicates=extra_predicates,
          row_retention_probability=row_retention_probability,
        seen_subgraphs=seen_subgraphs,
          operator_weights=search_params.user_input.operator_weights,
          equality_lower_bound_probability=equality_lower_bound_probability,
          extra_values_for_in=search_params.user_input.extra_values_for_in,
        ),
      )
    )
    for query in query_generator.generate_queries():
      selected_rows = get_result_from_duckdb(query.query, search_params.con)
      if selected_rows == -1:
        continue  # invalid query

124
125
126
127
128
129
130
131

132
133
134
133
134
135
136
137
138
139

140
141
142
143







-
+



          "predicate_number": query.predicate_number,
          "fact_table": query.fact_table,
          "max_hops": max_hops,
          "row_retention_probability": row_retention_probability,
        },
      )
    # Update the seen subgraphs with the new ones
    if search_params.unique_joins:
    if search_params.user_input.unique_joins:
      seen_subgraphs = query_generator.subgraph_generator.seen_subgraphs
  df_queries = pl.DataFrame(rows)
  query_writer.write_dataframe(df_queries)
Changes to src/query_generator/duckdb_connection/setup.py.
1
2
3
4
5
6
7
8
9

10
11
12
13
14
15
16
1
2
3
4
5
6
7
8

9
10
11
12
13
14
15
16








-
+







import os

import duckdb

from query_generator.utils.definitions import Dataset
from query_generator.utils.exceptions import (
  MissingScaleFactorError,
  PartiallySupportedDatasetError,
  UnkwonDatasetError,
  UnkownDatasetError,
)


def load_and_install_libraries() -> None:
  duckdb.install_extension("TPCDS")
  duckdb.install_extension("TPCH")
  duckdb.load_extension("TPCDS")
25
26
27
28
29
30
31
32

33
34
35
36
37
38
39
40
41
42
43

44
45
46
47
48
49
50
25
26
27
28
29
30
31

32
33
34
35
36
37
38
39
40
41
42

43
44
45
46
47
48
49
50







-
+










-
+







  if dataset == Dataset.TPCDS:
    con.execute(f"CALL dsdgen(sf = {scale_factor})")
  elif dataset == Dataset.TPCH:
    con.execute(f"CALL dbgen(sf = {scale_factor})")
  elif dataset == Dataset.JOB:
    raise PartiallySupportedDatasetError(dataset.value)
  else:
    raise UnkwonDatasetError(dataset)
    raise UnkownDatasetError(dataset)


def get_path(
  dataset: Dataset,
  scale_factor: float | int | None,
) -> str:
  if dataset in [Dataset.TPCDS, Dataset.TPCH]:
    return f"data/duckdb/{dataset.value}/{scale_factor}.db"
  if dataset == Dataset.JOB:
    return f"data/duckdb/{dataset.value}/job.db"
  raise UnkwonDatasetError(dataset.value)
  raise UnkownDatasetError(dataset.value)


def setup_duckdb(
  dataset: Dataset,
  scale_factor: int | float | None = None,
) -> duckdb.DuckDBPyConnection:
  """Installs TPCDS and TPCH datasets in DuckDB.
Changes to src/query_generator/join_based_query_generator/snowflake.py.
16
17
18
19
20
21
22

23



24
25
26
27
28

29
30
31
32
33
34
35
36
37
38
39
40
41

42
43
44
45
46
47

48
49
50
51
52
53
54
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

53
54
55
56
57
58
59
60







+

+
+
+





+













+





-
+








# fmt: on
from query_generator.join_based_query_generator.utils.query_writer import (
  Writer,
)
from query_generator.predicate_generator.predicate_generator import (
  HistogramDataType,
  PredicateEquality,
  PredicateGenerator,
  PredicateIn,
  PredicateRange,
  SupportedHistogramType,
)
from query_generator.utils.definitions import (
  Dataset,
  Extension,
  GeneratedQueryFeatures,
  PredicateParameters,
  QueryGenerationParameters,
)
from query_generator.utils.exceptions import InvalidHistogramTypeError
from query_generator.utils.utils import set_seed


class QueryBuilder:
  def __init__(
    self,
    subgraph_generator: SubGraphGenerator,
    # TODO(Gabriel): http://localhost:8080/tktview/b9400c203a38f3aef46ec250d98563638ba7988b
    tables_schema: Any,
    dataset: Dataset,
    predicate_params: PredicateParameters,
  ) -> None:
    self.sub_graph_gen = subgraph_generator
    self.table_to_pypika_table = {
      i: Table(i, alias=tables_schema[i]["alias"]) for i in tables_schema
    }
    self.predicate_gen = PredicateGenerator(dataset)
    self.predicate_gen = PredicateGenerator(dataset, predicate_params)
    self.tables_schema = tables_schema

  def get_subgraph_tables(
    self,
    subgraph: list[ForeignKeyGraph.Edge],
  ) -> list[str]:
    return list(
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

98






99
100
101
102
103



104
105
106
107



108
109

110
111
112
113


114
115
116
117
118
119



120

121
122
123
124


125
126
127
128

129
130
131
132
133
134
135

136
137
138
139


140
141
142

143
144
145
146
147
148
149
150
151
152
153
154

155
156
157
158
159
160
161

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
88
89
90
91
92
93
94


95
96
97
98


99
100

101
102
103
104
105
106
107
108



109
110
111




112
113
114


115

116


117
118
119
120



121
122
123
124

125
126
127


128
129
130
131
132

133



134
135


136

137


138
139



140
141
142
143
144
145
146
147
148
149
150
151

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175


176
177
178
179
180
181
182







-
-




-
-

+
-
+
+
+
+
+
+


-
-
-
+
+
+
-
-
-
-
+
+
+
-
-
+
-

-
-
+
+


-
-
-

+
+
+
-
+


-
-
+
+



-
+
-
-
-


-
-
+
-

-
-
+
+
-
-
-
+











-
+







+















-
-







      )
    return query

  def add_predicates(
    self,
    subgraph: list[ForeignKeyGraph.Edge],
    query: OracleQuery,
    extra_predicates: int,
    row_retention_probability: float,
  ) -> OracleQuery:
    subgraph_tables = self.get_subgraph_tables(subgraph)
    for predicate in self.predicate_gen.get_random_predicates(
      subgraph_tables,
      extra_predicates,
      row_retention_probability,
    ):
      if isinstance(predicate, PredicateRange):
      query = self._add_range(query, predicate)
        return self._add_range(query, predicate)
      if isinstance(predicate, PredicateEquality):
        return self._add_equality(query, predicate)
      if isinstance(predicate, PredicateIn):
        return self._add_in(query, predicate)
      raise InvalidHistogramTypeError(str(predicate.dtype))
    return query

  def _add_range(
    self, query: OracleQuery, predicate: PredicateGenerator.Predicate
  ) -> OracleQuery:
  def _cast_if_needed(
    self, value: SupportedHistogramType, dtype: HistogramDataType
  ) -> Any:
    if predicate.dtype in [HistogramDataType.INT, HistogramDataType.FLOAT]:
      return self._add_range_number(query, predicate)
    if predicate.dtype in [HistogramDataType.DATE]:
      return self._add_range_date(query, predicate)
    """Cast the value to the appropriate type if needed."""
    if dtype == HistogramDataType.DATE:
      return fn.Cast(value, "date")
    if predicate.dtype in [HistogramDataType.STRING]:
      return self._add_range_string(query, predicate)
    return value
    raise InvalidHistogramTypeError(str(predicate.dtype))

  def _add_range_number(
    self, query: OracleQuery, predicate: PredicateGenerator.Predicate
  def _add_range(
    self, query: OracleQuery, predicate: PredicateRange
  ) -> OracleQuery:
    return query.where(
      self.table_to_pypika_table[predicate.table][predicate.column]
      >= predicate.min_value,
    ).where(
      self.table_to_pypika_table[predicate.table][predicate.column]
      >= self._cast_if_needed(predicate.min_value, predicate.dtype),
    ).where(
      self.table_to_pypika_table[predicate.table][predicate.column]
      <= predicate.max_value,
      <= self._cast_if_needed(predicate.max_value, predicate.dtype)
    )

  def _add_range_date(
    self, query: OracleQuery, predicate: PredicateGenerator.Predicate
  def _add_equality(
    self, query: OracleQuery, predicate: PredicateEquality
  ) -> OracleQuery:
    return query.where(
      self.table_to_pypika_table[predicate.table][predicate.column]
      >= fn.Cast(predicate.min_value, "date"),
      == predicate.equality_value
    ).where(
      self.table_to_pypika_table[predicate.table][predicate.column]
      <= fn.Cast(predicate.max_value, "date"),
    )

  def _add_range_string(
    self, query: OracleQuery, predicate: PredicateGenerator.Predicate
  def _add_in(self, query: OracleQuery, predicate: PredicateIn) -> OracleQuery:
  ) -> OracleQuery:
    return query.where(
      self.table_to_pypika_table[predicate.table][predicate.column]
      >= predicate.min_value,
      self.table_to_pypika_table[predicate.table][predicate.column].isin(
        [self._cast_if_needed(i, predicate.dtype) for i in predicate.in_values]
    ).where(
      self.table_to_pypika_table[predicate.table][predicate.column]
      <= predicate.max_value
      )
    )


class QueryGenerator:
  def __init__(self, params: QueryGenerationParameters) -> None:
    set_seed()
    self.params = params
    self.tables_schema, self.fact_tables = get_schema(params.dataset)
    self.foreign_key_graph = ForeignKeyGraph(self.tables_schema)
    self.subgraph_generator = SubGraphGenerator(
      self.foreign_key_graph,
      params.keep_edge_prob,
      params.keep_edge_probability,
      params.max_hops,
      params.seen_subgraphs,
    )
    self.query_builder = QueryBuilder(
      self.subgraph_generator,
      self.tables_schema,
      params.dataset,
      params.predicate_parameters,
    )

  def generate_queries(self) -> Iterator[GeneratedQueryFeatures]:
    for fact_table in self.fact_tables:
      for cnt, subgraph in enumerate(
        self.subgraph_generator.generate_subgraph(
          fact_table,
          self.params.max_queries_per_fact_table,
        ),
      ):
        query = self.query_builder.generate_query_from_subgraph(subgraph)
        for idx in range(1, self.params.max_queries_per_signature + 1):
          query = self.query_builder.add_predicates(
            subgraph,
            query,
            self.params.extra_predicates,
            self.params.row_retention_probability,
          )

          yield GeneratedQueryFeatures(
            query=query.get_sql(),
            template_number=cnt,
            predicate_number=idx,
            fact_table=fact_table,
Changes to src/query_generator/join_based_query_generator/utils/subgraph_generator.py.
9
10
11
12
13
14
15
16

17
18
19
20
21

22
23
24
25
26
27


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

49
50
51
52
53
54
55
9
10
11
12
13
14
15

16
17
18
19
20

21
22
23
24
25
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

49
50
51
52
53
54
55
56







-
+




-
+





-
+
+




















-
+







MAX_ATTEMPTS_FOR_NEW_SUBGRAPH = 1000


class SubGraphGenerator:
  def __init__(
    self,
    graph: ForeignKeyGraph,
    keep_edge_prob: float,
    keep_edge_probability: float,
    max_hops: int,
    seen_subgraphs: dict[int, bool],
  ) -> None:
    self.hops = max_hops
    self.keep_edge_prob = keep_edge_prob
    self.keep_edge_probability = keep_edge_probability
    self.graph = graph
    self.seen_subgraphs: dict[int, bool] = seen_subgraphs.copy()

  def get_random_subgraph(self, fact_table: str) -> list[ForeignKeyGraph.Edge]:
    """Starting from the fact table, for each edge of the current table we
    decide based on the keep_edge_probability whether to keep the edge or not.
    decide based on the keep_edge_probabilityability whether to keep the
    edge or not.

    We repeat this process up until the maximum number of hops.
    """

    @dataclass
    class JoinDepthNode:
      table: str
      depth: int

    queue: deque[JoinDepthNode] = deque()
    queue.append(JoinDepthNode(fact_table, 0))
    edges_subgraph = []

    while queue:
      current_node = queue.popleft()
      if current_node.depth >= self.hops:
        continue

      current_edges = self.graph.get_edges(current_node.table)
      for current_edge in current_edges:
        if random.random() < self.keep_edge_prob:
        if random.random() < self.keep_edge_probability:
          edges_subgraph.append(current_edge)
          queue.append(
            JoinDepthNode(
              current_edge.reference_table.name,
              current_node.depth + 1,
            ),
          )
Changes to src/query_generator/main.py.
26
27
28
29
30
31
32





33
34
35
36
37
38
39
40
41
42

43
44
45
46
47

48
49
50

51
52
53
54
55
56
57
58
59

60
61

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

103
104

105
106
107
108
109
110





111
112
113

114
115
116
117
118
119
120
121
122
123
124
125

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

143
144
145

146
147
148
149
150
151
152
153
154

155
156
157
158
159
160
161
162
163
164

165
166

167
168

169
170
171
172
173
174
175
176
177



178
179
180
181
182
183




184
185
186
187
188

189
190
191
192
193
194
195
196
197
198
199
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

47





48
49


50









51


52


















53


54




















55
56
57
58
59





60
61
62
63
64


65
66
67
68
69
70
71
72






73

















74
75


76









77










78


79
80

81
82
83
84
85
86
87



88
89
90






91
92
93
94
95
96
97
98

99




100
101
102
103
104
105
106







+
+
+
+
+









-
+
-
-
-
-
-
+

-
-
+
-
-
-
-
-
-
-
-
-
+
-
-
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
-
-

-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+


+

-
-
-
-
-
+
+
+
+
+
-
-

+






-
-
-
-
-
-
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+

-
-
+
-
-
-
-
-
-
-
-
-
+
-
-
-
-
-
-
-
-
-
-
+
-
-
+

-
+






-
-
-
+
+
+
-
-
-
-
-
-
+
+
+
+




-
+
-
-
-
-







  make_redundant_histograms,
  query_histograms,
)
from query_generator.utils.definitions import (
  Dataset,
  Extension,
  QueryGenerationParameters,
)
from query_generator.utils.params import (
  SearchParametersEndpoint,
  SnowflakeEndpoint,
  read_and_parse_toml,
)
from query_generator.utils.show_messages import show_dev_warning
from query_generator.utils.utils import validate_file_path

app = typer.Typer(name="Query Generation")


@app.command()
def snowflake(
  dataset: Annotated[
  config_path: Annotated[
    Dataset,
    typer.Option("--dataset", "-d", help="The dataset used"),
  ],
  max_hops: Annotated[
    int,
    str,
    typer.Option(
      "--max-hops",
      "-h",
      "-c",
      help="The maximum number of hops",
      min=1,
      max=5,
    ),
  ] = 3,
  max_queries_per_fact_table: Annotated[
    int,
    typer.Option(
      "--fact",
      "--config",
      "-f",
      help="The maximum number of queries per fact table",
      help="The path to the configuration file"
      min=1,
    ),
  ] = 100,
  max_queries_per_signature: Annotated[
    int,
    typer.Option(
      "--signature",
      "-s",
      help="The maximum number of queries per signature/template",
      min=1,
    ),
  ] = 1,
  keep_edge_prob: Annotated[
    float,
    typer.Option(
      "--edge-prob",
      "-p",
      help="The probability of keeping an edge in the subgraph",
      "They can be found in the params_config/query_generation/ folder",
      min=0.0,
      max=1.0,
    ),
  ] = 0.2,
  row_retention_probability: Annotated[
    float,
    typer.Option(
      "--row-retention",
      "-r",
      help="The probability of keeping a row in each predicate",
      min=0.0,
      max=1.0,
    ),
  ] = 0.2,
  extra_predicates: Annotated[
    int,
    typer.Option(
      "--extra-predicates",
      "-e",
      help="The number of extra predicates to add to the query",
      min=0,
    ),
  ] = 3,
  ],
) -> None:
  """Generate queries using a random subgraph."""
  params_endpoint = read_and_parse_toml(Path(config_path), SnowflakeEndpoint)
  params = QueryGenerationParameters(
    dataset=dataset,
    max_hops=max_hops,
    max_queries_per_fact_table=max_queries_per_fact_table,
    max_queries_per_signature=max_queries_per_signature,
    keep_edge_prob=keep_edge_prob,
    dataset=params_endpoint.dataset,
    max_hops=params_endpoint.max_hops,
    max_queries_per_fact_table=params_endpoint.max_queries_per_fact_table,
    max_queries_per_signature=params_endpoint.max_queries_per_signature,
    keep_edge_probability=params_endpoint.keep_edge_probability,
    extra_predicates=extra_predicates,
    row_retention_probability=row_retention_probability,
    seen_subgraphs={},
    predicate_parameters=params_endpoint.predicate_parameters,
  )
  generate_and_write_queries(params)


@app.command()
def param_search(
  dataset: Annotated[
    Dataset,
    typer.Option("--dataset", "-d", help="The dataset used"),
  ],
  *,
  dev: Annotated[
  config_path: Annotated[
    bool,
    typer.Option(
      "--dev",
      help="Development testing. If true then uses scale factor 0.1 to check.",
    ),
  ] = False,
  unique_joins: Annotated[
    bool,
    typer.Option(
      "--unique-joins",
      "-u",
      help="If true all queries will have a unique join structure "
      "(not recommended for TPC-H)",
    ),
  ] = False,
  max_hops_range: Annotated[
    list[int] | None,
    str,
    typer.Option(
      "--max-hops-range",
      "-h",
      "-c",
      help="The range of hops to use for the query generation",
      show_default="1, 2, 4",
    ),
  ] = None,
  extra_predicates_range: Annotated[
    list[int] | None,
    typer.Option(
      "--extra-predicates-range",
      "-e",
      "--config",
      help="The range of extra predicates to use for the query generation",
      show_default="1, 2, 3, 5",
    ),
  ] = None,
  row_retention_probability_range: Annotated[
    list[float] | None,
    typer.Option(
      "--row-retention-probability-range",
      "-r",
      help="The range of row retention probabilities to use "
      help="The path to the configuration file"
      "for the query generation",
      show_default="0.2, 0.3, 0.4, 0.6, 0.8, 0.85, 0.9, 1.0",
      "They can be found in the params_config/search_params/ folder",
    ),
  ] = None,
  ],
) -> None:
  """This is an extension of the Snowflake algorithm.

  It runs multiple batches with different configurations of the algorithm.
  This allows us to get multiple results.
  """
  if max_hops_range is None:
    max_hops_range = [1, 2, 4]
  if extra_predicates_range is None:
  params = read_and_parse_toml(
    Path(config_path),
    SearchParametersEndpoint,
    extra_predicates_range = [1, 2, 3, 5]
  if row_retention_probability_range is None:
    row_retention_probability_range = [0.2, 0.3, 0.4, 0.6, 0.8, 0.85, 0.9, 1.0]
  show_dev_warning(dev=dev)
  scale_factor = 0.1 if dev else 100
  con = setup_duckdb(dataset, scale_factor)
  )
  show_dev_warning(dev=params.dev)
  scale_factor = 0.1 if params.dev else 100
  con = setup_duckdb(params.dataset, scale_factor)
  run_snowflake_param_seach(
    SearchParameters(
      scale_factor=scale_factor,
      con=con,
      dataset=dataset,
      user_input=params,
      max_hops=max_hops_range,
      extra_predicates=extra_predicates_range,
      row_retention_probability=row_retention_probability_range,
      unique_joins=unique_joins,
    ),
  )


@app.command()
def cherry_pick(
  dataset: Annotated[
Changes to src/query_generator/predicate_generator/predicate_generator.py.
1
2

3
4
5
6

7
8
9
10









11
12
13

14
15
16
17
18











19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34












35











36

37
38

39
40
41


42
43
44
45
46
47
48
49
50
51
52
53
54

55
56

57













58

59
60

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

81
82
83
84
85
86
87
88
89
90
91
92

















93
94
95
96
97
98
99

100
101
102
103
104
105
106
107
108

109
110
111
112
113
114
115
116



117
118
119







120
121
122
123
124
125
126
127
128
129
130
131


























































































132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

151
152

153
154
155
156
157
158
159
160
1
2
3
4
5
6
7
8
9
10


11
12
13
14
15
16
17
18
19
20
21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46








47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

71
72
73
74
75


76
77
78
79
80
81
82
83
84
85
86
87
88
89

90
91

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

107
108

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162



163
164
165
166
167
168
169
170
171

172
173
174
175
176
177
178


179
180
181

182
183
184
185
186
187
188
189
190












191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280

281
282
283
284

285
286
287
288
289
290
291
292
293
294
295
296

297
298

299
300
301
302
303
304
305
306
307


+




+


-
-
+
+
+
+
+
+
+
+
+


-
+





+
+
+
+
+
+
+
+
+
+
+








-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+

+
+
+
+
+
+
+
+
+
+
+
-
+


+

-
-
+
+












-
+

-
+

+
+
+
+
+
+
+
+
+
+
+
+
+
-
+

-
+



















-
+












+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+




-
-
-
+








-
+






-
-
+
+
+
-


+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-




-












-
+

-
+








import math
import random
from abc import ABC
from collections.abc import Iterator
from dataclasses import dataclass
from enum import Enum

import numpy as np
import polars as pl

from query_generator.tools.histograms import HistogramColumns
from query_generator.utils.definitions import Dataset
from query_generator.tools.histograms import (
  HistogramColumns,
  MostCommonValuesColumns,
)
from query_generator.utils.definitions import (
  Dataset,
  PredicateOperatorProbability,
  PredicateParameters,
)
from query_generator.utils.exceptions import (
  InvalidHistogramTypeError,
  UnkwonDatasetError,
  UnkownDatasetError,
)

SupportedHistogramType = float | int | str
SuportedHistogramArrayType = list[float] | list[int] | list[str]


MAX_DISTINCT_COUNT_FOR_RANGE = 500
PROBABILITY_TO_CHOOSE_EQUALITY = 0.8
PREDICATE_IN_SIZE = 5


class PredicateTypes(Enum):
  IN = "in"
  RANGE = "range"
  EQUALITY = "equality"


class HistogramDataType(Enum):
  INT = "int"
  FLOAT = "float"
  DATE = "date"
  STRING = "string"


class PredicateGenerator:
  @dataclass
  class Predicate:
    table: str
    column: str
    min_value: SupportedHistogramType
    max_value: SupportedHistogramType
    dtype: HistogramDataType
@dataclass
class Predicate(ABC):
  table: str
  column: str
  dtype: HistogramDataType


@dataclass
class PredicateRange(Predicate):
  min_value: SupportedHistogramType
  max_value: SupportedHistogramType


@dataclass
class PredicateEquality(Predicate):
  equality_value: SupportedHistogramType


@dataclass
class PredicateIn(Predicate):
  in_values: SuportedHistogramArrayType


class PredicateGenerator:
  def __init__(self, dataset: Dataset):
  def __init__(self, dataset: Dataset, predicate_params: PredicateParameters):
    self.dataset = dataset
    self.histogram: pl.DataFrame = self.read_histogram()
    self.predicate_params = predicate_params

  def _parse_bin(
    self, hist_array: list[str], dtype: HistogramDataType
  def _cast_array(
    self, str_array: list[str], dtype: HistogramDataType
  ) -> SuportedHistogramArrayType:
    """Parse the bin string representation to a list of values.

    Args:
        bin_str (str): String representation of bins.
        dtype (str): Data type of the values.

    Returns:
        list: List of parsed values.

    """
    if dtype == HistogramDataType.INT:
      return [int(float(x)) for x in hist_array]
      return [int(float(x)) for x in str_array]
    if dtype == HistogramDataType.FLOAT:
      return [float(x) for x in hist_array]
      return [float(x) for x in str_array]
    if dtype == HistogramDataType.DATE:
      return str_array
    if dtype == HistogramDataType.STRING:
      return str_array
    raise InvalidHistogramTypeError(dtype)

  def _cast_element(
    self, value: str, dtype: HistogramDataType
  ) -> SupportedHistogramType:
    if dtype == HistogramDataType.INT:
      return int(float(value))
    if dtype == HistogramDataType.FLOAT:
      return float(value)
    if dtype == HistogramDataType.DATE:
      return hist_array
      return value
    if dtype == HistogramDataType.STRING:
      return hist_array
      return value
    raise InvalidHistogramTypeError(dtype)

  def read_histogram(self) -> pl.DataFrame:
    """Read the histogram data for the specified dataset.

    Args:
        dataset: The dataset type (TPCH or TPCDS).

    Returns:
        pd.DataFrame: DataFrame containing the histogram data.

    """
    if self.dataset == Dataset.TPCH:
      path = "data/histograms/histogram_tpch.parquet"
    elif self.dataset == Dataset.TPCDS:
      path = "data/histograms/histogram_tpcds.parquet"
    elif self.dataset == Dataset.JOB:
      path = "data/histograms/histogram_job.parquet"
    else:
      raise UnkwonDatasetError(self.dataset.value)
      raise UnkownDatasetError(self.dataset.value)
    return pl.read_parquet(path).filter(pl.col("histogram") != [])

  def _get_histogram_type(self, dtype: str) -> HistogramDataType:
    if dtype in ["INTEGER", "BIGINT"]:
      return HistogramDataType.INT
    if dtype.startswith("DECIMAL"):
      return HistogramDataType.FLOAT
    if dtype == "DATE":
      return HistogramDataType.DATE
    if dtype == "VARCHAR":
      return HistogramDataType.STRING
    raise InvalidHistogramTypeError(dtype)

  def _choose_predicate_type(
    self, operator_weights: PredicateOperatorProbability
  ) -> PredicateTypes:
    weights = [
      operator_weights.operator_equal,
      operator_weights.operator_in,
      operator_weights.operator_range,
    ]
    return random.choices(
      [
        PredicateTypes.EQUALITY,
        PredicateTypes.IN,
        PredicateTypes.RANGE,
      ],
      weights=weights,
    )[0]

  def get_random_predicates(
    self,
    tables: list[str],
    num_predicates: int,
    row_retention_probability: float,
  ) -> Iterator["PredicateGenerator.Predicate"]:
  ) -> Iterator[Predicate]:
    """Generate random predicates based on the histogram data.

    Args:
        tables (str): List of tables to select predicates from.
        num_predicates (int): Number of predicates to generate.
        row_retention_probability (float): Probability of retaining rows.

    Returns:
        List[PredicateGenerator.Predicate]: List of generated predicates.
        List[Predicate]: List of generated predicates.

    """
    selected_tables_histogram = self.histogram.filter(
      pl.col(HistogramColumns.TABLE.value).is_in(tables)
    )

    for row in selected_tables_histogram.sample(n=num_predicates).iter_rows(
      named=True
    for row in selected_tables_histogram.sample(
      n=self.predicate_params.extra_predicates
    ).iter_rows(named=True):
    ):
      table = row[HistogramColumns.TABLE.value]
      column = row[HistogramColumns.COLUMN.value]
      dtype = self._get_histogram_type(row[HistogramColumns.DTYPE.value])
      predicate_type = self._choose_predicate_type(
        self.predicate_params.operator_weights
      )

      if predicate_type == PredicateTypes.RANGE:
        yield self._get_range_predicate(
      bins = row[HistogramColumns.HISTOGRAM.value]
      dtype = self._get_histogram_type(row[HistogramColumns.DTYPE.value])
      min_value, max_value = self._get_min_max_from_bins(
        bins, row_retention_probability, dtype
      )
      predicate = PredicateGenerator.Predicate(
        table=table,
        column=column,
        min_value=min_value,
        max_value=max_value,
        dtype=dtype,
      )
          table, column, row[HistogramColumns.HISTOGRAM.value], dtype
        )
      elif predicate_type == PredicateTypes.IN:
        array = self._get_in_array(
          row[HistogramColumns.MOST_COMMON_VALUES.value],
          row[HistogramColumns.TABLE_SIZE.value],
          row[HistogramColumns.HISTOGRAM_MCV.value],
        )
        if array is not None:
          yield self._get_in_predicate(array, table, column, dtype)
        else:
          continue
      elif predicate_type == PredicateTypes.EQUALITY:
        value = self._get_equality_value(
          row[HistogramColumns.MOST_COMMON_VALUES.value],
          row[HistogramColumns.TABLE_SIZE.value],
        )
        if value is not None:
          yield self._get_equality_predicate(value, table, column, dtype)
        else:
          continue

  def _get_in_predicate(
    self, array: list[str], table: str, column: str, dtype: HistogramDataType
  ) -> PredicateIn:
    cast_array = self._cast_array(array, dtype)
    return PredicateIn(table, column, dtype, cast_array)

  def _get_in_array(
    self,
    most_common_values: list[dict[str, int | str]],
    table_size: int,
    histogram: list[str],
  ) -> list[str] | None:
    """
    Gets the array for the IN operator
    """
    value = self._get_equality_value(most_common_values, table_size)
    if value is None:
      return None
    noise_values = random.sample(
      histogram,
      k=min(self.predicate_params.extra_values_for_in, len(histogram)),
    )
    return [value] + noise_values

  def _get_equality_predicate(
    self, value: str, table: str, column: str, dtype: HistogramDataType
  ) -> PredicateEquality:
    cast_value = self._cast_element(value, dtype)
    return PredicateEquality(
      table=table, column=column, dtype=dtype, equality_value=cast_value
    )

  def _get_equality_value(
    self,
    most_common_values: list[dict[str, int | str]],
    table_size: int,
  ) -> str | None:
    mcv_probabilities: list[float] = [
      float(table_size) / float(v[MostCommonValuesColumns.COUNT.value])
      for v in most_common_values
    ]
    mcv_probabilities_np = np.array(mcv_probabilities)
    filtered_indices = np.where(
      mcv_probabilities_np
      > self.predicate_params.equality_lower_bound_probability
    )[0]
    if len(filtered_indices) == 0:
      return None
    idx = random.choice(filtered_indices)
    value = most_common_values[idx][MostCommonValuesColumns.VALUE.value]
    assert isinstance(value, str)
    return value

  def _get_range_predicate(
    self,
    table: str,
    column: str,
    bins: list[str],
    dtype: HistogramDataType,
  ) -> PredicateRange:
    min_value, max_value = self._get_min_max_from_bins(bins, dtype)
    return PredicateRange(
      table=table,
      column=column,
      min_value=min_value,
      max_value=max_value,
      dtype=dtype,
    )
      yield predicate

  def _get_min_max_from_bins(
    self,
    bins: list[str],
    row_retention_probability: float,
    dtype: HistogramDataType,
  ) -> tuple[SupportedHistogramType, SupportedHistogramType]:
    """Convert the bins string representation to a tuple of min and max values.

    Args:
        bins (str): String representation of bins.
        row_retention_probability (float): Probability of retaining rows.

    Returns:
        tuple: Tuple containing min and max values.

    """
    histogram_array: SuportedHistogramArrayType = self._parse_bin(bins, dtype)
    histogram_array: SuportedHistogramArrayType = self._cast_array(bins, dtype)
    subrange_length = math.ceil(
      row_retention_probability * len(histogram_array)
      self.predicate_params.row_retention_probability * len(histogram_array)
    )
    start_index = random.randint(0, len(histogram_array) - subrange_length)

    min_value = histogram_array[start_index]
    max_value = histogram_array[
      min(start_index + subrange_length, len(histogram_array) - 1)
    ]
    return min_value, max_value
Changes to src/query_generator/tools/histograms.py.
16
17
18
19
20
21
22
23




24
25
26
27
28
29
30
16
17
18
19
20
21
22

23
24
25
26
27
28
29
30
31
32
33







-
+
+
+
+







  get_equi_height_histogram,
  get_frequent_non_null_values,
  get_histogram_excluding_common_values,
  get_tables,
)
from query_generator.utils.exceptions import InvalidHistogramTypeError

LIMIT_FOR_DISTINCT_VALUES = 1000

class MostCommonValuesColumns(Enum):
  VALUE = "value"
  COUNT = "count"


class RedundantHistogramsDataType(Enum):
  """
  This class was made for compatibility with old code that
  generated this histogram:
  https://github.com/udao-moo/udao-spark-optimizer-dev/blob/main
87
88
89
90
91
92
93
94
95
96
97
98

99
100
101
102
103
104
105
106
90
91
92
93
94
95
96

97



98

99
100
101
102
103
104
105







-

-
-
-
+
-









def get_most_common_values(
  con: duckdb.DuckDBPyConnection,
  table: str,
  column: str,
  common_value_size: int,
  distinct_count: int,
) -> list[RawDuckDBMostCommonValues]:
  result: list[RawDuckDBMostCommonValues] = []
  if distinct_count < LIMIT_FOR_DISTINCT_VALUES:
    result = get_frequent_non_null_values(con, table, column, common_value_size)
  return get_frequent_non_null_values(con, table, column, common_value_size)
  return result


def get_histogram_array(histogram_params: HistogramParams) -> list[str]:
  histogram_raw = get_equi_height_histogram(
    histogram_params.con,
    histogram_params.table,
    histogram_params.column.column_name,
114
115
116
117
118
119
120
121
122
123

124
125
126
127
128
129
130
131
113
114
115
116
117
118
119



120

121
122
123
124
125
126
127







-
-
-
+
-








def get_histogram_array_excluding_common_values(
  histogram_params: HistogramParams,
  common_values_size: int,
  distinct_count: int,
) -> list[str]:
  histogram_array: list[RawDuckDBHistograms] = []
  if (
    distinct_count < LIMIT_FOR_DISTINCT_VALUES
    and distinct_count > common_values_size
  if distinct_count > common_values_size:
  ):
    histogram_array = get_histogram_excluding_common_values(
      histogram_params.con,
      histogram_params.table,
      histogram_params.column.column_name,
      histogram_params.histogram_size,
      common_values_size,
    )
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

201



202
203
204
205
206
207
208
176
177
178
179
180
181
182

183
184
185
186
187
188
189
190
191
192
193
194
195
196

197
198
199
200
201
202
203
204
205
206







-













+
-
+
+
+







      if include_mvc:
        # Get most common values
        most_common_values = get_most_common_values(
          con,
          table,
          column.column_name,
          common_values_size,
          distinct_count,
        )

        # Get histogram array excluding common values
        histogram_array_excluding_mcv = (
          get_histogram_array_excluding_common_values(
            histogram_params,
            common_values_size,
            distinct_count,
          )
        )

        row_dict |= {
          HistogramColumns.MOST_COMMON_VALUES.value: [
            {
            {"value": value.value, "count": value.count}
              MostCommonValuesColumns.VALUE.value: value.value,
              MostCommonValuesColumns.COUNT.value: value.count,
            }
            for value in most_common_values
          ],
          HistogramColumns.HISTOGRAM_MCV.value: histogram_array_excluding_mcv,
        }

      rows.append(row_dict)
  return pl.DataFrame(rows)
Changes to src/query_generator/utils/definitions.py.
15
16
17
18
19
20
21






















22

23
24
25
26

27
28
29
30

31
32
33
34
35
36
37
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

49



50
51
52
53
54
55
56
57
58







+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+

+



-
+
-
-
-

+







class Dataset(Enum):
  TPCDS = "TPCDS"
  TPCH = "TPCH"
  JOB = "JOB"


@dataclass
class PredicateOperatorProbability:
  """Probability of using a specific predicate operator.

  They are based on choice with weights for each operator.
  """

  operator_in: float
  operator_equal: float
  operator_range: float


@dataclass
class PredicateParameters:
  extra_predicates: int
  row_retention_probability: float
  operator_weights: PredicateOperatorProbability
  equality_lower_bound_probability: float
  extra_values_for_in: int


# TODO(Gabriel): http://localhost:8080/tktview/205e90a1fa
@dataclass
class QueryGenerationParameters:
  dataset: Dataset
  max_hops: int
  max_queries_per_signature: int
  max_queries_per_fact_table: int
  keep_edge_prob: float
  keep_edge_probability: float
  dataset: Dataset
  extra_predicates: int
  row_retention_probability: float
  seen_subgraphs: dict[int, bool]
  predicate_parameters: PredicateParameters


@dataclass
class GeneratedQueryFeatures:
  query: str
  template_number: int
  predicate_number: int
Changes to src/query_generator/utils/exceptions.py.
23
24
25
26
27
28
29
30

31
32
33
34
35
36
37
23
24
25
26
27
28
29

30
31
32
33
34
35
36
37







-
+









class DuplicateEdgesError(Exception):
  def __init__(self, table: str) -> None:
    super().__init__(f"Duplicate edges found for table {table}.")


class UnkwonDatasetError(Exception):
class UnkownDatasetError(Exception):
  def __init__(self, dataset: str) -> None:
    super().__init__(f"Unknown dataset: {dataset}")


class MissingScaleFactorError(Exception):
  def __init__(self, dataset: str) -> None:
    super().__init__(
Added src/query_generator/utils/params.py.






































































































1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
import tomllib
from dataclasses import dataclass
from pathlib import Path
from typing import TypeVar

from cattrs import structure

from query_generator.utils.definitions import (
  Dataset,
  PredicateOperatorProbability,
  PredicateParameters,
)


@dataclass
class SearchParametersEndpoint:
  """
  Represents the parameters used for configuring search queries, including
  query builder, subgraph, and predicate options.

  This class is designed to support both the `IN` and `=` statements in
  query generation.

  Attributes:
    dataset (Dataset): The dataset to be queried.
    dev (bool): Flag indicating whether to use development settings.
    max_queries_per_fact_table (int): Maximum number of queries per fact
      table.
    max_queries_per_signature (int): Maximum number of queries per
      signature.
    unique_joins (bool): Whether to enforce unique joins in the subgraph.
    max_hops (list[int]): Maximum number of hops allowed in the subgraph.
    keep_edge_probability (float): Probability of retaining an edge in the
      subgraph.
    extra_predicates (list[int]): Number of additional predicates to include
      in the query.
    row_retention_probability (list[float]): Probability of retaining a row
      for range predicates
    operator_weights (PredicateOperatorProbability): Probability
      distribution for predicate operators.
    equality_lower_bound_probability (float): Lower bound probability when
      using the `=` and the `IN` operators
  """

  # Query Builder
  dataset: Dataset
  dev: bool
  max_queries_per_fact_table: int
  max_queries_per_signature: int
  # Subgraph
  unique_joins: bool
  max_hops: list[int]
  keep_edge_probability: float
  # Predicates
  extra_predicates: list[int]
  row_retention_probability: list[float]
  operator_weights: PredicateOperatorProbability
  equality_lower_bound_probability: list[float]
  extra_values_for_in: int


@dataclass
class SnowflakeEndpoint:
  """
  Represents the parameters used for configuring query generation,
  including query builder, subgraph, and predicate options.

  Attributes:
    dataset (Dataset): The dataset to be used for query generation.
    max_queries_per_signature (int): Maximum number of queries to generate
      per signature.
    max_queries_per_fact_table (int): Maximum number of queries to generate
      per fact table.
    max_hops (int): Maximum number of hops allowed in the subgraph.
    keep_edge_probability (float): Probability of retaining an edge in the
      subgraph.
    extra_predicates (int): Number of extra predicates to add to the query.
    row_retention_probability (float): Probability of retaining a row after
      applying predicates.
    operator_weights (PredicateOperatorProbability): Probability
      distribution for predicate operators.
    equality_lower_bound_probability (float): Probability of using a lower
      bound for equality predicates.
  """

  # Query builder
  dataset: Dataset
  max_queries_per_signature: int
  max_queries_per_fact_table: int
  # Subgraph
  max_hops: int
  keep_edge_probability: float
  # Predicates
  predicate_parameters: PredicateParameters


T = TypeVar("T")


def read_and_parse_toml(path: Path, cls: type[T]) -> T:
  toml_dict = tomllib.loads(path.read_text())
  return structure(toml_dict, cls)
Changes to tests/duckdb/test_binning.py.

1
2

3
4
5
6
7
8
9
10
11

12
13
14
15
16
17
18
1
2
3
4
5
6
7
8
9
10
11
12

13
14
15
16
17
18
19
20
+


+








-
+







import tomllib
from unittest import mock

from cattrs import structure
import polars as pl
import pytest

from query_generator.duckdb_connection.binning import (
  SearchParameters,
  run_snowflake_param_seach,
)
from query_generator.tools.cherry_pick_binning import make_bins_in_csv
from query_generator.utils.definitions import Dataset
from query_generator.utils.params import SearchParametersEndpoint


@pytest.mark.parametrize(
  "count_star, upper_bound, total_bins, expected_bin",
  [
    (5, 10, 5, 3),
    (0, 10, 5, 0),
40
41
42
43
44
45
46
47
48


49
50
51


52
53
54
55
56
57
58
59
60
61



















62
63
64
65
66
67
68
69
70

71
72
73
74
75
76
42
43
44
45
46
47
48


49
50
51


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85




86

87
88
89
90
91
92
93







-
-
+
+

-
-
+
+










+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+



-
-
-
-

-
+






    " but got {computed_bin}"
  )


@pytest.mark.parametrize(
  "extra_predicates, expected_call_count, unique_joins",
  [
    ([1], 120 * 1 + 14, False),
    ([1], 120 * 1 + 14, True),
    ("[1]", 120 * 1 + 14, "false"),
    ("[1]", 120 * 1 + 14, "true"),
    # Inventory is small and prooduces 14 queries total
    ([1, 2], 120 * 2 + 14, True),
    ([1, 2], 120 * 2 + 14 * 2, False),
    ("[1, 2]", 120 * 2 + 14, "true"),
    ("[1, 2]", 120 * 2 + 14 * 2, "false"),
  ],
)
def test_binning_calls(extra_predicates, expected_call_count, unique_joins):
  with mock.patch(
    "query_generator.duckdb_connection.binning.Writer.write_query_to_batch",
  ) as mock_writer:
    with mock.patch(
      "query_generator.duckdb_connection.binning.get_result_from_duckdb",
    ) as mock_connect:
      mock_connect.return_value = 0
      data_toml = f"""
        dataset = "TPCDS"
        dev = true
        max_hops = [1]
        extra_predicates = {extra_predicates}
        row_retention_probability = [0.2]
        unique_joins = {unique_joins}
        max_queries_per_fact_table = 10
        max_queries_per_signature = 2
        keep_edge_probability = 0.2
        equality_lower_bound_probability = [0]
        extra_values_for_in = 3

        [operator_weights]
        operator_in = 1
        operator_range = 3
        operator_equal = 3
        """
      user_input = structure(tomllib.loads(data_toml), SearchParametersEndpoint)
      run_snowflake_param_seach(
        search_params=SearchParameters(
          scale_factor=0,
          dataset=Dataset.TPCDS,
          max_hops=[1],
          extra_predicates=extra_predicates,
          row_retention_probability=[0.2],
          con=None,
          unique_joins=unique_joins,
          user_input=user_input,
        ),
      )
    assert mock_writer.call_count == expected_call_count, (
      f"Expected {expected_call_count} calls to write_query, "
      f"but got {mock_writer.call_count}"
    )
Changes to tests/duckdb/test_duckdb_utils.py.


1
2
3
4
5
6
7
8
9

10
11
12
13
14
15
16
17
1
2
3
4
5
6
7
8
9
10

11

12
13
14
15
16
17
18
+
+








-
+
-







import datetime

from query_generator.duckdb_connection.setup import setup_duckdb
from query_generator.duckdb_connection.utils import (
  get_distinct_count,
  get_equi_height_histogram,
  get_frequent_non_null_values,
)
from query_generator.tools.histograms import DuckDBHistogramParser
from query_generator.utils.definitions import Dataset
from tests.utils import is_date, is_float
from tests.utils import is_float
import datetime


def test_distinct_values():
  """Test the setup of DuckDB."""
  # Setup DuckDB
  con = setup_duckdb(Dataset.TPCDS, 0.1)
  assert get_distinct_count(con, "call_center", "cc_call_center_sk") == 1
Changes to tests/file_management/test_read_histograms.py.
1
2

3
4
5
6
7
8
9
10
11

12
13
14
15
16
17

18
19
20
21
22
23
24
1
2
3
4
5

6
7
8
9
10

11
12
13
14
15
16

17
18
19
20
21
22
23
24


+


-





-
+





-
+







from unittest import mock

import polars as pl
import pytest

import polars as pl
from query_generator.predicate_generator.predicate_generator import (
  HistogramDataType,
  PredicateGenerator,
)
from query_generator.tools.histograms import HistogramColumns
from query_generator.utils.definitions import Dataset
from query_generator.utils.definitions import Dataset, PredicateParameters
from query_generator.utils.exceptions import InvalidHistogramTypeError


def test_read_histograms():
  for dataset in Dataset:
    predicate_generator = PredicateGenerator(dataset)
    predicate_generator = PredicateGenerator(dataset, None)
    histogram = predicate_generator.read_histogram()
    assert not histogram.is_empty()

    assert histogram[HistogramColumns.DTYPE.value].dtype == pl.Utf8
    assert histogram[HistogramColumns.COLUMN.value].dtype == pl.Utf8
    assert histogram[HistogramColumns.DTYPE.value].dtype == pl.Utf8
    assert histogram[HistogramColumns.HISTOGRAM.value].dtype == pl.List(pl.Utf8)
74
75
76
77
78
79
80
81










82
83

84
85
86
87
88
89
90

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

108
74
75
76
77
78
79
80

81
82
83
84
85
86
87
88
89
90
91

92
93
94
95
96
97
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

116
117







-
+
+
+
+
+
+
+
+
+
+

-
+






-
+
















-
+

  max_index,
  dtype,
):
  with mock.patch(
    "query_generator.predicate_generator.predicate_generator.random.randint",
    return_value=mock_rand,
  ):
    predicate_generator = PredicateGenerator(Dataset.TPCH)
    predicate_generator = PredicateGenerator(
      Dataset.TPCH,
      PredicateParameters(
        extra_predicates=None,
        row_retention_probability=row_retention_probability,
        operator_weights=None,
        equality_lower_bound_probability=None,
        extra_values_for_in=None,
      ),
    )
    min_value, max_value = predicate_generator._get_min_max_from_bins(
      bins_array, row_retention_probability, dtype
      bins_array, dtype
    )
  assert min_value == bins_array[min_index]
  assert max_value == bins_array[max_index]


def test_get_invalid_histogram_type():
  predicate_generator = PredicateGenerator(Dataset.TPCH)
  predicate_generator = PredicateGenerator(Dataset.TPCH, None)
  with pytest.raises(InvalidHistogramTypeError):
    predicate_generator._get_histogram_type("not_supported_type")


@pytest.mark.parametrize(
  "input_type, expected_type",
  [
    ("INTEGER", HistogramDataType.INT),
    ("BIGINT", HistogramDataType.INT),
    ("DECIMAL(10,2)", HistogramDataType.FLOAT),
    ("DECIMAL(7,4)", HistogramDataType.FLOAT),
    ("DATE", HistogramDataType.DATE),
    ("VARCHAR", HistogramDataType.STRING),
  ],
)
def test_get_valid_histogram_type(input_type, expected_type):
  predicate_generator = PredicateGenerator(Dataset.TPCH)
  predicate_generator = PredicateGenerator(Dataset.TPCH, None)
  assert predicate_generator._get_histogram_type(input_type) == expected_type
Changes to tests/query_generation/test_make_queries.py.
1
2
3
4
5
6
7
8
9
10
11
12
13
14

15
16
17


18
19
20

21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38














39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57













58
59
60
61
62
63
64
65
66
67
68

69
70
71
72
73
74
75
76
77
78













79
80
81
82
83
84
85
86












87
88
89
90
91
92

93
94
95
96
97
98
99
1
2
3

4
5
6
7
8
9
10
11
12

13
14
15
16
17
18
19
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35





36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64




65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

88
89
90
91
92
93
94




95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

132
133
134
135
136
137
138
139



-









-
+



+
+


-
+

+












-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+















-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+










-
+






-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+







-
+
+
+
+
+
+
+
+
+
+
+
+





-
+







from unittest import mock

import pytest
from pypika import functions as fn


from query_generator.database_schemas.schemas import get_schema
from query_generator.join_based_query_generator.snowflake import (
  QueryBuilder,
  generate_and_write_queries,
)
from query_generator.predicate_generator.predicate_generator import (
  HistogramDataType,
  PredicateGenerator,
  PredicateRange,
)
from query_generator.utils.definitions import (
  Dataset,
  PredicateOperatorProbability,
  PredicateParameters,
  QueryGenerationParameters,
)
from query_generator.utils.exceptions import UnkwonDatasetError
from query_generator.utils.exceptions import UnkownDatasetError
from pypika import OracleQuery
from pypika import functions as fn


def test_tpch_query_generation():
  with mock.patch(
    "query_generator.join_based_query_generator.snowflake.Writer.write_query",
  ) as mock_writer:
    generate_and_write_queries(
      QueryGenerationParameters(
        dataset=Dataset.TPCDS,
        max_hops=1,
        max_queries_per_fact_table=1,
        max_queries_per_signature=1,
        keep_edge_prob=0.2,
        row_retention_probability=0.2,
        extra_predicates=1,
        seen_subgraphs={},
      ),
        keep_edge_probability=0.2,
        seen_subgraphs={},
        predicate_parameters=PredicateParameters(
          operator_weights=PredicateOperatorProbability(
            operator_in=0.4,
            operator_equal=0.4,
            operator_range=0.2,
          ),
          extra_predicates=1,
          row_retention_probability=0.2,
          equality_lower_bound_probability=0,
          extra_values_for_in=3,
        ),
      )
    )

    assert mock_writer.call_count > 5


def test_tpcds_query_generation():
  with mock.patch(
    "query_generator.join_based_query_generator.snowflake.Writer.write_query",
  ) as mock_writer:
    generate_and_write_queries(
      QueryGenerationParameters(
        dataset=Dataset.TPCDS,
        max_hops=1,
        max_queries_per_fact_table=1,
        max_queries_per_signature=1,
        keep_edge_prob=0.2,
        row_retention_probability=0.2,
        extra_predicates=1,
        seen_subgraphs={},
        keep_edge_probability=0.2,
        seen_subgraphs={},
        predicate_parameters=PredicateParameters(
          operator_weights=PredicateOperatorProbability(
            operator_in=0.4,
            operator_equal=0.4,
            operator_range=0.2,
          ),
          extra_predicates=1,
          row_retention_probability=0.2,
          equality_lower_bound_probability=0,
          extra_values_for_in=3,
        ),
      ),
    )

    assert mock_writer.call_count > 5


def test_non_implemented_dataset():
  with mock.patch(
    "query_generator.join_based_query_generator.snowflake.Writer.write_query",
  ) as mock_writer:
    with pytest.raises(UnkwonDatasetError):
    with pytest.raises(UnkownDatasetError):
      generate_and_write_queries(
        QueryGenerationParameters(
          dataset="non_implemented_dataset",
          max_hops=1,
          max_queries_per_fact_table=1,
          max_queries_per_signature=1,
          keep_edge_prob=0.2,
          row_retention_probability=0.2,
          extra_predicates=1,
          seen_subgraphs={},
          keep_edge_probability=0.2,
          seen_subgraphs={},
          predicate_parameters=PredicateParameters(
            operator_weights=PredicateOperatorProbability(
              operator_in=0.4,
              operator_equal=0.4,
              operator_range=0.2,
            ),
            extra_predicates=1,
            row_retention_probability=0.2,
            equality_lower_bound_probability=0,
            extra_values_for_in=3,
          ),
        ),
      )
    assert mock_writer.call_count == 0


def test_add_rage_supports_all_histogram_types():
  tables_schema, _ = get_schema(Dataset.TPCH)
  query_builder = QueryBuilder(None, tables_schema, Dataset.TPCH)
  query_builder = QueryBuilder(
    None,
    tables_schema,
    Dataset.TPCH,
    PredicateParameters(
      extra_predicates=None,
      row_retention_probability=0.2,
      operator_weights=None,
      equality_lower_bound_probability=None,
      extra_values_for_in=None,
    ),
  )
  for dtype in HistogramDataType:
    query_builder._add_range(
      OracleQuery()
      .from_(query_builder.table_to_pypika_table["lineitem"])
      .select(fn.Count("*")),
      PredicateGenerator.Predicate(
      PredicateRange(
        table="lineitem",
        column="foo",
        min_value=2020,
        max_value=2020,
        dtype=dtype,
      ),
    )