From c6f208df6675e73ad5394781db5492d9d54d9f12 Mon Sep 17 00:00:00 2001 From: Yifan Zhao <yifanz16@illinois.edu> Date: Tue, 4 Jul 2023 23:02:29 -0500 Subject: [PATCH] Added feature extraction --- include/tvm/arith/egg_simpl.h | 3 +- include/tvm/arith/var_context.h | 44 +- python/tvm/felix/features.py | 349 ++++++++++ python/tvm/felix/ffi.py | 39 ++ python/tvm/felix/sketch.py | 28 +- src/arith/egg_simpl.cc | 15 +- src/arith/egg_simpl/src/lang.rs | 112 +++- src/arith/egg_simpl/src/lib.rs | 8 +- src/arith/var_context.cc | 102 ++- src/felix/constraints.cc | 207 ++++++ src/felix/feat_transform.cc | 1003 +++++++++++++++++++++++++++++ src/felix/features.cc | 1060 +++++++++++++++++++++++++++++++ src/felix/features.h | 29 + src/felix/rangeinfer.h | 262 ++++++++ src/felix/utils.h | 49 ++ src/tir/ir/expr.cc | 49 +- src/tir/op/op.cc | 14 +- 17 files changed, 3208 insertions(+), 165 deletions(-) create mode 100644 python/tvm/felix/features.py create mode 100644 src/felix/constraints.cc create mode 100644 src/felix/feat_transform.cc create mode 100644 src/felix/features.cc create mode 100644 src/felix/features.h create mode 100644 src/felix/rangeinfer.h diff --git a/include/tvm/arith/egg_simpl.h b/include/tvm/arith/egg_simpl.h index 28eaf3e2f..ce2a891ea 100644 --- a/include/tvm/arith/egg_simpl.h +++ b/include/tvm/arith/egg_simpl.h @@ -22,8 +22,7 @@ PrimExpr SimplifyExpr(const PrimExpr& expr, size_t max_n_iters = 30, size_t max_ PrimExpr SubAndSimplify(const PrimExpr& expr, const std::unordered_map<std::string, PrimExpr>& subst, - bool simpl_only_on_change, size_t max_n_iters = 30, - size_t max_n_nodes = 10000); + size_t max_n_iters = 30, size_t max_n_nodes = 10000); bool IsExprEquivalent(const PrimExpr& lhs, const PrimExpr& rhs, size_t max_n_iters = 30, size_t max_n_nodes = 10000, bool diff_approx = false); diff --git a/include/tvm/arith/var_context.h b/include/tvm/arith/var_context.h index bd8139840..4a0b7d2f3 100644 --- a/include/tvm/arith/var_context.h +++ b/include/tvm/arith/var_context.h @@ -6,6 +6,8 @@ #include <tvm/tir/expr.h> #include <tvm/tir/stmt_functor.h> +#include <optional> + namespace tvm { namespace arith { @@ -35,14 +37,11 @@ class VarExprPair : public ObjectRef { public: VarExprPair(tir::SizeVar var, PrimExpr expr) : VarExprPair(make_object<VarExprPairNode>(var, expr)) {} - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarExprPair, ObjectRef, VarExprPairNode); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VarExprPair, ObjectRef, VarExprPairNode); }; class VarDefStackNode : public Object { - using ContainerT = Array<VarExprPair>; - using ExprVisitor = std::function<PrimExpr(const PrimExpr& e)>; - using ThreadedExprVisitor = std::function<PrimExpr(const PrimExpr& e, size_t)>; - using VarExprVisitor = std::function<PrimExpr(const Var& v, const PrimExpr& e)>; + using ExprMutator = std::function<PrimExpr(const PrimExpr& e)>; public: void VisitAttrs(tvm::AttrVisitor* v) { @@ -53,32 +52,36 @@ class VarDefStackNode : public Object { tir::SizeVar Append(const std::string& vname, const PrimExpr& expr); void Append(const tir::SizeVar& var, const PrimExpr& expr); - tir::SizeVar FindOrAppend(const std::string& vname, const PrimExpr& expr); - - VarDefStackNode Prepend(const VarMapT& vmap) const; - VarMapT IntoVarMap() const; + PrimExpr DefineConstShorthand(PrimExpr expr); bool Contains(const std::string& vname) const { return this->var2idx.find(vname) != this->var2idx.end(); } size_t Size() const { return this->exprs.size(); } - const ContainerT& GetExprs() const { return this->exprs; } - PrimExpr& GetExprAt(const std::string& vname) { + + const Array<VarExprPair>& GetExprs() const { return this->exprs; } + const PrimExpr& GetExprAt(const std::string& vname) const { auto it = this->var2idx.find(vname); ICHECK(it != this->var2idx.end()) << "Var " << vname << " not found in VarDefStack"; return this->exprs[(*it).second]->expr; } - std::vector<Var> FreeVars() const; - bool HasUndefVars(const PrimExpr& expr) const; + VarMapT IntoUnwindedVarMap() const; + std::unordered_set<std::string> GetAllUsedVars(std::optional<tir::SizeVarKind> kind) const; + + void MapExprs(ExprMutator func); + void MapExprsParallel(ExprMutator func); static constexpr const char* _type_key = "arith.VarDefStack"; TVM_DECLARE_FINAL_OBJECT_INFO(VarDefStackNode, Object); private: - ContainerT exprs; + Array<VarExprPair> exprs; Map<String, Integer> var2idx; std::unordered_map<PrimExpr, size_t, StructuralHash, StructuralEqual> expr2idx; + + friend class VarDefStack; + friend class dmlc::json::Handler<VarDefStackNode>; }; class VarDefStack : public ObjectRef { @@ -125,20 +128,17 @@ class VarContextNode : public Object { Array<tir::SizeVar> GetSplitVars(const PrimExpr& extent, size_t n_splits, bool whole_div); std::pair<PrimExpr, PrimExpr> GetSplitSizes(const PrimExpr& extent, PrimExpr factor, bool no_tighten_factor); + PrimExpr DefineConstShorthand(PrimExpr expr) { + return this->var_defs->DefineConstShorthand(expr); + } - void DefineVar(const std::string& name, PrimExpr expr); - - PrimExpr DefineConstShorthand(PrimExpr expr); + static constexpr const char* _type_key = "arith.VarContext"; + TVM_DECLARE_FINAL_OBJECT_INFO(VarContextNode, Object); private: - Var AllocVarForExpr(PrimExpr expr); - std::pair<PrimExpr, PrimExpr> SymbolicDiv(PrimExpr numer, PrimExpr denom, bool no_tighten_factor); public: - static constexpr const char* _type_key = "arith.VarContext"; - TVM_DECLARE_FINAL_OBJECT_INFO(VarContextNode, Object); - Array<SplitGroup> split_groups{}; VarDefStack var_defs{}; diff --git a/python/tvm/felix/features.py b/python/tvm/felix/features.py new file mode 100644 index 000000000..4ad16937b --- /dev/null +++ b/python/tvm/felix/features.py @@ -0,0 +1,349 @@ +import logging +import typing as ty +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple, cast + +import numpy as np +import torch +from sympy.ntheory import factorint +from torch import Tensor, nn +from torch.fx._symbolic_trace import symbolic_trace +from torch.fx.graph_module import GraphModule +from tvm import tir + +from . import ffi +from .utils import transpose2 + +__all__ = ["TorchFeatures"] +logger = logging.getLogger(__name__) + +Number = ty.Union[int, float] +T = ty.TypeVar("T") + + +class TorchFeatures(nn.Module): + def __init__( + self, + features_f: GraphModule, + constraints_f: GraphModule, + lin_cons: list, + var_order: Dict[str, int], + var_decomp: Dict[str, Dict[int, tir.SizeVar]], + n_feats: int, + n_bufs: int, + ): + super().__init__() + self.features_f = features_f + self.constraints_f = constraints_f + self.lin_cons = lin_cons + self.var_order = var_order + self.var_decomp = { + k: {int(prime): v.name for prime, v in ds.items()} for k, ds in var_decomp.items() + } + self.n_feats, self.n_bufs = n_feats, n_bufs + self.n_consts = 0 + self._dummy_param = nn.Parameter(torch.empty(0)) # type: ignore + self._conf_maker: Optional[RandConfigMaker] = None + + @classmethod + def from_feat_pack(cls, feat: ffi.FeaturePack) -> "TorchFeatures": + var_order = {var: i for i, var in enumerate(sorted(feat.free_vars))} + features = defaultdict(list) + other_cons = [] + vdefs = {} + for vname, expr in feat.expressions: + if vname.startswith("BS"): + bufstore_idx = int(vname.split(".")[0][2:]) + features[bufstore_idx].append(expr) + elif vname.startswith("con_"): + other_cons.append(expr) + else: + vdefs[vname] = expr + features_ = np.array([features[i] for i in range(len(features))]) + n_bufs, n_feats = features_.shape + features_f = TorchExprRunner(features_, var_order, vdefs).get_traced() + lin_cons = list(feat.linear_cons) + all_cons = np.array([c.as_primexpr() for c in lin_cons] + other_cons) + cons_f = TorchExprRunner(all_cons, var_order, vdefs).get_traced() + return cls(features_f, cons_f, lin_cons, var_order, feat.var_decomp, n_feats, n_bufs) + + @property + def device(self): + return self._dummy_param.device + + def rand_configs(self, n_configs: int): + if self._conf_maker is None: + self._conf_maker = RandConfigMaker( + list(self.var_order.keys()), + self.lin_cons, + self.constraints_f, + self.device, + ) + return self._conf_maker.rand_configs(n_configs) + + def forward(self, params: Tensor) -> Tuple[Tensor, Tensor]: + # params: [batch_size, n_vars] + # features: [batch_size, n_bufs, n_feats] + features = self.features_f(params) + # leq_0s: [batch_size, n_constraints] + leq_0s = self.constraints_f(params) + return features, leq_0s + + def run_on_initial_configs(self, initial_configs: Sequence[dict]): + configs = [self.transform_config(c) for c in initial_configs] + return self(torch.stack(configs, dim=0)) + + def transform_config(self, config: Dict[str, Number]): + stage1: Dict[str, Number] = {} + for var_name, value in config.items(): + if var_name not in self.var_decomp: + stage1[var_name] = value + continue + decomposed = self.var_decomp[var_name] + if value == 0: + for vname in decomposed.values(): + stage1[vname] = -10 # HACK + continue + if len(decomposed) == 1: + ((b, vname),) = decomposed.items() + if b < 2: + raise ValueError() + stage1[vname] = np.log(value) / np.log(b) + continue + factors: Dict[int, int] = factorint(value) + uncovered_ps = set(factors.keys()) - set(decomposed.keys()) + if uncovered_ps: + bases = list(decomposed.keys()) + raise ValueError(f"Cannot factorize {value} into {bases}") + for basis, vname in decomposed.items(): + power = factors.get(basis, 0) + stage1[vname] = power + stage2: List[Optional[float]] = [None] * len(self.var_order) + for name, value in stage1.items(): + if name not in self.var_order: + continue + idx = self.var_order[name] + stage2[idx] = value + if any(x is None for x in stage2): + required = set(self.var_order.keys()) + got = set(stage1.keys()) + raise ValueError(f"Missing value for some symbols {required - got}") + return torch.tensor(cast(List[float], stage2)).float() + + def inv_transform_config(self, config: Tensor): + def expr_to_int(config, d: Dict[int, str]): + v = np.prod([prime ** config[vname] for prime, vname in d.items()]) + if not np.isclose(v, int(v)): + raise ValueError(f"Cannot convert {v} to int") + return int(v) + + stage1 = {k: config[idx].item() for k, idx in self.var_order.items()} + return {var: expr_to_int(stage1, d) for var, d in self.var_decomp.items()} + + def _concat_knobs(self, consts: Tensor, knobs: Tensor): + assert consts.shape[0] == knobs.shape[0] + assert consts.shape[1] == self.n_consts + assert consts.shape[1] + knobs.shape[1] == self.n_vars + return torch.cat([consts.float(), knobs.float()], dim=1) + + +class RandConfigMaker: + def __init__( + self, + variables: List[str], + lin_cons: List[ffi.LinearExpr], + constraints_f, + device, + ) -> None: + self.variables = variables + self.lin_cons = lin_cons + self.constraints_f = constraints_f + self.device = device + coefs, biases, self.bounds = self._estimate_var_bounds() + # Make configs in [0, 1]^n but zoom them before returning. + # Consider a single constraint: \sum A_ij x_j + b <= 0. + # Now x_i \in [0, bound_i]; make x'_i := x_i / bound_i, then x'_i \in [0, 1]. + # And the constraint becomes: \sum (A_ij * bound_j) x'_j + b <= 0. + # Therefore: + self.coefs, self.biases = coefs * self.bounds.unsqueeze(0), biases + + def rand_configs(self, n_configs: int): + # This point (0) should be in the polyhedron (on the boundary). + current = torch.zeros((1, self.coefs.shape[1]), device=self.device) + while len(current) < n_configs + 1: + next = self.rand_next_points(current) + # Restore configs into the original scale (multiplying with `bounds`; see above) + # Still needs to check as we have non-linear constraints too. + valid_mask = torch.all(self.constraints_f(next * self.bounds) <= 0, dim=1) + next = next[valid_mask] + # Don't insert scale-restored points into `configs` yet. + # We need to work with [0, 1]^n points here. + current = torch.cat([current, next], dim=0) + # ret: [n_configs, n_vars] (discard the first point which is (0)) + return current[1 : n_configs + 1] * self.bounds + + def rand_next_points(self, xs: Tensor): + # https://mathoverflow.net/questions/9854 + assert xs.shape[1] == self.coefs.shape[1] + alpha = torch.rand_like(xs, device=self.device) + alpha = alpha / torch.linalg.norm(alpha, dim=1, keepdim=True) + # Solve A_i . (x + t_i alpha) + b_i <= 0 (. is for dot prod) + # -> t_i ? (-b - A_i . x) / (A_i . alpha) + # (whether t_i is a lower or upper bound depends on the sign of A_i . alpha) + # alpha, xs: [n_configs, n_vars]; self.coefs: [n_constraints, n_vars] + denoms = torch.tensordot(alpha, self.coefs, dims=([1], [1])) # type: ignore + a_dot_x = torch.tensordot(xs, self.coefs, dims=([1], [1])) # type: ignore + # denoms, a_dot_x: [n_configs, n_constraints]; self.biases: [n_constraints] + ts: Tensor = (-self.biases - a_dot_x) / denoms + # ts: [n_configs, n_constraints] + # NOTE: for each row, we would have at least one column where denoms > 0 + # and at least one column where denoms < 0. + # Otherwise the problem would be unbounded. + # (If that actually happens, this code doesn't detect it.) + min_ubounds = torch.where(denoms > 0, ts, +1e9).min(dim=1).values + max_lbounds = torch.where(denoms < 0, ts, -1e9).max(dim=1).values + t_rands = [] + for i in range(len(xs)): + t_min, t_max = max_lbounds[i], min_ubounds[i] + assert t_min <= 0 <= t_max + # Pick a uniformly random value in [t_min, -d] \union [d, t_max] + # (where d == 0.05) if there is such space, to prevent too small + # of a step from previous configuration. + # (If space is not enough for this operation, t is then just uniformly + # sampled from [t_min, t_max]). + t_rands.append(self._rand_uniform_with_gap(t_min, t_max)) + return xs + torch.tensor(t_rands).unsqueeze(1) * alpha + + @staticmethod + def _rand_uniform_with_gap(left: float, right: float, half_gap: float = 0.05): + assert left <= 0 <= right + + def rand_range_f(l0, r0): + return np.random.rand() * (r0 - l0) + l0 + + if right <= half_gap and left >= -half_gap: + # Not enough space to evade [-half_gap, half_gap]. + return rand_range_f(left, right) + l_size = max(0, -left - half_gap) + r_size = max(0, right - half_gap) + x = rand_range_f(-l_size, r_size) + return -half_gap + x if x < 0 else half_gap + x + + def _estimate_var_bounds(self): + from scipy.optimize import linprog + + def _get_coefs(poly: ffi.LinearExpr): + lin_terms = {repr(k): v for k, v in poly.lin_terms.items()} + coefs = [float(lin_terms.get(s, 0)) for s in self.variables] + return coefs, float(poly.constant) + + coefs_biases = [_get_coefs(poly) for poly in self.lin_cons] + coefs, biases = transpose2(coefs_biases) + coefs, biases = np.array(coefs), np.array(biases) + # Run linear optimization with scipy. + vars = self.variables + var_bounds = [] + for i in range(len(vars)): + coef_c = np.zeros(len(vars)) + coef_c[i] = -1 # To maximize x[i] + result = linprog(coef_c, coefs, -biases) + if result.success: + var_bounds.append(result.x[i]) + else: + raise RuntimeError(f"Bound inference failed. Diagnostics: \n{result}") + + def np2torch(x): + return torch.tensor(x, device=self.device).float() + + return np2torch(coefs), np2torch(biases), np2torch(var_bounds) + + +def safe_log(x: torch.Tensor): + # To prevent Infs and NaNs in the result, use this: + # log(x) if x >= 1 + # -log(2 - x) otherwise (x <= 1, 2 - x >= 1) + return torch.clamp_min(x, 1).log() - torch.clamp_min(2 - x, 1).log() + + +class TorchExprRunner: + TIR_OP_TORCH_FN = { + tir.Add: (torch.add, ["a", "b"]), + tir.Sub: (torch.sub, ["a", "b"]), + tir.Mul: (torch.mul, ["a", "b"]), + tir.Div: (torch.div, ["a", "b"]), + tir.Min: (torch.minimum, ["a", "b"]), + tir.Max: (torch.maximum, ["a", "b"]), + tir.Cast: (lambda x: x, ["value"]), + } + TIR_FN_TORCH_FN = { + "tir.pow": torch.pow, + "tir.exp": torch.exp, + "tir.log": safe_log, + "tir.logk": lambda base, x: safe_log(x) / torch.log(base), + "tir.sigmoid": lambda x: x * torch.pow(x**2 + 1, -0.5) / 2 + 0.5, + "tir.hump": lambda x: torch.pow(x**2 + 1, -0.5), + } + + def __init__(self, exprs: np.ndarray, var2idx: dict, var2expr: dict) -> None: + self.exprs = exprs.ravel().tolist() + self.shape = exprs.shape + self.var2idx = var2idx + self.var2expr = var2expr + self.memo: Dict[str, Tensor] = {} + + def get_traced(self) -> GraphModule: + self.memo.clear() + # `self` is coded in the graph. + # Use a lambda to convince symbolic_trace that it's a function, not a member + ret = symbolic_trace(lambda input: self.run(input)) + self.memo.clear() + return ret + + def _run_expr_memoized(self, expr: tir.PrimExpr, inputs: Tensor, batch_size: int): + estr = repr(expr) + if (result := self.memo.get(estr)) is not None: + return result + self.memo[estr] = result = self._run_expr(expr, inputs, batch_size) + return result + + def _run_expr(self, expr: tir.PrimExpr, inputs: Tensor, nb: int) -> Tensor: + if isinstance(expr, tir.Var): + var_idx = self.var2idx.get(expr.name) + if var_idx is not None: + return inputs[:, var_idx] + var_expr = self.var2expr.get(expr.name) + if var_expr is not None: + return self._run_expr_memoized(var_expr, inputs, nb) + raise KeyError(f"Stray variable {expr.name}") + if isinstance(expr, tir.IntImm): + dtype = torch.bool if expr.dtype == "bool" else torch.int64 + return inputs.new_full((nb,), expr.value, dtype=dtype) + if isinstance(expr, tir.FloatImm): + return inputs.new_full((nb,), expr.value, dtype=torch.float32) + if isinstance(expr, tir.Call): + if (func := self.TIR_FN_TORCH_FN.get(expr.op.name)) is None: + raise ValueError(f"Function call {expr.op.name} is unsupported") + args = expr.args + else: + if (func_fields := self.TIR_OP_TORCH_FN.get(type(expr))) is None: + raise ValueError(f"Operator {type(expr)} is unsupported") + func, fields = func_fields + args = [getattr(expr, f) for f in fields] + args = [self._run_expr_memoized(arg, inputs, nb) for arg in args] + return func(*args) + + def run(self, inputs: Tensor): + batch_size = inputs.shape[0] + if not self.exprs: + return inputs.new_zeros(batch_size, *self.shape) + # inputs: [batch_size, n_vars] + ret = [] + for expr in self.exprs: + result = self._run_expr_memoized(expr, inputs, batch_size) + if result.dtype is torch.float and (result.abs() > 1e4).any(): + raise ValueError(f"Result too large: {result} in {expr}") + ret.append(result) + ret_ = torch.stack(ret, dim=1) + # return [batch_size, *shape] (shape = [n_buffers, n_exprs]) + return ret_.view(-1, *self.shape) diff --git a/python/tvm/felix/ffi.py b/python/tvm/felix/ffi.py index ec876c9c4..591016302 100644 --- a/python/tvm/felix/ffi.py +++ b/python/tvm/felix/ffi.py @@ -17,6 +17,23 @@ class VarContext(tvm.Object): return dict(_arith.VarContextGetVarDefs(self)) +@tvm._ffi.register_object("ansor.LinearExpr") +class LinearExpr(tvm.Object): + lin_terms: Dict[tir.Var, float] + constant: float + + def as_primexpr(self) -> tir.PrimExpr: + return _ansor.LinearExprAsPrimExpr(self) + + +@tvm._ffi.register_object("ansor.FeaturePackPy") +class FeaturePack(tvm.Object): + expressions: List[Tuple[str, tir.PrimExpr]] + free_vars: List[str] + linear_cons: List[LinearExpr] + var_decomp: Dict[str, Dict[int, tir.SizeVar]] + + # PrimExpr Utils @@ -57,5 +74,27 @@ def print_state_tr_steps(state: StateObject) -> str: return "\n".join(_ansor.PrintTrStep(s) for s in state.transform_steps) +def get_feature_pack( + code: tir.Stmt, + context: VarContext, + hw_params: ansor.HardwareParams, + sizes: Dict[str, int], + cache_line_size: int, + max_n_buf: int, + factorize: bool, + save_load_path: str, +) -> FeaturePack: + return _ansor.GetFeaturePack( + code, + context, + hw_params, + sizes, + cache_line_size, + max_n_buf, + factorize, + save_load_path, + ) + + def get_loop_bounds(code: tir.Stmt) -> List[Tuple[str, tir.PrimExpr]]: return list(_ansor.GetLoopBounds(code)) diff --git a/python/tvm/felix/sketch.py b/python/tvm/felix/sketch.py index 225953779..570cd959f 100644 --- a/python/tvm/felix/sketch.py +++ b/python/tvm/felix/sketch.py @@ -1,14 +1,17 @@ import logging from pathlib import Path +from typing import Dict from tvm.auto_scheduler.loop_state import StateObject from tvm.tir import Stmt from . import ffi +from .features import TorchFeatures +from .utils import HW_PARAMS _logger = logging.getLogger(__name__) __all__ = ["Sketch"] -CACHE_PATH = (Path(__file__).parent / "../../lightning_logs/features").resolve() +FEATURE_CACHE_PATH = Path(Path("~").expanduser(), ".tvm", "felix", "features") class Sketch: @@ -39,7 +42,28 @@ class Sketch: return ffi.generate_code_for_state(task, state, False)[0] def save_path(self) -> Path: - return CACHE_PATH / f"{self.state_hash()}.json" + return FEATURE_CACHE_PATH / f"{self.state_hash()}.json" + + def fetch_features( + self, + sizes: Dict[str, int], + prime_factorize: bool = True, + max_n_buf: int = 5, + cache_line_size: int = 64, + ): + path = self.save_path() + path.parent.mkdir(exist_ok=True, parents=True) + features = ffi.get_feature_pack( + self.code, + self.context, + HW_PARAMS, + sizes, + cache_line_size, + max_n_buf, + prime_factorize, + path.as_posix(), + ) + return TorchFeatures.from_feat_pack(features) def __str__(self) -> str: return f"Sketch({self.backbone} from {self.parent_task})" diff --git a/src/arith/egg_simpl.cc b/src/arith/egg_simpl.cc index 71bce8435..8dad2ab7a 100644 --- a/src/arith/egg_simpl.cc +++ b/src/arith/egg_simpl.cc @@ -293,6 +293,8 @@ std::pair<PrimExpr, size_t> ParseExprPreorder(const std::string& str, return {MakeUnOp(op_str, exp, args), loc}; } else if (op_str == "sigmoid") { return {MakeUnOp(op_str, sigmoid, args), loc}; + } else if (op_str == "hump") { + return {MakeUnOp(op_str, hump, args), loc}; } else { throw std::runtime_error("Unknown operator " + op_str + " in " + str); } @@ -313,20 +315,15 @@ PrimExpr SimplifyExpr(const PrimExpr& expr, size_t max_n_iters, size_t max_n_nod char* simpl_str = simplify_expr(expr_str.c_str(), max_n_iters, max_n_nodes, diff_approx); PrimExpr simplified = ParseExprPreorder(simpl_str, printer.var_map).first; free_str(simpl_str); - // If not really simplified, don't return the simplified version - // because the order of the variables etc. may be different. - return CountOps(simplified) < CountOps(expr) ? simplified : expr; + return simplified; } PrimExpr SubAndSimplify(const PrimExpr& expr, - const std::unordered_map<std::string, PrimExpr>& subst, - bool simpl_only_on_change, size_t max_n_iters, size_t max_n_nodes) { + const std::unordered_map<std::string, PrimExpr>& subst, size_t max_n_iters, + size_t max_n_nodes) { bool changed = false; auto expr_ = SubstByName(expr, subst, &changed); - if (!changed && simpl_only_on_change) { - return expr; - } - return SimplifyExpr(expr_, max_n_iters, max_n_nodes, true); + return changed ? SimplifyExpr(expr_, max_n_iters, max_n_nodes, true) : expr; } bool IsExprEquivalent(const PrimExpr& lhs, const PrimExpr& rhs, size_t max_n_iters, diff --git a/src/arith/egg_simpl/src/lang.rs b/src/arith/egg_simpl/src/lang.rs index 576763586..83aff3e5e 100644 --- a/src/arith/egg_simpl/src/lang.rs +++ b/src/arith/egg_simpl/src/lang.rs @@ -86,6 +86,7 @@ define_language! { "!" = Not([Id; 1]), "select" = Select([Id; 3]), "sigmoid" = Sigmoid([Id; 1]), + "hump" = Hump([Id; 1]), } } @@ -264,7 +265,7 @@ impl egg::Analysis<Math> for ConstantFold { } } -fn is_const_with_pred( +fn _is_const_with_pred( var: &str, fop: impl Fn(f64) -> bool, ) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { @@ -279,11 +280,29 @@ fn is_const_with_pred( } } -fn is_geq_2(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { - is_const_with_pred(var, |x| x >= 2.0) +fn is_not_zero(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + let var = var.parse().unwrap(); + move |egraph, _, subst| { + if let Some(n) = &egraph[subst[var]].data { + if let Const::Float(OrderedFloat(f)) = n.0 { + return f != 0.0; + } + } + true + } +} + +fn is_symbol(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + let var = var.parse().unwrap(); + move |egraph, _, subst| { + egraph[subst[var]] + .nodes + .iter() + .any(|n| matches!(n, Math::Symbol(..))) + } } -fn is_pow_of_(egraph: &mut EGraph, subst: &Subst, x: Var, base: Var) -> Option<i64> { +fn _is_pow_of(egraph: &mut EGraph, subst: &Subst, x: Var, base: Var) -> Option<i64> { let (x, base) = ( egraph[subst[x]].data.as_ref()?.0, egraph[subst[base]].data.as_ref()?.0, @@ -302,10 +321,10 @@ fn is_pow_of_(egraph: &mut EGraph, subst: &Subst, x: Var, base: Var) -> Option<i fn is_pow_of(x: &str, base: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { let (x, base) = (x.parse().unwrap(), base.parse().unwrap()); - move |egraph, _, subst| is_pow_of_(egraph, subst, x, base).is_some() + move |egraph, _, subst| _is_pow_of(egraph, subst, x, base).is_some() } -pub static STABLE_RULES: Lazy<Vec<Rewrite>> = Lazy::new(|| { +pub static BASIC_RULES: Lazy<Vec<Rewrite>> = Lazy::new(|| { vec![ rewrite!("mul-comm"; "(* ?a ?b)" => "(* ?b ?a)"), rewrite!("mul-assoc"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"), @@ -318,7 +337,7 @@ pub static STABLE_RULES: Lazy<Vec<Rewrite>> = Lazy::new(|| { rewrite!("sub-intro"; "(+ ?a (* -1 ?b))" => "(- ?a ?b)"), rewrite!("sub-cancel"; "(+ ?a (* -1 ?a))" => "0"), rewrite!("div-canon"; "(/ ?a ?b)" => "(* ?a (pow ?b -1))"), - rewrite!("div-intro"; "(* ?a (pow ?b -1))" => "(/ ?a ?b)"), + rewrite!("div-intro"; "(* ?a (pow ?b -1))" => "(/ ?a ?b)" if is_not_zero("?b")), rewrite!("div-cancel"; "(* ?a (pow ?a -1))" => "1"), rewrite!("add-mul-distrib"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"), rewrite!("add-mul-factor"; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"), @@ -327,12 +346,16 @@ pub static STABLE_RULES: Lazy<Vec<Rewrite>> = Lazy::new(|| { rewrite!("pow-pow"; "(pow (pow ?a ?b) ?c)" => "(pow ?a (* ?b ?c))"), rewrite!("const-mul-pow"; "(* ?a (pow ?b ?c))" => "(pow ?b (+ (logk ?b ?a) ?c))" if is_pow_of("?a", "?b")), rewrite!("const-div-pow"; "(/ ?a (pow ?b ?c))" => "(pow ?b (- (logk ?b ?a) ?c))" if is_pow_of("?a", "?b")), + rewrite!("max-self"; "(max ?a ?a)" => "?a"), + rewrite!("max-a-add-b"; "(max ?a (+ ?a ?b))" => "(+ ?a ?b)" if _is_const_with_pred("?b", |x| x >= 0.0)), + rewrite!("min-self"; "(max ?a ?a)" => "?a"), rewrite!("min-pow"; "(min (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (min ?b ?c))"), - rewrite!("logk-canon"; "(logk ?a ?b)" => "(/ (log ?b) (log ?a))"), - rewrite!("log-prod"; "(log (* ?a ?b))" => "(+ (log ?a) (log ?b))"), - rewrite!("log-pow"; "(log (pow ?a ?b))" => "(* ?b (log ?a))"), - rewrite!("floordiv-cancel"; "(// ?a ?a)" => "1"), + rewrite!("logk-canon"; "(logk ?a ?b)" => "(/ (log ?b) (log ?a))"), + rewrite!("log-prod"; "(log (* ?a ?b))" => "(+ (log ?a) (log ?b))"), + rewrite!("log-pow"; "(log (pow ?a ?b))" => "(* ?b (log ?a))"), + rewrite!("floordiv-cancel"; "(// (* ?a ?b) ?a)" => "?b"), rewrite!("floordiv-merge"; "(// (// ?a ?b) ?c)" => "(// (/ ?a ?b) ?c)"), + rewrite!("floordiv-neg-1"; "(// -1 ?a)" => "-1"), rewrite!("select-same"; "(select ?a ?b ?b)" => "?b"), rewrite!("select-true"; "(select true ?a ?b)" => "?a"), rewrite!("select-false"; "(select false ?a ?b)" => "?b"), @@ -357,32 +380,57 @@ pub static STABLE_RULES: Lazy<Vec<Rewrite>> = Lazy::new(|| { rewrite!("gt-canon"; "(> ?a ?b)" => "(< ?b ?a)"), rewrite!("le-canon"; "(<= ?a ?b)" => "(! (< ?b ?a))"), rewrite!("eq-comm"; "(== ?a ?b)" => "(== ?b ?a)"), + rewrite!("lt-min"; "(< ?a (min ?b ?c))" => "(&& (< ?a ?b) (< ?a ?c))"), + rewrite!("eq-min"; "(== ?a (min ?b ?c))" => "(|| (== ?a ?b) (== ?a ?c))"), ] }); -pub static ALL_RULES: Lazy<Vec<Rewrite>> = Lazy::new(|| { +pub static DIFF_APPROX_RULES: Lazy<Vec<Rewrite>> = Lazy::new(|| { let mut ret_rules = vec![]; - ret_rules.extend_from_slice(STABLE_RULES.as_slice()); + ret_rules.extend_from_slice(BASIC_RULES.as_slice()); ret_rules.extend(vec![ // Differentiability-approx-specific simplification rules - rewrite!("lt-pow-1"; "(< (pow ?a ?b) ?c)" => "(< ?b (logk ?a ?c))" if is_geq_2("?a")), - rewrite!("lt-pow-2"; "(< ?c (pow ?a ?b))" => "(< (logk ?a ?c) ?b)" if is_geq_2("?a")), - rewrite!("lt-add-1"; "(< (+ ?a ?b) ?c)" => "(< ?a (- ?c ?b))"), - rewrite!("lt-add-2"; "(< ?c (+ ?a ?b))" => "(< (- ?c ?b) ?a)"), - rewrite!("lt-sub-1"; "(< (- ?a ?b) ?c)" => "(< ?a (+ ?b ?c))"), - rewrite!("lt-sub-2"; "(< ?c (- ?a ?b))" => "(< (+ ?b ?c) ?a)"), - rewrite!("lt-div-pow"; "(< ?a (/ ?b (pow ?c ?d)))" => "(< (* ?a (pow ?c ?d)) ?b)"), - rewrite!("lt-mul-pow"; "(< ?a (* ?b (pow ?c ?d)))" => "(< (/ ?a (pow ?c ?d)) ?b)"), - rewrite!("lt-min"; "(< ?a (min ?b ?c))" => "(&& (< ?a ?b) (< ?a ?c))"), - rewrite!("eq-pow"; "(== (pow ?a ?b) ?c)" => "(== ?b (logk ?a ?c))" if is_geq_2("?a")), - rewrite!("eq-pow-0"; "(== (pow ?a ?b) 0)" => "false" if is_geq_2("?a")), - rewrite!("eq-sub"; "(== (- ?a ?b) ?c)" => "(== ?a (+ ?b ?c))"), - rewrite!("eq-to-sub"; "(== ?a ?b)" => "(== (- ?a ?b) 0)"), - rewrite!("eq-div"; "(== ?a (/ ?b ?c))" => "(== (* ?a ?c) ?b)"), + rewrite!("1-lt-prod"; "(< 1 (* ?a ?b))" => "(|| (< 1 ?a) (< 1 ?b))"), + rewrite!("lt-pow-1"; "(< (pow ?a ?b) ?c)" => "(< ?b (logk ?a ?c))" if _is_const_with_pred("?a", |x| x >= 2.0)), + rewrite!("lt-pow-2"; "(< ?c (pow ?a ?b))" => "(< (logk ?a ?c) ?b)" if _is_const_with_pred("?a", |x| x >= 2.0)), + rewrite!("eq-to-lt"; "(== 1 ?a)" => "(! (< 1 ?a))" if is_symbol("?a")), + rewrite!("1-eq-prod"; "(== 1 (* ?a ?b))" => "(&& (== 1 ?a) (== 1 ?b))"), + rewrite!("neg-1-never-0"; "(== 0 (* -1 ?a))" => "false"), + rewrite!("hump-0"; "(hump 0)" => "1"), + rewrite!("hump-const"; "(hump ?a)" => "0" if _is_const_with_pred("?a", |x| x != 0.0)), ]); ret_rules }); +pub struct DiffApproxSimplCostFn; +impl egg::CostFunction<Math> for DiffApproxSimplCostFn { + type Cost = usize; + fn cost<C>(&mut self, enode: &Math, mut costs: C) -> Self::Cost + where + C: FnMut(Id) -> Self::Cost, + { + // Prefer the form (|| (< 1 ?a) (< 1 ?b)) over (< 1 (* a b)). + let op_cost = match enode { + // Thus all logical and lt / eq operators are cheap, + // and everything else is expensive. + Math::Const(..) => 1, + Math::Symbol(..) => 1, + Math::And(..) => 1, + Math::Or(..) => 1, + Math::Not(..) => 1, + Math::Lt(..) => 1, + // Prefer Lt over any other comparators + Math::Eq(..) => 5, + Math::Ne(..) => 5, + Math::Le(..) => 5, + Math::Gt(..) => 5, + Math::Ge(..) => 5, + _ => 10, + }; + enode.fold(op_cost, |sum, i| sum + costs(i)) + } +} + #[cfg(feature = "count")] pub static GLOBAL_RULE_COUNTER: Lazy<Mutex<HashMap<Symbol, usize>>> = Lazy::new(|| Mutex::new(HashMap::new())); @@ -447,9 +495,9 @@ fn make_runner( explain: bool, ) -> (Runner, &'static Vec<Rewrite>) { let rules = if diff_approx { - &*ALL_RULES + &*DIFF_APPROX_RULES } else { - &*STABLE_RULES + &*BASIC_RULES }; let mut runner = Runner::default() .with_iter_limit(n_iters) @@ -465,7 +513,11 @@ pub fn simplify(expr: &str, n_iters: usize, n_nodes: usize, diff_approx: bool) - let (mut runner, rules) = make_runner(n_iters, n_nodes, diff_approx, cfg!(feature = "count")); runner = runner.with_expr(&expr).run(rules); let root = runner.roots[0]; - let (_, best) = Extractor::new(&runner.egraph, AstSize).find_best(root); + let (_, best) = if diff_approx { + Extractor::new(&runner.egraph, DiffApproxSimplCostFn).find_best(root) + } else { + Extractor::new(&runner.egraph, AstSize).find_best(root) + }; if cfg!(feature = "count") { add_rules_to_counter(&mut runner.explain_equivalence(&expr, &best)); } diff --git a/src/arith/egg_simpl/src/lib.rs b/src/arith/egg_simpl/src/lib.rs index c6a6e46a4..e0d674071 100644 --- a/src/arith/egg_simpl/src/lib.rs +++ b/src/arith/egg_simpl/src/lib.rs @@ -18,14 +18,14 @@ pub extern "C" fn simplify_expr( s_raw: *const c_char, n_iters: u64, n_nodes: u64, - simpl_rel: bool, + diff_approx: bool, ) -> *mut c_char { let s = str_from_ptr(s_raw); let simplified = lang::simplify( &s, n_iters.try_into().unwrap(), n_nodes.try_into().unwrap(), - simpl_rel, + diff_approx, ); let c_str = CString::new(simplified).unwrap(); c_str.into_raw() @@ -49,7 +49,7 @@ pub extern "C" fn is_equivalent( explain: bool, n_iters: u64, n_nodes: u64, - simpl_rel: bool, + diff_approx: bool, ) -> bool { let lhs = str_from_ptr(lhs_raw); let rhs = str_from_ptr(rhs_raw); @@ -59,7 +59,7 @@ pub extern "C" fn is_equivalent( explain, n_iters.try_into().unwrap(), n_nodes.try_into().unwrap(), - simpl_rel, + diff_approx, ) } diff --git a/src/arith/var_context.cc b/src/arith/var_context.cc index 406f5e1be..0f42861ce 100644 --- a/src/arith/var_context.cc +++ b/src/arith/var_context.cc @@ -95,63 +95,55 @@ void VarDefStackNode::Append(const SizeVar& var, const PrimExpr& expr) { this->exprs.push_back(VarExprPair(var, expr)); } -SizeVar VarDefStackNode::FindOrAppend(const std::string& vname, const PrimExpr& expr) { - auto it = this->expr2idx.find(expr); - if (it != this->expr2idx.end()) { - return this->exprs[it->second]->var; - } - return this->Append(vname, expr); -} - -VarDefStackNode VarDefStackNode::Prepend(const VarMapT& vmap) const { - VarDefStackNode ret; - for (auto& [vname, expr] : vmap) { - ret.Append(vname, expr); - } - for (auto& pair : this->exprs) { - ret.Append(pair->var->name_hint, pair->expr); +PrimExpr VarDefStackNode::DefineConstShorthand(PrimExpr expr) { + if (ExprIsConstant(expr) && CountOps(expr) >= 10) { + std::string name = "v" + std::to_string(this->exprs.size()); + expr = Analyzer().canonical_simplify(arith::SimplifyExpr(expr)); + auto it = this->expr2idx.find(expr); + if (it == this->expr2idx.end()) { + expr = this->Append(name, expr); + } else { + expr = this->exprs[it->second]->var; + } } - return ret; + return expr; } -VarMapT VarDefStackNode::IntoVarMap() const { +VarMapT VarDefStackNode::IntoUnwindedVarMap() const { VarMapT vmap; for (const auto& pair : this->exprs) { - vmap.emplace(pair->var->name_hint, tir::SubstByName(pair->expr, vmap)); + vmap.emplace(pair->var->name_hint, SubstByName(pair->expr, vmap)); } return vmap; } -std::vector<Var> VarDefStackNode::FreeVars() const { - std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> vars; - auto CollectVars = [&vars](const ObjectRef& node) { - if (const VarNode* op = node.as<VarNode>()) { - vars.insert(GetRef<Var>(op)); - } - }; - for (const auto& pair : this->GetExprs()) { - tir::PostOrderVisit(pair->expr, CollectVars); +std::unordered_set<std::string> VarDefStackNode::GetAllUsedVars( + std::optional<SizeVarKind> kind) const { + std::unordered_set<std::string> ret; + for (auto& pair : this->exprs) { + PostOrderVisit(pair->expr, [this, &ret, kind](const ObjectRef& node) { + if (auto* svnode = node.as<SizeVarNode>()) { + if (!kind || svnode->kind == kind) { + ret.insert(svnode->name_hint); + } + } + }); } - for (const auto& pair : this->GetExprs()) { - vars.erase(pair->var); + return ret; +} + +void VarDefStackNode::MapExprs(ExprMutator func) { + for (size_t i = 0; i < this->exprs.size(); i++) { + auto& pair = this->exprs[i]; + this->exprs.Set(i, VarExprPair(pair->var, func(pair->expr))); } - return std::vector<Var>(vars.begin(), vars.end()); } -bool VarDefStackNode::HasUndefVars(const PrimExpr& expr) const { - // SizeVar is seen as constant and doesn't count unless it has kOther type. - bool has_undef = false; - auto CheckUndef = [&has_undef, this](const ObjectRef& obj) { - if (auto* vnode = obj.as<VarNode>()) { - auto* svnode = obj.as<SizeVarNode>(); - if (this->var2idx.count(vnode->name_hint) == 0 && - (!svnode || svnode->kind == SizeVarKind::kOther)) { - has_undef = true; - } - } - }; - tir::PostOrderVisit(expr, CheckUndef); - return has_undef; +void VarDefStackNode::MapExprsParallel(ExprMutator func) { + support::parallel_for(0, this->exprs.size(), [this, &func](int i) { + auto& pair = this->exprs[i]; + this->exprs.Set(i, VarExprPair(pair->var, func(pair->expr))); + }); } inline std::pair<PrimExpr, PrimExpr> ConservativeDiv(PrimExpr extent, PrimExpr factor, @@ -174,8 +166,8 @@ Array<SizeVar> VarContextNode::GetSplitVars(const PrimExpr& extent, size_t n_spl var_names.push_back(name); product *= var; } - // Declare a quotient variable which would be equal to Extent_i / (sp_i_0*sp_i_1*...sp_i_j), as qi. - // For example q3 could be Cout / (sp_3_0 * sp_3_1 * sp_3_2...) for group 3. + // Declare a quotient variable which would be equal to Extent_i / (sp_i_0*sp_i_1*...sp_i_j), as + // qi. For example q3 could be Cout / (sp_3_0 * sp_3_1 * sp_3_2...) for group 3. // * Don't even define this variable in this->var_defs. We'll need to delay the expansion of q{i} // as much as possible, and when finally it's needed it can be derived from the SplitGroup. SizeVar quotient("q" + group_idx, SizeVarKind::kShorthand); @@ -213,29 +205,13 @@ std::pair<PrimExpr, PrimExpr> VarContextNode::GetSplitSizes(const PrimExpr& exte return this->SymbolicDiv(extent, factor, no_tighten_factor); } -void VarContextNode::DefineVar(const std::string& name, PrimExpr expr) { - this->var_defs->Append(name, expr); -} - -PrimExpr VarContextNode::DefineConstShorthand(PrimExpr expr) { - if (!this->var_defs->HasUndefVars(expr) && CountOps(expr) >= 10) { - expr = this->AllocVarForExpr(arith::SimplifyExpr(expr)); - } - return expr; -} - -Var VarContextNode::AllocVarForExpr(PrimExpr expr) { - std::string name = "v" + std::to_string(this->var_defs->Size()); - return this->var_defs->FindOrAppend(name, expr); -} - std::pair<PrimExpr, PrimExpr> VarContextNode::SymbolicDiv(PrimExpr numer, PrimExpr denom, bool no_tighten_factor) { PrimExpr simpl = SimplifyExpr(numer / denom); if (!HasDiv(simpl)) { return {denom, simpl}; } - for (auto &[extent, subst]: this->div_extents) { + for (auto& [extent, subst] : this->div_extents) { if (IsExprEquivalent(numer, extent)) { PrimExpr simpl = SimplifyExpr(subst / denom); if (!HasDiv(simpl)) { diff --git a/src/felix/constraints.cc b/src/felix/constraints.cc new file mode 100644 index 000000000..6fd5f2503 --- /dev/null +++ b/src/felix/constraints.cc @@ -0,0 +1,207 @@ +#include <tvm/auto_scheduler/search_task.h> +#include <tvm/tir/stmt_functor.h> + +#include <numeric> + +#include "../tir/transforms/ir_utils.h" +#include "features.h" + +namespace tvm { +namespace felix { + +using namespace tvm::tir; + +class GPUConstraintsMaker : public StmtExprVisitor { + // TODO(jcf94): Add support of detecting CUDA Misaligned Address error + public: + explicit GPUConstraintsMaker(size_t max_local_memory_per_block, + size_t max_shared_memory_per_block, size_t max_threads_per_block, + size_t max_vthread, size_t max_vector_size, size_t max_vector_bytes) + : max_local_memory_per_block(max_local_memory_per_block), + max_shared_memory_per_block(max_shared_memory_per_block), + max_threads_per_block(max_threads_per_block), + max_vthread(max_vthread), + max_vector_size(max_vector_size), + max_vector_bytes(max_vector_bytes) {} + + void RunOnStmt(Stmt stmt) { + Reset_(); + this->VisitStmt(stmt); + } + + void VisitStmt_(const AllocateNode* op) final { + StmtVisitor::VisitStmt_(op); + // visit an allocation of a buffer in shared memory, record its size + auto scope = GetPtrStorageScope(op->buffer_var); + CountBufferSize_(op->extents, op->dtype, scope); + if (op->dtype.lanes() > 1) { + CheckVectorBytes_(op->dtype); + } + } + + void VisitStmt_(const BufferRealizeNode* op) final { + StmtVisitor::VisitStmt_(op); + auto scope = GetPtrStorageScope(op->buffer->data); + Array<PrimExpr> extents; + for (auto& range : op->bounds) { + extents.push_back(range->extent); + } + CountBufferSize_(extents, op->buffer->dtype, scope); + if (op->buffer->dtype.lanes() > 1) { + CheckVectorBytes_(op->buffer->dtype); + } + } + + void VisitStmt_(const AttrStmtNode* op) final { + auto attr = op->attr_key; + bool is_thread = attr == tir::attr::thread_extent; + bool is_vthread = attr == tir::attr::virtual_thread; + bool is_unroll = attr == tir::attr::pragma_auto_unroll_max_step; + if (!is_thread && !is_vthread && !is_unroll) { + StmtVisitor::VisitStmt_(op); + return; + } + + if (this->nest_level == 0) { + // enter a new kernel, reset statistics + Reset_(); + kernels_launched++; + } + + // record the number of threads in a block + Var var = op->node.as<IterVarNode>()->var; + std::string name = var.get()->name_hint; + PrimExpr extent = op->value; + if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z" || + name == "vthread") { + // record the number of threads in a block + if (!this->visited_threads.count(name)) { + this->visited_threads.insert(name); + this->thread_per_block *= extent; + } + // else: the thread should be bound to axes with the same length + // but we don't check this here, as it can be difficult to + // compare the equality of two expressions. + } + + this->nest_level++; + StmtVisitor::VisitStmt_(op); + this->nest_level--; + + if (this->nest_level == 0) { + // exit a kernel, check the validity + AddConstraint_(this->thread_per_block, this->max_threads_per_block); + AddConstraint_(this->local_memory_per_block, this->max_local_memory_per_block); + AddConstraint_(this->shared_memory_per_block, this->max_shared_memory_per_block); + } + } + + void VisitStmt_(const ForNode* op) { + if (op->kind == ForKind::kVectorized) { + AddConstraint_(op->extent, this->max_vector_size); + } + StmtVisitor::VisitStmt_(op); + } + + void VisitExpr_(const LoadNode* op) { + if (op->dtype.lanes() > 1) { + CheckVectorBytes_(op->dtype); + } + ExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const StoreNode* op) { + if (op->value->dtype.lanes() > 1) { + CheckVectorBytes_(op->value->dtype); + } + StmtVisitor::VisitStmt_(op); + } + + private: + size_t max_local_memory_per_block; + size_t max_shared_memory_per_block; + size_t max_threads_per_block; + size_t max_vthread; + size_t max_vector_size; + size_t max_vector_bytes; + + std::unordered_set<std::string> visited_threads{}; + PrimExpr local_memory_per_block = 0; + PrimExpr shared_memory_per_block = 0; + PrimExpr thread_per_block = 1; + size_t kernels_launched = 0; + size_t nest_level = 0; + + public: + std::vector<PrimExpr> constraints{}; + std::vector<String> errors{}; + + private: + void AddConstraint_(PrimExpr lhs, size_t rhs) { + auto con = lhs <= Integer(rhs); + if (auto* simp_bool = con.as<IntImmNode>()) { + if (simp_bool->value == 0) { + std::stringstream s; + s << "Constraint " << lhs << " <= " << rhs << " is trivially False"; + this->errors.push_back(s.str()); + } + } else { + this->constraints.push_back(con); + } + } + + void CountBufferSize_(const Array<PrimExpr>& extents, DataType dtype, const String& scope) { + PrimExpr one = Integer(1); + PrimExpr alloc_count = + std::accumulate(extents.begin(), extents.end(), one, + [](const PrimExpr& a, const PrimExpr& b) { return a * b; }); + PrimExpr alloc_size = alloc_count * dtype.bytes() * dtype.lanes(); + if (scope == "local") { + this->local_memory_per_block += alloc_size; + } else if (scope == "shared") { + this->shared_memory_per_block += alloc_size; + } + } + + void CheckVectorBytes_(DataType dtype) { + if (static_cast<size_t>(dtype.lanes() * dtype.bytes()) > this->max_vector_bytes) { + std::stringstream s; + s << "Number of lanes (" << dtype.lanes() << ") times number of bytes (" << dtype.bytes() + << ") for dtype " << dtype << " is greater than the maximum number of vector bytes (" + << this->max_vector_bytes << ")"; + this->errors.push_back(s.str()); + } + } + + void Reset_() { + this->local_memory_per_block = 0; + this->shared_memory_per_block = 0; + this->visited_threads.clear(); + } +}; + +std::vector<PrimExpr> GetConstraints(const Stmt& code, + const auto_scheduler::HardwareParams& hw_params) { + // Run GPU verification pass to inject constraints. + // HACK: size of 4 is based on src/target/source/codegen_cuda.cc. + // - line 246: bool vector is only supported when size is less than 4 + // - line 351: vector of int/uint is only supported when size is less than 8. + // While this is true for all CUDA devices and will stay so for a while, + // it's not a good idea to hardcode it here. + GPUConstraintsMaker cmaker(hw_params->max_local_memory_per_block, + hw_params->max_shared_memory_per_block, + hw_params->max_threads_per_block, hw_params->max_vthread_extent, + /*max_vector_size*/ 4, hw_params->vector_unit_bytes); + cmaker.RunOnStmt(code); + if (!cmaker.errors.empty()) { + LOG_WARNING << "Code constraint check failed: "; + for (auto& err : cmaker.errors) { + LOG_WARNING << " " << err; + } + LOG_WARNING << "Failed code: " << code; + return {}; + } + return cmaker.constraints; +} +} // namespace felix +} // namespace tvm \ No newline at end of file diff --git a/src/felix/feat_transform.cc b/src/felix/feat_transform.cc new file mode 100644 index 000000000..8a9f04380 --- /dev/null +++ b/src/felix/feat_transform.cc @@ -0,0 +1,1003 @@ +#include <tvm/arith/egg_simpl.h> +#include <tvm/arith/var_context.h> +#include <tvm/auto_scheduler/feature.h> +#include <tvm/support/parallel_for.h> + +#include <optional> +#include <variant> + +#include "features.h" +#include "utils.h" + +namespace tvm { +namespace felix { +using namespace tvm::arith; + +std::unordered_map<uint64_t, uint64_t> Factorize(uint64_t n) { + std::unordered_map<uint64_t, uint64_t> factors; + for (uint64_t i = 2; i <= n; ++i) { + while (n % i == 0) { + factors[i] += 1; + n /= i; + } + } + return factors; +} + +template <typename T> +void CollectSameOps(const T* e, std::vector<PrimExpr>& ret) { + auto *lhs = e->a.template as<T>(), *rhs = e->b.template as<T>(); + if (lhs) { + CollectSameOps(lhs, ret); + } else { + ret.push_back(e->a); + } + if (rhs) { + CollectSameOps(rhs, ret); + } else { + ret.push_back(e->b); + } +} + +template <typename T> +std::vector<PrimExpr> CollectSameOps(const PrimExpr& e) { + std::vector<PrimExpr> ret; + auto* node = e.as<T>(); + if (node) { + CollectSameOps(node, ret); + } else { + ret.push_back(e); + } + return ret; +} + +// size_t is not (de)serializable, so we use uint64_t instead +using DecompT = std::unordered_map<std::string, std::unordered_map<uint64_t, SizeVar>>; + +class LinearExprNode : public Object { + public: + Map<SizeVar, FloatImm> lin_terms; + FloatImm constant; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("lin_terms", &lin_terms); + v->Visit("constant", &constant); + } + + static constexpr const char* _type_key = "ansor.LinearExpr"; + TVM_DECLARE_FINAL_OBJECT_INFO(LinearExprNode, Object); +}; + +TVM_REGISTER_NODE_TYPE(LinearExprNode); + +class LinearExpr : public ObjectRef { + public: + explicit LinearExpr(double constant) { + auto node = make_object<LinearExprNode>(); + node->constant = ToFloatImm(constant); + this->data_ = std::move(node); + } + + explicit LinearExpr(SizeVar var) { + auto node = make_object<LinearExprNode>(); + node->constant = ToFloatImm(0.0f); + node->lin_terms.Set(var, ToFloatImm(1.0f)); + this->data_ = std::move(node); + } + + LinearExpr(double constant, const std::unordered_map<std::string, double>& name2coef) { + auto node = make_object<LinearExprNode>(); + node->constant = ToFloatImm(0.0f); + for (auto& [name, coef] : name2coef) { + node->lin_terms.Set(SizeVar(name, SizeVarKind::kScheduleKnob), ToFloatImm(coef)); + } + this->data_ = std::move(node); + } + + PrimExpr ToPrimExpr() const { + auto node = this->operator->(); + PrimExpr ret = node->constant; + for (auto& [var, coef] : node->lin_terms) { + ret = ret + var * coef; + } + return ret; + } + + LinearExpr& operator+=(const LinearExpr& other) { + ICHECK(this->defined() && other.defined()); + auto this_ = this->CopyOnWrite(); + this_->constant = ToFloatImm(this_->constant->value + other->constant->value); + for (auto& [var, coef1] : other->lin_terms) { + auto coef2 = this_->lin_terms.Get(var).value_or(ToFloatImm(0.0f)); + this_->lin_terms.Set(var, ToFloatImm(coef1->value + coef2->value)); + } + return *this; + } + LinearExpr& operator*=(double other) { + ICHECK(this->defined()); + auto this_ = this->CopyOnWrite(); + this_->constant = ToFloatImm(this_->constant->value * other); + for (auto& [var, coef] : this_->lin_terms) { + this_->lin_terms.Set(var, ToFloatImm(coef->value * other)); + } + return *this; + } + LinearExpr& operator-=(const LinearExpr& other) { return *this += LinearExpr(-1.0f) * other; } + LinearExpr& operator/=(double other) { return *this *= (1.0f / other); } + +#define DEF_BINARY_OP(def, with, other_t, check) \ + LinearExpr operator def(const other_t& other) const { \ + if (!this->defined() || check) { \ + return LinearExpr(); \ + } \ + LinearExpr ret = *this; \ + ret with other; \ + return ret; \ + } + + DEF_BINARY_OP(+, +=, LinearExpr, !other.defined()) + DEF_BINARY_OP(-, -=, LinearExpr, !other.defined()) + DEF_BINARY_OP(*, *=, double, false) + DEF_BINARY_OP(/, /=, double, false) + + LinearExpr operator*(LinearExpr other) { + if (!this->defined() || !other.defined()) { + return LinearExpr(); + } else if ((*this)->lin_terms.empty()) { + return other * (*this)->constant->value; + } else if (other->lin_terms.empty()) { + return (*this) * other->constant->value; + } else { + return LinearExpr(); + } + } + LinearExpr operator/(LinearExpr other) { + if (!this->defined() || !other.defined()) { + return LinearExpr(); + } else if (other->lin_terms.empty()) { + return (*this) / other->constant->value; + } else { + return LinearExpr(); + } + } + + TVM_DEFINE_OBJECT_REF_METHODS(LinearExpr, ObjectRef, LinearExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(LinearExprNode); +}; + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch<LinearExprNode>([](const ObjectRef& node, ReprPrinter* p) { + auto* expr = static_cast<const LinearExprNode*>(node.get()); + auto& os = p->stream; + os << "(" << expr->constant << "; "; + bool first = true; + for (auto& [var, coef] : expr->lin_terms) { + if (!first) { + os << ", "; + } + os << "(" << coef << " * " << var << ")"; + first = false; + } + os << ")"; + }); + +class LinExprExtractor : public ExprFunctor<LinearExpr(const PrimExpr&)> { + LinearExpr VisitExpr_(const SizeVarNode* e) override { return LinearExpr(GetRef<SizeVar>(e)); } + LinearExpr VisitExpr_(const IntImmNode* e) override { return LinearExpr((double)e->value); } + LinearExpr VisitExpr_(const FloatImmNode* e) override { return LinearExpr(e->value); } + LinearExpr VisitExpr_(const VarNode* e) override { + LOG_FATAL << "Do not use LinExprExtractor on expressions with non-sizevar variable; got " + << GetRef<Var>(e); + return LinearExpr(); + } + + LinearExpr VisitExpr_(const AddNode* e) override { return VisitExpr(e->a) + VisitExpr(e->b); } + LinearExpr VisitExpr_(const SubNode* e) override { return VisitExpr(e->a) - VisitExpr(e->b); } + LinearExpr VisitExpr_(const MulNode* e) override { return VisitExpr(e->a) * VisitExpr(e->b); } + LinearExpr VisitExpr_(const DivNode* e) override { return VisitExpr(e->a) / VisitExpr(e->b); } + LinearExpr VisitExpr_(const CastNode* e) override { return VisitExpr(e->value); } + + LinearExpr VisitExprDefault_(const Object* e) override { return LinearExpr(); } +}; + +class FloorRemover : public ExprMutator { + PrimExpr VisitExpr_(const FloorDivNode* e) final { + // Simply drop the floor() operator; it's not differentiable. + return div(VisitExpr(e->a), VisitExpr(e->b)); + } + PrimExpr VisitExpr_(const FloorModNode* e) final { + LOG_WARNING << "Mod operator is not differentiable and will be dropped."; + return Integer(0); + } +}; + +class DiffableApprox : public MemoizedExprFunctor<PrimExpr> { + public: + DiffableApprox(const DiffableApprox& other) = delete; + DiffableApprox() = delete; + + // * Important to use Float values (Range(1.0, 5.0)) for range inf default range. + DiffableApprox(const std::unordered_set<std::string>& new_knobs, const VarMapT& exp_subst, + const VarMapT& shorthands, const VarMapT& quotients) + : rinf(Range(ToFloatImm(1.0), ToFloatImm(5.0))), + new_knobs(new_knobs), + exp_subst(exp_subst), + shorthands(shorthands), + quotients(quotients) {} + + // Simplification strategies: + // 1. `K < sp_i_j_b + sp_i_j_b + ...` (K is actual constant) + // - We're good; just return sigmoid(RHS - LHS). + // 2. `1 < sp_i_j` + // - Replace sp_i_j with its exp decomposition, simplify, should give us Form 1. + // 3. Anything expr with shorthand variables + // - Substitute with shorthand vars and simplify; if changed, revisit. + // 4. Now safe to use special simplification (safe on expressions that only contain `sp_i_j` and + // quotient vars `d{i}`, and shape constants) + // - This helps reduce forms such as `1 < sp_i_j * sp_i'_j' * ...` to smaller forms. + // - If expr contains quotients d{i}, try twice: expanding d{i} = E{i} / sp_i_j / sp_i_j' ... , + // see which one is shorter. + // 5. If everything fall through, try taking log on both sides (with range inference to help with + // safety). + // + // Step 3-5 also applies to VisitExpr_(EQ). + + PrimExpr VisitExpr_(const LTNode* e) final { + auto expr = GetRef<PrimExpr>(e); + LinearExpr diff = IsTermLinear(e->b - e->a); + if (diff.defined()) { + return MakeDiffableCond(diff.ToPrimExpr(), /* sigmoid_or_hump */ true); + } + PrimExpr response; + if ((response = ConstVsSingleSchedVar(e->a, e->b)).defined()) { + response = SimplifyExpr(LT(e->a, response), 20, 1000, true); + return VisitExpr(response); + } + if ((response = SubstShorthandsAway(expr)).defined()) { + return VisitExpr(response); + } + if ((response = SpecialSimplAndVisit(expr)).defined()) { + return response; + } + if ((response = TakeLogDiffIfSafe(e->b, e->a)).defined()) { + return MakeDiffableCond(response, /* sigmoid_or_hump */ true); + } + LOG_FATAL << "Cannot simplify " << e->a << " < " << e->b; + return PrimExpr(); + } + + PrimExpr VisitExpr_(const EQNode* e) final { + auto expr = GetRef<PrimExpr>(e); + PrimExpr response; + if ((response = SubstShorthandsAway(expr)).defined()) { // Step 1 + return VisitExpr(response); + } + if ((response = SpecialSimplAndVisit(expr)).defined()) { + return response; + } + if ((response = TakeLogDiffIfSafe(e->a, e->b)).defined()) { + return MakeDiffableCond(response, /* sigmoid_or_hump */ false); + } + LOG_ERROR << "Cannot simplify " << e->a << " == " << e->b; + return PrimExpr(); + } + + PrimExpr VisitExpr_(const IntImmNode* e) final { + // Converts true into 1 and false into 0. + return Integer(e->value); + } + + PrimExpr VisitExpr_(const SelectNode* e) final { + if (ExprIsShapeConstant(e->condition)) { + return select(SimplifyExpr(e->condition, 20, 1000), VisitExpr(e->true_value), + VisitExpr(e->false_value)); + } + auto cond = VisitExpr(e->condition), tv = VisitExpr(e->true_value), + fv = VisitExpr(e->false_value); + return cond * (tv - fv) + fv; + } + + // 1 - x below stands for `not x` + PrimExpr VisitExpr_(const LENode* e) final { return 1 - this->VisitExpr(LT(e->b, e->a)); } + PrimExpr VisitExpr_(const GENode* e) final { return 1 - this->VisitExpr(LT(e->a, e->b)); } + PrimExpr VisitExpr_(const GTNode* e) final { return this->VisitExpr(LT(e->b, e->a)); } + PrimExpr VisitExpr_(const NENode* e) final { return 1 - this->VisitExpr(EQ(e->a, e->b)); } + + PrimExpr VisitExpr_(const AndNode* e) final { return VisitExpr(e->a) * VisitExpr(e->b); } + PrimExpr VisitExpr_(const OrNode* e) final { + std::vector<PrimExpr> chained_ors; + CollectSameOps(e, chained_ors); + // Using de Morgan's law: + PrimExpr ret = Integer(1); + for (auto& expr : chained_ors) { + ret *= 1 - VisitExpr(expr); + } + return 1 - ret; + } + PrimExpr VisitExpr_(const NotNode* e) final { return 1 - VisitExpr(e->a); } + + LinearExpr IsTermLinear(const PrimExpr& diff) { + auto lin_diff = LinExprExtractor()(diff); + if (!lin_diff.defined()) { + return LinearExpr(); + } + bool all_vars_allowed = true; + for (auto& [var, _] : lin_diff->lin_terms) { + if (!this->new_knobs.count(var->name_hint)) { + all_vars_allowed = false; + } + } + return all_vars_allowed ? lin_diff : LinearExpr(); + } + + private: + PrimExpr ConstVsSingleSchedVar(const PrimExpr& lhs, const PrimExpr& rhs) { + auto* rhs_v = rhs.as<SizeVarNode>(); + if (lhs->IsInstance<IntImmNode>() && rhs_v) { + auto it = this->exp_subst.find(rhs_v->name_hint); + if (it != this->exp_subst.end()) { + return it->second; + } + } + return PrimExpr(); + } + + PrimExpr SubstShorthandsAway(const PrimExpr& expr) { + t4.Start(); + bool changed = false; + auto ret = SubstByName(expr, this->shorthands, &changed); + ret = changed ? SimplifyExpr(ret) : PrimExpr(); + t4.Stop(); + return ret; + } + + PrimExpr SpecialSimplAndVisit(const PrimExpr& expr) { + if (!this->progressed) { + t5.Start(); + auto direct_simpl = SpecialSimplIfSafe(expr); + t5.Stop(); + if (direct_simpl.defined()) { + this->progressed = true; + auto ret = VisitExpr(direct_simpl); + this->progressed = false; + return ret; + } + } + return PrimExpr(); + } + + PrimExpr SpecialSimplIfSafe(const PrimExpr& expr) { + bool safe_vars_only = true, has_quotients = false; + PostOrderVisit(expr, [this, &safe_vars_only, &has_quotients](const ObjectRef& obj) { + if (auto* var = obj.as<VarNode>()) { + if (!obj->IsInstance<SizeVarNode>() || this->shorthands.count(var->name_hint)) { + safe_vars_only = false; + } else if (this->quotients.count(var->name_hint)) { + has_quotients = true; + } + } + }); + if (!safe_vars_only) { + return PrimExpr(); + } + auto simpl = SimplifyExpr(expr, 30, 10000, true); + if (has_quotients) { + auto subbed = SubAndSimplify(expr, this->quotients); + return CountOps(subbed) < CountOps(expr) ? subbed : simpl; + } else { + return simpl; + } + } + + PrimExpr TakeLogDiffIfSafe(PrimExpr a, PrimExpr b) { + double amin, amax, bmin, bmax; + ICHECK(ToConstRange(this->rinf(a), amin, amax)); + ICHECK(ToConstRange(this->rinf(b), bmin, bmax)); + if (amin <= 0 || bmin <= 0) { + LOG_WARNING << "Cannot take diff of " << PrintExprPrefix(a) << " (" << amin << ") and " + << PrintExprPrefix(b) << " (" << bmin << ") to a range stable form"; + } + auto ret = SimplifyExpr(log(a) - log(b), 20, 5000, true); + return ret; + } + + PrimExpr MakeDiffableCond(const PrimExpr& e, bool sigmoid_or_hump) { + auto* e_int = e.as<IntImmNode>(); + if (e_int) { + if (sigmoid_or_hump) { + return e_int->value > 0 ? 1 : 0; + } else { + return e_int->value == 0 ? 1 : 0; + } + } + String var_name = "da_" + std::to_string(this->vdefs_out.size()); + this->vdefs_out[var_name] = sigmoid_or_hump ? sigmoid(e) : hump(e); + return SizeVar(var_name, SizeVarKind::kShorthand); + } + + public: + VarMapT vdefs_out; + + private: + Timer t4{"SubstShorthandsAway"}, t5{"SpecialSimplAndVisit"}; + RangeInfer rinf; + std::unordered_set<std::string> new_knobs; + const VarMapT &exp_subst, &shorthands, "ients; + bool progressed{}; +}; + +class FeaturePackPyNode : public Object { + public: + Array<Array<ObjectRef>> expressions; + Array<String> free_vars; + Array<LinearExpr> linear_cons; + Map<String, Map<Integer, SizeVar>> var_decomp; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("expressions", &expressions); + v->Visit("free_vars", &free_vars); + v->Visit("linear_cons", &linear_cons); + v->Visit("var_decomp", &var_decomp); + } + + static constexpr const char* _type_key = "ansor.FeaturePackPy"; + TVM_DECLARE_FINAL_OBJECT_INFO(FeaturePackPyNode, Object); +}; + +TVM_REGISTER_NODE_TYPE(FeaturePackPyNode); + +class FeaturePackPy : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(FeaturePackPy, ObjectRef, FeaturePackPyNode); +}; + +class FeaturePack { + public: + FeaturePack() = default; + + FeaturePack(VarDefStack vdefs, VarDefStack features, VarDefStack constraints, + Array<SplitGroup> sp_groups) + : variables({{"vdefs", std::move(vdefs)}, + {"features", std::move(features)}, + {"constraints", std::move(constraints)}}), + split_groups(std::move(sp_groups)) {} + + void RunRollingSimplify() { + // For every expression `e` in vdefs, + // for each variable `v` in `e` that's defined in vdefs as `v = e_v`, + // if substituting `v = e_v` into `e` (and simplifying) makes `e` smaller, do it. + auto& vdefs = this->variables.at("vdefs"); + vdefs->MapExprs([&vdefs](const PrimExpr& expr) { + std::unordered_set<const VarNode*> vars; + PostOrderVisit(expr, [&vdefs, &vars](const ObjectRef& node) { + auto* var = node.as<SizeVarNode>(); + if (var && vdefs->Contains(var->name_hint)) { + vars.insert(node.as<VarNode>()); + } + }); + auto expr_ = expr; + for (auto& var : vars) { + auto& to_sub = vdefs->GetExprAt(var->name_hint); + auto subbed = arith::SubAndSimplify(expr, {{var->name_hint, to_sub}}); + if (CountOps(subbed) < CountOps(expr)) { + expr_ = subbed; + } + } + return expr_; + }); + } + + void RunExpTransform(const std::vector<size_t>& prime_bases) { + // Concat `features` to the end of `vdefs` to form `vi_and_features`, + // and split out E{i} variables from vdefs into `ei_vars`, the rest into `vi_vars`. + VarDefStack& vi_and_features = this->variables["vi_and_features"]; + { + VarDefStack &vi_vars = this->variables["vi_vars"], &ei_vars = this->variables["ei_vars"]; + for (auto& pair : this->variables.at("vdefs")->GetExprs()) { + if (pair->var->kind == SizeVarKind::kShapeVar) { + ei_vars->Append(pair->var, pair->expr); + } else { + vi_vars->Append(pair->var, pair->expr); + vi_and_features->Append(pair->var, pair->expr); + } + } + for (auto& pair : this->variables.at("features")->GetExprs()) { + vi_and_features->Append(pair->var, pair->expr); + } + this->variables.erase("vdefs"); + this->variables.erase("features"); + } + // List all schedule vars (such as sp_i_j), and create exp decomposition for them, + // inserting into this->var_decomp and `exp_decomp`. + auto& exp_decomp = this->variables["exp_decomp"]; + auto all_var_names = vi_and_features->GetAllUsedVars(SizeVarKind::kScheduleKnob); + // 1. For variables in `SplitGroup`s, we create a new variable for each prime base p: + for (auto& group : this->split_groups) { + auto extent_vname = group->extent->name_hint; + // Create sp_i_j_2, sp_i_j_3, ... for each sp_i_j and prime base p. + for (auto& vname : group->vars) { + exp_decomp->Append(vname, + this->DecomposeOneVar(prime_bases, vname, SizeVarKind::kScheduleKnob)); + all_var_names.erase(vname); + } + // Create Ei_2, Ei_3, ... for the group's extent Ei and each prime base p. + exp_decomp->Append(extent_vname, + this->DecomposeOneVar(prime_bases, extent_vname, SizeVarKind::kShapeVar)); + // Create representation of the quotient Ei / (sp_i_0 * sp_i_1 * ...) + // as a sum in the power: 2**(Ei_2 - sp_i_0_2 - sp_i_1_2 - ...) * 3**(...) * ... + // Also insert the constraints that Ei_b >= sp_i_0_b + sp_i_1_b + ... + // (essentially meaning that each power consisting q{i} is non-negative) + auto quotient = Integer(1); + for (size_t prime : prime_bases) { + LinearExpr q_total_power(this->var_decomp[extent_vname][prime]); + for (auto& vname : group->vars) { + q_total_power -= LinearExpr(this->var_decomp[vname][prime]); + } + this->linear_cons.push_back(q_total_power); // Defaults to >= 0. + quotient *= pow(Integer(prime), q_total_power.ToPrimExpr()); + } + exp_decomp->Append(group->quotient, quotient); + } + // 2. For all other variables, just do a log2. + for (auto& vname : all_var_names) { + SizeVar sv(vname + "_2", SizeVarKind::kScheduleKnob); + this->var_decomp[vname][2] = sv; + exp_decomp->Append(vname, pow(Integer(2), sv)); + } + } + + void RunDiffTransform(size_t n_threads) { + // Remove floordiv and floormod from all variables except Ei variables. + for (auto& [table_name, vdefs] : this->variables) { + if (table_name != "ei_vars") { + vdefs->MapExprs(FloorRemover()); + } + } + // Prepare 2 substitution maps for the differentiability transformation: + // 1. exp_decomp: sp_i_j = 2**sp_i_j_2 * 3**sp_i_j_3 * ..., Ei = 2**Ei_2 * 3**Ei_3 * ..., qi = + // 2**qi_2 * 3**qi_3 * ... + // - Also get a list of all variables on the power (sp_i_j_2, Ei_2, qi_2, ...) from + // exp_decomp. + // 2. shorthands: vi = ... (all variables in vi_vars) + // 3. quotient: all the qi variables in non-exponential form: + // q0 = E0 / (sp_0_0 * sp_0_1 * ...), q1 = E1 / (sp_1_0 * sp_1_1 * ...), ... + // - This can help simplify expressions like (q0 * sp_0_0 * sp_0_1), without having to expand + // everything into exp form. + auto& exp_decomp = this->variables.at("exp_decomp"); + auto exp_sched_vars = exp_decomp->GetAllUsedVars(std::nullopt); + VarMapT exp_subst = exp_decomp->IntoUnwindedVarMap(), + shorthands = this->variables.at("vi_vars")->IntoUnwindedVarMap(); + this->variables.erase("vi_vars"); + VarMapT quotient; + for (auto& group : this->split_groups) { + PrimExpr extent = group->extent; + for (auto& vname : group->vars) { + extent = div(extent, SizeVar(vname, SizeVarKind::kScheduleKnob)); + } + quotient[group->quotient->name_hint] = extent; + } + auto& features = this->variables.at("vi_and_features"); + DiffableApprox da(exp_sched_vars, exp_subst, shorthands, quotient); + { + Timer timer("DA"); + features->MapExprs([&da](const PrimExpr& e) { return da(e); }); + } + auto& diff_approx = this->variables["diff_approx"]; + for (auto& [vname, expr] : da.vdefs_out) { + diff_approx->Append(vname, expr); + } + VarDefStack constraints; + for (auto& pair : this->variables.at("constraints")->GetExprs()) { + auto* expr_lt = pair->expr.as<LENode>(); + ICHECK(expr_lt); + auto diff = log(expr_lt->b) - log(expr_lt->a); + auto linexpr = + LinExprExtractor()(SimplifyExpr(SubstByName(SubstByName(diff, shorthands), exp_subst))); + if (linexpr.defined()) { + this->linear_cons.push_back(linexpr); + } else { + constraints->Append(pair->var, diff); + } + } + this->variables["constraints"] = constraints; + } + + void RunSizeSubstitution(const Map<String, Integer>& size_subst) { + VarMapT size_subst_(size_subst.begin(), size_subst.end()); + std::unordered_map<std::string, std::vector<uint64_t>> shape_value_primes; + for (auto& pair : this->variables.at("ei_vars")->GetExprs()) { + auto& vname = pair->var->name_hint; + auto expr = SubAndSimplify(pair->expr, size_subst_); + auto* expr_int = expr.as<IntImmNode>(); + ICHECK(expr_int) << "All shape variables must be substituted to integers; got " << pair->var + << " = " << expr; + size_subst_[vname] = expr; + std::unordered_map<uint64_t, uint64_t> factors = Factorize(expr_int->value); + auto& shape_var_decomp = this->var_decomp.at(vname); + for (auto& [prime, power] : factors) { + shape_value_primes[vname].push_back(prime); + auto it = shape_var_decomp.find(prime); + ICHECK(it != shape_var_decomp.end()) + << "Shape variable " << vname << " = " << expr_int->value << " contains factor " + << prime << " that the features weren't factorized for."; + size_subst_[it->second->name_hint] = Integer(power); + } + for (auto& [k, v] : shape_var_decomp) { + if (!size_subst_.count(v->name_hint)) { + size_subst_[v->name_hint] = Integer(0); + } + } + } + VarDefStack concated; + for (auto& [table_name, vdefs] : this->variables) { + if (table_name != "ei_vars") { + for (auto& pair : vdefs->GetExprs()) { + concated->Append(pair->var, pair->expr); + } + } + } + concated->MapExprsParallel([this, &size_subst_](const PrimExpr& expr) { + return SubAndSimplify(expr, size_subst_, true); + }); + for (auto& linexpr : this->linear_cons) { + linexpr = LinExprExtractor()(SubAndSimplify(linexpr.ToPrimExpr(), size_subst_, true)); + std::cout << linexpr << "\n"; + ICHECK(linexpr.defined()); + } + + this->variables.clear(); + this->variables["vdefs"] = concated; + } + + FeaturePackPy IntoPythonFeaturePack() const { + auto node = make_object<FeaturePackPyNode>(); + auto& vdefs = this->variables.at("vdefs"); + for (auto& pair : vdefs->GetExprs()) { + node->expressions.push_back({pair->var->name_hint, pair->expr}); + } + node->linear_cons = this->linear_cons; + std::unordered_set<std::string> free_vars; + for (auto& [k1, vs] : this->var_decomp) { + Map<Integer, SizeVar> m; + for (auto& [k2, v] : vs) { + if (v->kind == SizeVarKind::kScheduleKnob) { + free_vars.insert(v->name_hint); + } + m.Set(k2, v); + } + node->var_decomp.Set(k1, m); + } + for (auto& vname : free_vars) { + node->free_vars.push_back(vname); + } + return FeaturePackPy(node); + } + + static std::optional<FeaturePack> LoadFromJsonReader(dmlc::JSONReader& reader) { + bool is_defined; + reader.BeginArray(); + ICHECK(reader.NextArrayItem()); + reader.Read(&is_defined); + if (!is_defined) { + return std::nullopt; + } + FeaturePack fp; + ICHECK(reader.NextArrayItem()); + reader.Read(&fp.variables); + ICHECK(reader.NextArrayItem()); + reader.Read(&fp.linear_cons); + ICHECK(reader.NextArrayItem()); + reader.Read(&fp.split_groups); + ICHECK(reader.NextArrayItem()); + reader.Read(&fp.var_decomp); + ICHECK(!reader.NextArrayItem()); + return fp; + } + + static void SaveAsJson(const String& filepath, const std::optional<FeaturePack>& fp) { + std::ofstream fout(filepath); + dmlc::JSONWriter writer(&fout); + writer.BeginArray(true); + writer.WriteArrayItem((bool)fp); + if (fp) { + writer.WriteArrayItem(fp->variables); + writer.WriteArrayItem(fp->linear_cons); + writer.WriteArrayItem(fp->split_groups); + writer.WriteArrayItem(fp->var_decomp); + } + writer.EndArray(); + } + + PrimExpr DecomposeOneVar(const std::vector<size_t>& prime_bases, const std::string& vname, + SizeVarKind kind) { + PrimExpr subst_expr = Integer(1); + for (size_t prime : prime_bases) { + SizeVar var(vname + "_" + std::to_string(prime), kind); + this->var_decomp[vname][prime] = var; + subst_expr *= pow(Integer(prime), var); + } + return subst_expr; + } + + // void ProcessConstraint() { + // auto* expr_lt = simpl.as<LTNode>(); + // ICHECK(expr_lt); + // LinearExpr linexpr; + // for (auto &piece: CollectSameOps<MulNode>(expr_lt->a)) { + // auto *evar = piece.as<SizeVarNode>(); + // if (!evar) { + // break; + // } + // auto it = this->var_decomp.find(evar->name_hint); + // if (it == this->var_decomp.end()) { + // break; + // } + // it->second + // } + // } + + std::unordered_map<std::string, VarDefStack> variables; + std::vector<LinearExpr> linear_cons{}; + Array<SplitGroup> split_groups; + DecompT var_decomp{}; +}; + +class StmtSimplifier : public StmtExprMutator { + public: + StmtSimplifier(RangeInfer& rinf, const VarDefStack& vdefs) : rinf(rinf) { + for (auto& pair : vdefs->GetExprs()) { + if (pair->var->kind == SizeVarKind::kShapeVar) { + this->e_vars.emplace_back(pair->var, pair->expr); + } + } + } + + Stmt VisitStmt_(const ForNode* node) final { + PrimExpr extent = SimplifyExpr(this->rinf.GetMax(node->extent)); + this->rinf.Bind(node->loop_var, Range::FromMinExtent(0, extent), true); + Stmt body = StmtExprMutator::VisitStmt(node->body); + return For(node->loop_var, 0, extent, node->kind, body); + } + + Stmt VisitStmt_(const AttrStmtNode* node) final { + if (node->attr_key == tir::attr::thread_extent || node->attr_key == tir::attr::virtual_thread) { + PrimExpr extent = SimplifyExpr(this->rinf.GetMax(node->value)); + const Var& var = node->node.as<IterVarNode>()->var; + this->rinf.Bind(var, Range::FromMinExtent(0, extent), true); + Stmt body = StmtExprMutator::VisitStmt(node->body); + return AttrStmt(node->node, node->attr_key, extent, body); + } else { + return StmtExprMutator::VisitStmt_(node); + } + } + + // Remove tir.likely() for if-then-else + // which doesn't do anything for feature extraction + // and is hard to read. + Stmt VisitStmt_(const IfThenElseNode* node) final { + auto* call = node->condition.as<CallNode>(); + static auto op_likely = Op::Get("tir.likely"); + if (!call || !call->op.same_as(op_likely)) { + return StmtExprMutator::VisitStmt_(node); + } + return StmtExprMutator::VisitStmt(node->then_case); + } + + Stmt VisitStmt_(const AllocateNode* op) final { + Array<PrimExpr> extents; + for (const auto& x : op->extents) { + extents.push_back(SimplifyExpr(this->rinf.GetMax(x))); + } + return Allocate(op->buffer_var, op->dtype, extents, op->condition, + StmtExprMutator::VisitStmt(op->body), op->annotations); + } + + Stmt VisitStmt_(const BufferRealizeNode* node) final { + TryChangingBufferShape(node->buffer); + Array<Range> bounds; + for (auto& r : node->bounds) { + auto max = SimplifyExpr(this->rinf.GetMax(r->extent)); + bounds.push_back(Range::FromMinExtent(0, max)); + } + return BufferRealize(node->buffer, bounds, node->condition, + StmtExprMutator::VisitStmt(node->body)); + } + + Stmt VisitStmt_(const BufferStoreNode* node) final { + TryChangingBufferShape(node->buffer); + Array<PrimExpr> indices; + for (auto& index : node->indices) { + indices.push_back(SimplifyExpr(index)); + } + return BufferStore(node->buffer, StmtExprMutator::VisitExpr(node->value), indices); + } + + PrimExpr VisitExpr_(const BufferLoadNode* node) final { + TryChangingBufferShape(node->buffer); + Array<PrimExpr> indices; + for (auto& index : node->indices) { + indices.push_back(SimplifyExpr(index)); + } + return BufferLoad(node->buffer, indices); + } + + private: + void TryChangingBufferShape(const Buffer& buf) { + auto* buf_ = buf.get(); + if (this->touched_bufs.count(buf_)) { + return; + } + Array<PrimExpr>& buf_shape = *const_cast<Array<PrimExpr>*>(&buf_->shape); + for (size_t i = 0; i < buf_shape.size(); ++i) { + // HACK: pattern-match buffer shape stored in Buffer and BufferRealize + // against the expressions of the E{i} variables we defined in VarContext. + // If we were able to define these variables earlier, we wouldn't need to do this. + if (auto size_var = FindEquivalentMatch(buf_shape[i])) { + buf_shape.Set(i, size_var.value()); + } + } + this->touched_bufs.insert(buf_); + } + + std::optional<SizeVar> FindEquivalentMatch(const PrimExpr& expr) { + for (auto& [v, e] : this->e_vars) { + if (arith::IsExprEquivalent(e, expr)) { + return v; + } + } + return std::nullopt; + } + + PrimExpr SimplifyExpr(PrimExpr expr) { + auto it = this->_memo.find(expr); + if (it != this->_memo.end()) { + return it->second; + } + return this->_memo[expr] = arith::SimplifyExpr(expr); + } + + RangeInfer& rinf; + std::unordered_set<const BufferNode*> touched_bufs; + std::vector<std::pair<SizeVar, PrimExpr>> e_vars; + std::unordered_map<PrimExpr, PrimExpr, StructuralHash, StructuralEqual> _memo; +}; + +void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) { + PostOrderVisit(expr, [&vars](const ObjectRef& node) { + if (const VarNode* op = node.as<VarNode>()) { + vars->insert(op); + } + }); +} + +std::optional<FeaturePack> GetFeaturePack(Stmt stmt, VarContext context, + const auto_scheduler::HardwareParams& hw_params, + size_t cache_line_size, size_t max_n_bufs) { + auto st_vdefs = context->var_defs; + FeaturePack fp; + try { + RangeInfer rinf; + { + Timer timer("StmtSimplifier"); + stmt = StmtSimplifier(rinf, st_vdefs)(stmt); + } + // Feature and constraint extraction + Timer timer("FeatureExtraction"); + VarDefStack features = + GetPerStoreFeatureExpr(stmt, *(st_vdefs.get()), rinf, cache_line_size, max_n_bufs); + auto constraints_ = GetConstraints(stmt, hw_params); + VarDefStack constraints; + for (size_t i = 0; i < constraints_.size(); ++i) { + auto name = "con_" + std::to_string(i); + constraints->Append("con_" + std::to_string(i), constraints_[i]); + } + fp = FeaturePack(std::move(context->var_defs), std::move(features), std::move(constraints), + std::move(context->split_groups)); + } catch (const std::exception& e) { + LOG_WARNING << "Feature extraction failed: " << e.what(); + return std::nullopt; + } + // Simplify features that just came out from feature extraction + fp.RunRollingSimplify(); + fp.RunExpTransform({2, 3, 5, 7}); + fp.RunDiffTransform(1); + return fp; +} + +TVM_REGISTER_GLOBAL("auto_scheduler.GetFeaturePack") + .set_body_typed([](Stmt stmt, VarContext context, auto_scheduler::HardwareParams hw_params, + Map<String, Integer> sizes, size_t cache_line_size, size_t max_n_bufs, + bool factorize, String save_load_path) { + std::ifstream fin(save_load_path); + FeaturePack fp; + if (fin.is_open()) { + dmlc::JSONReader reader(&fin); + auto fp_opt = FeaturePack::LoadFromJsonReader(reader); + if (!fp_opt) { + LOG_WARNING << "File " << save_load_path << " notes previous feature extraction failure"; + return FeaturePackPy(); // None + } + fp = fp_opt.value(); + } else { + LOG_INFO << "Extracting features to save to " << save_load_path; + auto fp_opt = GetFeaturePack(stmt, context, hw_params, cache_line_size, max_n_bufs); + if (fp_opt) { + fp = fp_opt.value(); + FeaturePack::SaveAsJson(save_load_path, fp); + } else { + FeaturePack::SaveAsJson(save_load_path, std::nullopt); + return FeaturePackPy(); // None + } + } + fp.RunSizeSubstitution(sizes); + FeaturePack::SaveAsJson(save_load_path + ".json", fp); + return fp.IntoPythonFeaturePack(); + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.LinearExprAsPrimExpr").set_body_typed([](LinearExpr e) { + return e.ToPrimExpr(); +}); + +} // namespace felix +} // namespace tvm + +namespace dmlc { +namespace json { + +template <typename V> +struct Handler<std::unordered_map<uint64_t, V>> { + static void Write(JSONWriter* writer, const std::unordered_map<uint64_t, V>& map) { + writer->BeginArray(map.size() > 1); + for (auto& [k, v] : map) { + writer->WriteArraySeperator(); + writer->BeginArray(false); + writer->WriteArrayItem(k); + writer->WriteArrayItem(v); + writer->EndArray(); + } + writer->EndArray(); + } + + static void Read(JSONReader* reader, std::unordered_map<uint64_t, V>* map) { + map->clear(); + reader->BeginArray(); + while (reader->NextArrayItem()) { + uint64_t k; + V v; + reader->BeginArray(); + ICHECK(reader->NextArrayItem()); + reader->Read(&k); + ICHECK(reader->NextArrayItem()); + reader->Read(&v); + ICHECK(!reader->NextArrayItem()); + map->emplace(k, v); + } + } +}; + +template <> +struct Handler<::tvm::felix::LinearExpr> { + static void Write(JSONWriter* writer, const tvm::felix::LinearExpr& e) { + writer->BeginArray(); + writer->WriteArrayItem(e->constant->value); + std::unordered_map<std::string, double> name2coef; + for (auto& [var, coef] : e->lin_terms) { + name2coef[var->name_hint] = coef->value; + } + writer->WriteArrayItem(name2coef); + writer->EndArray(); + } + static void Read(JSONReader* reader, tvm::felix::LinearExpr* e) { + reader->BeginArray(); + ICHECK(reader->NextArrayItem()); + double constant; + reader->Read(&constant); + ICHECK(reader->NextArrayItem()); + std::unordered_map<std::string, double> name2coef; + reader->Read(&name2coef); + ICHECK(!reader->NextArrayItem()); + *e = tvm::felix::LinearExpr(constant, name2coef); + } +}; +} // namespace json +} // namespace dmlc diff --git a/src/felix/features.cc b/src/felix/features.cc new file mode 100644 index 000000000..fdcfa392d --- /dev/null +++ b/src/felix/features.cc @@ -0,0 +1,1060 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "features.h" + +#include <tvm/arith/analyzer.h> +#include <tvm/tir/analysis.h> + +#include <algorithm> +#include <optional> +#include <unordered_map> +#include <vector> + +#include "utils.h" + +namespace tvm { +namespace felix { + +using namespace tvm::tir; +using namespace tvm::arith; + +// The number of samples to extract for arithmetic intensity curves +// static const constexpr int ARITH_INTENSITY_CURVE_SAMPLE_N = 10; + +// Annotation position encoding +enum class AnnotationPosType : int { + kPosNone = 0, // Does not have this kind of annotation + kPosInnerSpatial = 1, // The annotated iterator is the innermost spatial iterator + kPosMiddleSpatial = 2, // The annotated iterator is a middle spatial iterator + kPosOuterSpatial = 3, // The annotated iterator is the outermost spatial iterator + kPosInnerReduce = 4, // The annotated iterator is the innermost reduce iterator + kPosMiddleReduce = 5, // The annotated iterator is a middle reduce iterator + kPosOuterReduce = 6, // The annotated iterator is the outermost reduce iterator + kPosMixed = 7 // The annotated iterator is a mixed space and reduce iterator +}; + +// Buffer access type +enum class BufferAccessType : int { kRead = 0, kWrite = 1, kReadWrite = 2, kUnknownRW = 3 }; + +// Accesses to a buffer +struct BufferAccess { + // data reuse type + BufferAccessType acc_type{BufferAccessType::kUnknownRW}; + // Use a two-dimensional array to store multiple multi-dimensional accesses. + // The innermost vector stores the multi-dimensional indices of one access. + std::vector<std::vector<PrimExpr>> indices; +}; + +// Feature for an access of a buffer +struct BufferAccessFeature { + std::string buffer_name; // The name of the buffer + BufferAccessType acc_type; // The type of the access + PrimExpr bytes; // The touched memory in bytes + PrimExpr unique_bytes; // The touched unique memory in bytes + PrimExpr lines; // The number of touched cache lines + PrimExpr unique_lines; // The number touched unique cache lines + // Types of data reuse + PrimExpr multi_read_cond, serial_multi_rw_cond, no_reuse_cond; + PrimExpr reuse_dis_iter; // The reuse distance in iterator number + PrimExpr reuse_dis_bytes; // The reuse distance in total touched bytes + PrimExpr reuse_ct; // The reuse ratio + PrimExpr bytes_d_reuse_ct; // bytes / reuse_ct + PrimExpr unique_bytes_d_reuse_ct; // unique_bytes / reuse_ct + PrimExpr lines_d_reuse_ct; // lines / reuse_ct + PrimExpr unique_lines_d_reuse_ct; // unique_lines / reuse_ct + PrimExpr stride; // The stride in access +}; + +// Feature set of a BufferStore statement +struct FeatureSet { + // Group 1: Computation related features + PrimExpr float_mad; // The number of float MAD (Multiply–add) ops + PrimExpr float_addsub; // The number of float add and sub ops + PrimExpr float_mul; // The number of float multiply ops + PrimExpr float_divmod; // The number of float div and mod ops + PrimExpr float_cmp; // The number of float comparison ops + PrimExpr float_math_func; // The number of float math func calls + PrimExpr float_other_func; // The number of other float func calls + PrimExpr int_mad; // The number of integer MAD (Multiply–add) ops + PrimExpr int_addsub; // The number of integer add and sub ops + PrimExpr int_mul; // The number of float multiply ops + PrimExpr int_divmod; // The number of float div and mod ops + PrimExpr int_cmp; // The number of float comparison ops + PrimExpr int_math_func; // The number of float math func calls + PrimExpr int_other_func; // The number of other float func calls + PrimExpr bool_op; // The number of bool ops + PrimExpr select_op; // The number of select ops + PrimExpr vec_num; // The number of vectorized iterators + PrimExpr vec_prod; // The product of the lengths of vectorized iterators + PrimExpr vec_len; // The length of the innermost vectorized iterator + AnnotationPosType vec_type; // The type of vectorization position + PrimExpr unroll_num; // The number of unrolled iterators + PrimExpr unroll_prod; // The product of the lengths of vectorized iterators + PrimExpr unroll_len; // The length of the innermost unrolled iterator + AnnotationPosType unroll_type; // The type of unroll position + PrimExpr parallel_num; // The number of paralleled iterators + PrimExpr parallel_prod; // The product of the lengths of paralleled iterators + PrimExpr parallel_len; // The length of the innermost paralleled iterators + AnnotationPosType parallel_type; // The type of parallel position + PrimExpr is_gpu; // Whether it is a GPU task + PrimExpr blockIdx_x_len; // The length of blockIdx.x + PrimExpr blockIdx_y_len; // The length of blockIdx.y + PrimExpr blockIdx_z_len; // The length of blockIdx.z + PrimExpr threadIdx_x_len; // The length of threadIdx.x + PrimExpr threadIdx_y_len; // The length of threadIdx.y + PrimExpr threadIdx_z_len; // The length of threadIdx.z + PrimExpr vthread_len; // The length of virtual thread + + // Group 2: Buffer access related features (per buffer) + std::vector<BufferAccessFeature> access_feas; + + // Group 3: Arithmetic intensity related features + // PrimExpr arith_intensity_curve[ARITH_INTENSITY_CURVE_SAMPLE_N]; // points sampled from the + // // arithmetic intensity curve + + // Group 4: Allocation related features + PrimExpr alloc_size; // The size of allocated buffer in bytes + PrimExpr alloc_outer_prod; // The product of lengths of loops outside the scope of the allocation + PrimExpr alloc_inner_prod; // The product of lengths of loops inside the score of the allocation + PrimExpr alloc_prod; // alloc_outer_prod * alloc_inner_prod + + // Group 5: Outer scope related features + PrimExpr outer_prod; // The product of lengths of outer loops + PrimExpr num_loops; // The number of outer loops + PrimExpr auto_unroll_max_step; // The value of pragma "auto_unroll_max_step" +}; + +namespace { + +// Return whether a var is in an expr +bool VarInExpr(const Var& var, const PrimExpr& expr) { + bool found = false; + + // Find by name, because TVM duplicates some loops such as threadIdx.x, + // creating 2 loop variables that have the same name but are different as objects. + PostOrderVisit(expr, [&found, &var](const ObjectRef& node) { + const VarNode* op = node.as<VarNode>(); + if (op && op->name_hint == var->name_hint) { + found = true; + } + }); + + return found; +} + +PrimExpr SelectNonZero(const PrimExpr& expr, PrimExpr non_zero) { + auto as_select = expr.as<SelectNode>(); + if (as_select) { + auto false_value = as_select->false_value.as<IntImmNode>(); + if (false_value && false_value->value == 0) { + return select(as_select->condition, CastToFloat(as_select->true_value), non_zero); + } + } + return select(expr == 0, non_zero, CastToFloat(expr)); +} + +PrimExpr SelectLogOr0(PrimExpr cond, PrimExpr value) { return select(cond, log(value), 0); } + +// Count math ops in an expr +class MathOpCounter : public ExprVisitor { + public: + MathOpCounter(RangeInfer& rinf) : rinf(rinf) {} + + PrimExpr FromExprMap(const ExprMap<size_t>& expr_map) const { + PrimExpr result = 0; + for (auto& [cond, count] : expr_map) { + result += select(cond, 0, Integer(count)); + } + return result; + } + + private: + PrimExpr ConstCond(Range lhs, Range rhs, Integer const_goal) { + bool lhs_const = this->ana.CanProveEqual(lhs->extent, 0), + rhs_const = this->ana.CanProveEqual(rhs->extent, 0); + if (lhs_const && rhs_const) { + return const_true(); + } else if (lhs_const && const_goal.defined()) { + return lhs->min == const_goal; + } else if (rhs_const && const_goal.defined()) { + return rhs->min == const_goal; + } else { + return const_false(); + } + } + +#define DefineVisitBinOp(Type, float_ct, int_ct, const_goal) \ + void VisitExpr_(const Type* op) override { \ + if (op->a.dtype().is_float()) { \ + float_ct++; \ + } else { \ + Range lhs = this->rinf(op->a), rhs = this->rinf(op->b); \ + PrimExpr const_cond = ConstCond(lhs, rhs, const_goal); \ + int_ct[const_cond]++; \ + } \ + ExprVisitor::VisitExpr_(op); \ + } + DefineVisitBinOp(AddNode, float_addsub, int_addsub, 0); + DefineVisitBinOp(SubNode, float_addsub, int_addsub, 0); + DefineVisitBinOp(MulNode, float_mul, int_mul, 1); + DefineVisitBinOp(DivNode, float_divmod, int_divmod, 1); + DefineVisitBinOp(FloorDivNode, float_divmod, int_divmod, 1); + DefineVisitBinOp(ModNode, float_divmod, int_divmod, 1); + DefineVisitBinOp(FloorModNode, float_divmod, int_divmod, 1); + DefineVisitBinOp(MaxNode, float_cmp, int_cmp, Integer()); + DefineVisitBinOp(MinNode, float_cmp, int_cmp, Integer()); + +#define BoolOp(Type) \ + void VisitExpr_(const Type* op) override { \ + bool_op++; \ + ExprVisitor::VisitExpr_(op); \ + } + BoolOp(AndNode); + BoolOp(OrNode); + BoolOp(NotNode); +#define NumToBoolCmpOp(Type) \ + void VisitExpr_(const Type* op) override { \ + if (op->a.dtype().is_float()) { \ + float_cmp++; \ + } else { \ + int_cmp[const_false()]++; \ + } \ + ExprVisitor::VisitExpr_(op); \ + } + NumToBoolCmpOp(EQNode); + NumToBoolCmpOp(NENode); + NumToBoolCmpOp(LTNode); + NumToBoolCmpOp(LENode); + NumToBoolCmpOp(GTNode); + NumToBoolCmpOp(GENode); + +#undef DefineVisitBinOp +#undef BoolOp +#undef NumToBoolCmpOp + + void VisitExpr_(const SelectNode* op) override { + select_op++; + ExprVisitor::VisitExpr_(op); + } + + // Returning empty range as we have no idea what the range could be. + void VisitExpr_(const CallNode* op) override { + auto* pop = op->op.as<OpNode>(); + ICHECK(pop != nullptr); + auto effect_kind = op_call_effect_[GetRef<Op>(pop)]; + bool is_pure = + effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation; + + if (is_pure) { + if (op->dtype.is_float()) { + float_math_func++; + } else { + int_math_func++; + } + } else { + if (op->dtype.is_float()) { + float_other_func++; + } else { + int_other_func++; + } + } + ExprVisitor::VisitExpr_(op); + } + + RangeInfer& rinf; + Analyzer ana; + OpAttrMap<TCallEffectKind> op_call_effect_ = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind"); + + public: + size_t float_mad{0}; // The number of float MAD (Multiply–add) ops + size_t float_addsub{0}; // The number of float add and sub ops + size_t float_mul{0}; // The number of float multiply ops + size_t float_divmod{0}; // The number of float div and mod ops + size_t float_cmp{0}; // The number of float comparison ops + size_t float_math_func{0}; // The number of float math func calls + size_t float_other_func{0}; // The number of other float func calls + size_t int_mad{0}; // The number of integer MAD (Multiply–add) ops + ExprMap<size_t> int_addsub; // The number of integer add and sub ops + ExprMap<size_t> int_mul; // The number of integer multiply ops + ExprMap<size_t> int_divmod; // The number of integer div and mod ops + ExprMap<size_t> int_cmp; // The number of integer comparison ops + size_t int_math_func{0}; // The number of float math func calls + size_t int_other_func{0}; // The number of other float func calls + size_t bool_op{0}; // The number of bool ops + size_t select_op{0}; // The number of select ops +}; + +// Extract all buffer accesses in an expr +class BufferAccessExtractor : public StmtExprVisitor { + public: + void ExtractReads(const PrimExpr& expr) { this->VisitExpr(expr); } + + void InsertAccess(const Buffer& buf, BufferAccessType acc_type, const Array<PrimExpr>& indices) { + BufferAccess& acc = buf_accesses[buf]; + acc.acc_type = acc_type; + acc.indices.push_back(std::vector<PrimExpr>(indices.begin(), indices.end())); + } + + void VisitExpr_(const BufferLoadNode* op) final { + BufferAccess& acc = buf_accesses[op->buffer]; + switch (acc.acc_type) { + case BufferAccessType::kRead: + break; + case BufferAccessType::kWrite: + acc.acc_type = BufferAccessType::kReadWrite; + break; + case BufferAccessType::kReadWrite: + break; + case BufferAccessType::kUnknownRW: + default: + acc.acc_type = BufferAccessType::kRead; + break; + } + + if (acc.acc_type != BufferAccessType::kReadWrite) { + // If a buffer is both read and written, in the tvm DSL, it must be a update, + // so the indices should be the same. Then we can skip appending indices for it. + // Otherwise we do the following. + buf_accesses[op->buffer].indices.push_back( + std::vector<PrimExpr>(op->indices.begin(), op->indices.end())); + } + StmtExprVisitor::VisitExpr_(op); + } + + BufferMap<BufferAccess> buf_accesses; +}; + +// Compute the coefficient for an loop iterator in an expression +// Note: we use an approximation strategy to find coefficient. +// Hopefully, it is faster than DetectLinearEquation and can handle more cases (non-linear) +class CoefficientExtractor : public StmtExprVisitor { + public: + void VisitExpr_(const MulNode* node) final { + StmtExprVisitor::VisitExpr_(node); + if (visited_var) { + if (!visited_add) { + if (auto a = node->a.as<IntImmNode>()) { + visited_mul = true; + stride = a->value; + } else if (auto b = node->b.as<IntImmNode>()) { + visited_mul = true; + stride = b->value; + } + } + } + } + + void VisitExpr_(const AddNode* node) final { + StmtExprVisitor::VisitExpr_(node); + if (visited_var) { + if (!visited_mul) { + visited_add = true; + stride = 1; + } + } + } + + void VisitExpr_(const VarNode* node) final { + if (node == var_) { + visited_var = true; + // This is a magic default stride in case our approximation strategy fails + stride = 2; + } + } + + int ExtractCoefficient(const PrimExpr& expr, const VarNode* var) { + visited_var = visited_mul = visited_add = false; + var_ = var; + + this->VisitExpr(expr); + + if (visited_var && !visited_mul && !visited_add) { + return 1; + } else { + return stride; + } + } + + bool visited_var{false}; + bool visited_mul{false}; + bool visited_add{false}; + int stride{0}; + + private: + const VarNode* var_{nullptr}; +}; + +// Compute stride for the accesses to a buffer +std::pair<bool, PrimExpr> ComputeStride(const std::vector<std::vector<PrimExpr>>& indices, + const Array<PrimExpr>& shape, const VarNode* stride_var) { + PrimExpr min_stride{}; + bool found = false; + CoefficientExtractor extractor; + + for (const auto& index : indices) { + PrimExpr shape_stride = 1; + for (int i = static_cast<int>(index.size()) - 1; i >= 0; i--) { + int coefficient = extractor.ExtractCoefficient(index[i], stride_var); + if (extractor.visited_var) { + found = true; + if (min_stride.defined()) { + min_stride = min(min_stride, std::abs(coefficient) * shape_stride); + } else { + min_stride = std::abs(coefficient) * shape_stride; + } + break; + } + shape_stride *= shape[i]; + } + } + return {found, min_stride}; +} + +PrimExpr LoopNonTrivialCond(const ForNode* loop) { + std::string name = loop->loop_var->name_hint; + if (name.substr(0, 8) == "blockIdx" || name.substr(0, 9) == "threadIdx" || name == "vthread") { + // These loops are always there no matter what the loop size is. + return Bool(true); + } + return loop->extent > 1; +} + +std::tuple<PrimExpr, PrimExpr, PrimExpr> ComputeStrideForLoops( + const std::vector<std::vector<PrimExpr>>& indices, const Array<PrimExpr>& shape, + const std::vector<const ForNode*> loops_reversed) { + PrimExpr reduce_ratio_acc = 1; + PrimExpr reduce_ratio = 0, stride = 0, innermost_stride = 0; + PrimExpr found = Bool(false), in_loop = Bool(true), first_loop = Bool(true); + for (const auto& loop : loops_reversed) { + PrimExpr non_trivial_loop = LoopNonTrivialCond(loop); + // If loop is trivial, then the following don't happen and we effectively have a `continue;`. + auto [found_, stride_] = ComputeStride(indices, shape, loop->loop_var.get()); + PrimExpr found_this = in_loop && non_trivial_loop && Bool(found_); + reduce_ratio_acc *= loop->extent; + if (found_) { + // innermost_stride is non-zero only when the stride is found from the innermost loop. + innermost_stride += SelectLogOr0(found_this && first_loop, stride_); + stride += SelectLogOr0(found_this, stride_); + } + reduce_ratio += SelectLogOr0(found_this, reduce_ratio_acc); + found = found || found_this; + // Breaks out when we actually find something. + in_loop = in_loop && (!non_trivial_loop || Bool(!found_)); + first_loop = first_loop && (!non_trivial_loop); + } + // Default value. Can also use !found here, but that expression is more complex. + reduce_ratio += SelectLogOr0(in_loop, reduce_ratio_acc); + ICHECK(innermost_stride.defined()); + return {exp(stride), exp(innermost_stride), exp(reduce_ratio)}; +} + +// Compute touched bytes and cache lines for accesses to a buffer +std::vector<PrimExpr> ComputeRegion(const std::vector<std::vector<PrimExpr>>& indices, + RangeInfer& rinf, VarDefStackNode& vdefs) { + std::vector<PrimExpr> ret; + if (indices.empty()) return ret; + if (indices.size() == 1) { + for (const auto& index : indices[0]) { + Range range = rinf(index); + ret.push_back(vdefs.DefineConstShorthand(range->extent + 1)); + } + } else { + for (const auto& indices_ : indices) { + Range range = rinf(indices_[0]); + PrimExpr size = range->extent + 1; + for (size_t i = 1; i < indices_.size(); ++i) { + size = max(size, rinf(indices_[i])->extent + 1); + } + ret.push_back(vdefs.DefineConstShorthand(size)); + } + } + return ret; +} + +using BufferInfo3 = std::tuple<BufferAccessType, PrimExpr, int>; +using ForTouchRegionT = std::unordered_map<const ForNode*, BufferMap<std::vector<BufferInfo3>>>; + +bool LoopIterInIndices(Var for_var, const std::vector<std::vector<PrimExpr>>& indices) { + for (size_t j = 0; j < indices.size(); j++) { + for (size_t k = 0; k < indices[j].size(); k++) { + if (VarInExpr(for_var, indices[j][k])) { + return true; + } + } + } + return false; +} + +PrimExpr ReuseDistInBytes(const BufferMap<std::vector<BufferInfo3>>& this_for_region, + bool include_n_elems) { + PrimExpr reuse_dis_bytes = 0; + for (const auto& iter : this_for_region) { + for (auto& [_, n_elem, dtype_bytes] : iter.second) { + if (include_n_elems) { + reuse_dis_bytes += n_elem * dtype_bytes; + } else { + reuse_dis_bytes += dtype_bytes; + } + } + } + return reuse_dis_bytes; +} + +PrimExpr MultiRWReuseDistance(const std::vector<BufferInfo3>& buffers, PrimExpr for_extent) { + ICHECK(!buffers.empty()); + PrimExpr reuse_dis_iter = std::get<1>(buffers[0]); + for (size_t i = 1; i < buffers.size(); ++i) { + reuse_dis_iter = min(reuse_dis_iter, std::get<1>(buffers[i])); + } + return div(reuse_dis_iter, for_extent); +} + +// Compute reuse distance and reuse ratio for accesses to a buffer +// return values: reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct +std::tuple<PrimExpr, PrimExpr, PrimExpr, PrimExpr, PrimExpr> ComputeReuse( + const Buffer& buf, const std::vector<std::vector<PrimExpr>>& indices, + const std::vector<const ForNode*>& for_loops, const ForTouchRegionT& for_touch_regions) { + // for (i = 0; i < N; i++) { + // if (trivial_loop[i]) continue; + // if (has_loop_iter[i]) { ... } + // else return x_i; // kLoopMultipleRead + // if (has_serial_uses[i]) return y_i; // kSerialMultipleReadWrite + // } + // return 0; // NoReuse + // + // Denote the condition that we're still in the loop for i-th iteration by `in_loop[i]`. Then + // multi_read_reuse[i] := in_loop[i] && !trivial_loop[i] && !has_loop_iter[i] + // serial_rw_reuse[i] = in_loop[i] && !trivial_loop[i] && has_serial_uses[i] + // Also + // in_loop[i + 1] := in_loop[i] && (!multi_read_reuse[i] && !serial_rw_reuse[i]) + // in_loop[0] = true + // The return value R from this function is then + // \sum_i select(multi_read_reuse[i], x_i, 0) + select(serial_rw_reuse[i], y_i, 0) + // This function has multiple return values, they all follow this same idea. + int n_loops = static_cast<int>(for_loops.size()); + PrimExpr in_loop = Bool(true), multi_read_reuse = Bool(false), serial_rw_reuse = Bool(false); + PrimExpr read_reuse_dist_iter = 1; + PrimExpr reuse_dist_iter = 0, reuse_dist_bytes = 0, reuse_count = 0; + for (int i = n_loops - 1; i >= 0; --i) { + auto* loop = for_loops[i]; + PrimExpr extent = loop->extent; + const auto& this_for_region = for_touch_regions.at(loop); + const auto& this_buffer = this_for_region.at(buf); + int serial_reuse = (int)this_buffer.size() - 1; + PrimExpr has_loop_iter = Bool(LoopIterInIndices(loop->loop_var, indices)); + // Use extent > 1 here (instead of LoopNonTrivialCond) to skip _all_ loops with extent 1 + // including threadIdx/blockIdx, because that's what the concrete version does + // (and it makes more sense because if extent is 1 then there won't really be a "reuse"). + PrimExpr non_trivial_loop = extent > 1, has_serial_uses = Bool(serial_reuse > 0); + PrimExpr multi_read_reuse_ = in_loop && non_trivial_loop && !has_loop_iter, + serial_rw_reuse_ = in_loop && non_trivial_loop && has_serial_uses; + PrimExpr no_exit = !non_trivial_loop || (has_loop_iter && !has_serial_uses); + in_loop = in_loop && no_exit; + + // accumulate/update reuse distance + PrimExpr rw_reuse_dist_iter = MultiRWReuseDistance(this_buffer, extent); + PrimExpr reuse_dist_iter_ = SelectLogOr0(multi_read_reuse_, read_reuse_dist_iter) + + SelectLogOr0(serial_rw_reuse_, rw_reuse_dist_iter); + read_reuse_dist_iter *= extent; // This after reuse_dist_iter_ + // For multi read, reuse_dist_bytes is computed based on the previous (1-level inner) loop. + // When this is the innermost loop, it's computed from this loop with a slightly different + // algorithm (`n_elems` is not counted). + PrimExpr read_reuse_dist_bytes_ = + i == n_loops - 1 ? ReuseDistInBytes(this_for_region, false) + : ReuseDistInBytes(for_touch_regions.at(for_loops[i + 1]), true); + PrimExpr rw_reuse_dist_bytes_ = div(ReuseDistInBytes(this_for_region, true), extent); + PrimExpr reuse_dist_bytes_ = SelectLogOr0(multi_read_reuse_, read_reuse_dist_bytes_) + + SelectLogOr0(serial_rw_reuse_, rw_reuse_dist_bytes_); + PrimExpr reuse_count_ = + SelectLogOr0(multi_read_reuse_, extent) + SelectLogOr0(serial_rw_reuse_, serial_reuse); + + multi_read_reuse = multi_read_reuse || multi_read_reuse_; + serial_rw_reuse = serial_rw_reuse || serial_rw_reuse_; + reuse_dist_bytes += reuse_dist_bytes_; + reuse_dist_iter += reuse_dist_iter_; + reuse_count += reuse_count_; + } + return std::make_tuple(multi_read_reuse, serial_rw_reuse, exp(reuse_dist_iter), + exp(reuse_dist_bytes), exp(reuse_count)); +} + +// Extract features for every BufferStore statement +class PerStoreFeatureExtractor : public StmtExprVisitor { + public: + explicit PerStoreFeatureExtractor(VarDefStackNode& vdefs, RangeInfer& rinf, int cache_line_size) + : vdefs(vdefs), rinf(rinf), cache_line_size_(cache_line_size) {} + + void VisitStmt_(const AttrStmtNode* node) final { + if (node->attr_key == tir::attr::thread_extent || node->attr_key == tir::attr::virtual_thread) { + const Var& var = node->node.as<IterVarNode>()->var; + this->is_gpu = true; + + // make a fake for node for blockIdx.x or threadIdx.x + For fake_for(var, 0, node->value, ForKind::kParallel, node->body); + auto fake_node = fake_for.as<ForNode>(); + this->for_loop_stack.push_back(fake_node); + StmtExprVisitor::VisitStmt_(node); + for_loop_stack.pop_back(); + } else if (node->attr_key == "pragma_auto_unroll_max_step") { + PrimExpr old_value = cur_auto_unroll_max_step_; + cur_auto_unroll_max_step_ = node->value; + StmtExprVisitor::VisitStmt_(node); + cur_auto_unroll_max_step_ = old_value; + } else { + StmtExprVisitor::VisitStmt_(node); + } + } + + void VisitStmt_(const ForNode* node) final { + for_loop_stack.push_back(node); + StmtExprVisitor::VisitStmt(node->body); + for_loop_stack.pop_back(); + } + + void VisitStmt_(const BufferStoreNode* node) final { + PrimExpr loop_prod = LoopProd(FilterForLoops(std::nullopt)); + + // Group 1: Computation related features + MathOpCounter moc(this->rinf); + moc(node->value); + ExtractComputationFeature(node, moc, loop_prod); + + // Group 2: Buffer access related features (per buffer) + std::vector<PrimExpr> mem_bytes_list, compute_ops_list; + PrimExpr cur_compute_ops; + ExtractBufferAccessFeature(node, moc, loop_prod, &cur_compute_ops, &compute_ops_list, + &mem_bytes_list); + + // Group 3: Arithmetic intensity related features + // LOG_WARNING << "ExtractArithmeticIntensityFeature is unsupported yet"; + // ExtractArithmeticIntensityFeature(node, cur_compute_ops, compute_ops_list, mem_bytes_list); + + // Group 5: Outer scope related features + ExtractOuterScopeFeature(node, loop_prod); + } + + void VisitStmt_(const BufferRealizeNode* node) final { + StmtExprVisitor::VisitStmt_(node); + + // Group 4: Allocation related features + ExtractAllocationFeature(node); + } + + // Extract computation related features (group 1) + void ExtractComputationFeature(const BufferStoreNode* node, const MathOpCounter& moc, + const PrimExpr& loop_prod) { + // Computation related features + FeatureSet& fea = bufstore_feats[node->buffer]; + + fea.float_mad = loop_prod * (int)moc.float_mad; + fea.float_addsub = loop_prod * (int)moc.float_addsub; + fea.float_mul = loop_prod * (int)moc.float_mul; + fea.float_divmod = loop_prod * (int)moc.float_divmod; + fea.float_cmp = loop_prod * (int)moc.float_cmp; + fea.float_math_func = loop_prod * (int)moc.float_math_func; + fea.float_other_func = loop_prod * (int)moc.float_other_func; + fea.int_mad = loop_prod * (int)moc.int_mad; + fea.int_addsub = loop_prod * moc.FromExprMap(moc.int_addsub); + fea.int_mul = loop_prod * moc.FromExprMap(moc.int_mul); + fea.int_divmod = loop_prod * moc.FromExprMap(moc.int_divmod); + fea.int_math_func = loop_prod * (int)moc.int_math_func; + fea.int_cmp = loop_prod * moc.FromExprMap(moc.int_cmp); + fea.int_other_func = loop_prod * (int)moc.int_other_func; + fea.bool_op = loop_prod * (int)moc.bool_op; + fea.select_op = loop_prod * (int)moc.select_op; + + FillLoopFeatures(ForKind::kVectorized, fea.vec_num, fea.vec_len, fea.vec_prod, fea.vec_type); + FillLoopFeatures(ForKind::kUnrolled, fea.unroll_num, fea.unroll_len, fea.unroll_prod, + fea.unroll_type); + FillLoopFeatures(ForKind::kParallel, fea.parallel_num, fea.parallel_len, fea.parallel_prod, + fea.parallel_type); + + // GPU threads + fea.is_gpu = Bool(this->is_gpu); + Map<String, PrimExpr> loop_sizes; + for (const auto& loop : this->for_loop_stack) { + loop_sizes.Set(loop->loop_var->name_hint, loop->extent); + } + fea.blockIdx_x_len = loop_sizes.Get("blockIdx.x").value_or(1); + fea.blockIdx_y_len = loop_sizes.Get("blockIdx.y").value_or(1); + fea.blockIdx_z_len = loop_sizes.Get("blockIdx.z").value_or(1); + fea.threadIdx_x_len = loop_sizes.Get("threadIdx.x").value_or(1); + fea.threadIdx_y_len = loop_sizes.Get("threadIdx.y").value_or(1); + fea.threadIdx_z_len = loop_sizes.Get("threadIdx.z").value_or(1); + fea.vthread_len = loop_sizes.Get("vthread").value_or(1); + } + + // Extract buffer access related features (group 2) + void ExtractBufferAccessFeature(const BufferStoreNode* node, const MathOpCounter& moc, + PrimExpr loop_prod, PrimExpr* cur_compute_ops, + std::vector<PrimExpr>* compute_ops_list, + std::vector<PrimExpr>* mem_bytes_list) { + std::vector<BufferAccessFeature>& acc_feas = bufstore_feats[node->buffer].access_feas; + // We may have multiple bufferstore nodes for the same buffer (e.g., 1 for initializing an + // array, and 1 for computing it). In that case, delibrately overwrite the previous result. + acc_feas.clear(); + + BufferAccessExtractor buf_extractor; + buf_extractor.InsertAccess(node->buffer, BufferAccessType::kWrite, node->indices); + buf_extractor.ExtractReads(node->value); + auto for_loops = FilterForLoops(std::nullopt), for_loops_rev = for_loops; + std::reverse(for_loops_rev.begin(), for_loops_rev.end()); + + // Compute touched region for all outer loops + // * Make a copy of our global RangeInfer and override all loop variables to be [min, min] + RangeInfer rangeinf = this->rinf; + for (auto* loop : for_loops) { + // Using [a, b] convension for range, this means [x->min, x->min]. + rangeinf.Bind(loop->loop_var, Range::FromMinExtent(loop->min, 0), true); + } + + mem_bytes_list->reserve(for_loops.size()); + compute_ops_list->reserve(for_loops.size()); + + *cur_compute_ops = (int)(moc.float_mad + moc.float_addsub + moc.float_mul + moc.float_divmod + + moc.float_cmp + moc.float_math_func + moc.float_other_func); + + // std::cout << "In BufferStoreNode " << node->buffer->name << "\n"; + std::vector<PrimExpr> tmp_region; + for (auto* loop : for_loops_rev) { + rangeinf.BindLoop(loop, true); + // std::cout << " in for loop " << loop->loop_var->name_hint << "\n"; + // Note, here we do overwrite. + // So if there are multiple BufferStoreNode, the last one will overwrite the first few. + // e.g. The update part in gemm will overwrite the init part. + BufferMap<std::vector<BufferInfo3>>& buffer_regions_map = for_touch_regions_[loop]; + PrimExpr mem_bytes = 0; + for (const auto& x : buf_extractor.buf_accesses) { + const Buffer& t = x.first; + const BufferAccess& acc = x.second; + // std::cout << " in buffer access name " << t->name << "\n"; + tmp_region = ComputeRegion(acc.indices, rangeinf, this->vdefs); + PrimExpr touched_size = ElementProduct(tmp_region); + buffer_regions_map[t].emplace_back(acc.acc_type, touched_size, t->dtype.bytes()); + mem_bytes += touched_size * t->dtype.bytes(); + } + + mem_bytes_list->push_back(log2(mem_bytes)); + *cur_compute_ops *= loop->extent; + compute_ops_list->push_back(log2(*cur_compute_ops)); + } + + // Buffer access related features (per buffer) + auto bufmap = buf_extractor.buf_accesses; + std::vector<std::pair<Buffer, BufferAccess>> buf_accs(bufmap.begin(), bufmap.end()); + for (size_t i = 0; i < buf_accs.size(); ++i) { + auto [buf, acc] = buf_accs[i]; + Integer ele_bytes = buf->dtype.bytes(); + // calculate bytes + PrimExpr bytes = loop_prod * ele_bytes, unique_bytes; + // calculate cache lines + PrimExpr stride, lines, unique_lines; + if (for_loops.empty()) { + unique_bytes = ele_bytes; + stride = 0; + lines = 1.0f; + unique_lines = 1.0f; + } else { + unique_bytes = this->vdefs.DefineConstShorthand( + std::get<1>(for_touch_regions_[for_loops.front()][buf].front()) * ele_bytes); + auto [stride_, innermost_stride, reduce_ratio] = + ComputeStrideForLoops(acc.indices, buf->shape, for_loops_rev); + // convert `stride` back to the stride of the innermost iterator + stride = innermost_stride; + auto term1 = min(1.0f, div(CastToFloat(stride_ * ele_bytes), (float)cache_line_size_)); + lines = max(div(CastToFloat(loop_prod), CastToFloat(reduce_ratio)) * term1, 1.0f); + + // Modeled after this: + // PrimExpr n_continuous = ele_bytes; + // for (int i = std::min(tmp_region.size() - 1, t->shape.size() - 1); i >= 0; i--) { + // if (this->ana_.CanProveEqual(tmp_region[i], t->shape[i])) { + // n_continuous *= tmp_region[i]; + // break; + // } + // } + PrimExpr n_continuous = 0, in_loop = Bool(true); + for (int i = std::min(tmp_region.size() - 1, buf->shape.size() - 1); i >= 0; i--) { + PrimExpr is_equal = tmp_region[i] == buf->shape[i]; + n_continuous += SelectLogOr0(in_loop && is_equal, ele_bytes * tmp_region[i]); + in_loop = in_loop && (!is_equal); + } + // If we've done the whole loop without `is_equal == True`, then the value + // should just be `ele_bytes`. + n_continuous += SelectLogOr0(in_loop, ele_bytes); + unique_lines = + max(div(CastToFloat(unique_bytes), min(exp(n_continuous), cache_line_size_)), 1.0f); + } + + auto [multi_read_cond, serial_multi_rw_cond, reuse_dis_iter, reuse_dis_bytes, reuse_ct] = + ComputeReuse(buf, acc.indices, for_loops, for_touch_regions_); + multi_read_cond = SimplifyExpr(multi_read_cond); + serial_multi_rw_cond = SimplifyExpr(serial_multi_rw_cond); + reuse_dis_iter = SimplifyExpr(reuse_dis_iter); + reuse_dis_bytes = SimplifyExpr(reuse_dis_bytes); + reuse_ct = SimplifyExpr(reuse_ct); + PrimExpr no_reuse_cond = SimplifyExpr(!(serial_multi_rw_cond || multi_read_cond)); + + acc_feas.emplace_back(); + BufferAccessFeature& acc_fea = acc_feas.back(); + + acc_fea.buffer_name = buf->name; + acc_fea.acc_type = acc.acc_type; + acc_fea.stride = stride; + acc_fea.bytes = bytes; + acc_fea.unique_bytes = unique_bytes; + acc_fea.lines = lines; + acc_fea.unique_lines = unique_lines; + acc_fea.multi_read_cond = multi_read_cond; + acc_fea.serial_multi_rw_cond = serial_multi_rw_cond; + acc_fea.no_reuse_cond = no_reuse_cond; + acc_fea.reuse_dis_iter = reuse_dis_iter; + acc_fea.reuse_dis_bytes = reuse_dis_bytes; + acc_fea.reuse_ct = reuse_ct; + // no reuse, multiply by a magic number '2' + PrimExpr coef = SelectNonZero(reuse_ct, 0.5f); + acc_fea.bytes_d_reuse_ct = bytes / coef; + acc_fea.unique_bytes_d_reuse_ct = unique_bytes / coef; + acc_fea.lines_d_reuse_ct = lines / coef; + acc_fea.unique_lines_d_reuse_ct = unique_lines / coef; + } + } + + // Extract allocation related features (group 4) + void ExtractAllocationFeature(const BufferRealizeNode* node) { + FeatureSet& fea = bufstore_feats[node->buffer]; + PrimExpr allocation_size = 1; + for (const auto& x : node->bounds) { + allocation_size *= this->rinf.GetMax(x->extent); + } + // allocation feature + allocation_size = this->vdefs.DefineConstShorthand(allocation_size); + auto loop_prod = LoopProd(FilterForLoops(std::nullopt)); + fea.alloc_size = allocation_size * node->buffer->dtype.bytes(); + fea.alloc_prod = allocation_size * loop_prod; + fea.alloc_outer_prod = loop_prod; + fea.alloc_inner_prod = div(fea.outer_prod, loop_prod); + } + + // Extract outer scope related features (group 5) + void ExtractOuterScopeFeature(const BufferStoreNode* node, const PrimExpr& loop_prod) { + FeatureSet& fea = bufstore_feats[node->buffer]; + fea.outer_prod = loop_prod; + fea.num_loops = CountLoops(for_loop_stack); + fea.auto_unroll_max_step = cur_auto_unroll_max_step_; + } + + void FillLoopFeatures(ForKind kind, PrimExpr& num, PrimExpr& len, PrimExpr& prod, + AnnotationPosType& type) { + auto loops = FilterForLoops(kind); + num = CountLoops(loops); + if (loops.empty()) { + len = prod = 0; + type = AnnotationPosType::kPosNone; + } else { + len = loops.back()->extent; + prod = 1; + for (auto* loop : loops) { + prod *= loop->extent; + } + type = AnnotationPosType::kPosMixed; + } + } + + std::vector<const ForNode*> FilterForLoops(std::optional<ForKind> kind) { + std::unordered_set<std::string> var_registered; + std::vector<const ForNode*> loops; + for (auto* loop : this->for_loop_stack) { + if (kind && kind.value() != loop->kind) { + continue; + } + std::string var_name = loop->loop_var->name_hint; + if (var_registered.count(var_name)) { + ICHECK(var_name.substr(0, 8) == "blockIdx" || var_name.substr(0, 9) == "threadIdx") + << "Duplicate non-gpu-grid loop var: " << var_name; + continue; + } + loops.push_back(loop); + } + return loops; + } + + PrimExpr CountLoops(const std::vector<const ForNode*>& loops) { + PrimExpr num = 0; + for (auto* loop : loops) { + num += select(LoopNonTrivialCond(loop), 1, 0); + } + return num; + } + + PrimExpr LoopProd(const std::vector<const ForNode*>& loops) { + PrimExpr ret = 1.0f; + for (auto* loop : loops) { + ret *= loop->extent; + } + return ret; + } + + public: + BufferMap<FeatureSet> bufstore_feats; + + private: + // The shared arithmetic analyzers + VarDefStackNode& vdefs; + RangeInfer& rinf; + VarMapT flatmap; + + std::vector<const ForNode*> for_loop_stack; + PrimExpr previous_outer; + // GPU-related features + bool is_gpu; + PrimExpr cur_auto_unroll_max_step_{0}; + + // Store touch region information for all for loops. The format of this nested map: + // For a loop, for all its touched buffers, for all different accesses to the buffers, + // its (access type, number of touched elements, number of bytes of single element) + ForTouchRegionT for_touch_regions_; + + // The default cache line size in bytes + const int cache_line_size_ = 64; +}; + +PrimExpr slog(PrimExpr x) { return x.dtype().is_bool() ? x : log(x); } + +} // namespace + +VarDefStack GetPerStoreFeatureExpr(const Stmt& stmt, VarDefStackNode& vdefs, RangeInfer& rinf, + size_t cache_line_size, size_t max_n_bufs) { + // Extract features + PerStoreFeatureExtractor extractor(vdefs, rinf, cache_line_size); + extractor(stmt); + std::vector<std::pair<Buffer, FeatureSet>> buffer_features(extractor.bufstore_feats.begin(), + extractor.bufstore_feats.end()); + std::sort(buffer_features.begin(), buffer_features.end(), + [](auto& a, auto& b) { return a.first->name < b.first->name; }); + + // Define features in context, and put the resulted variable names in ret. + VarDefStack feats; + for (size_t i = 0; i < buffer_features.size(); ++i) { + auto& [buf, fea_set] = buffer_features[i]; + auto PushFeature = [&i, &feats](const std::string& name, const PrimExpr& val) { + auto name_ = "BS" + std::to_string(i) + "." + name; + feats->Append(name_, val); + }; + auto PushEnumFeature = [&i, &feats](const std::string& field, + const std::vector<std::string>& kind_names, auto val) { + for (size_t j = 0; j < kind_names.size(); j++) { + auto name_ = "BS" + std::to_string(i) + "." + field + "." + kind_names[j]; + feats->Append(name_, Bool(static_cast<size_t>(val) == j)); + } + }; + +#define PUSH_FEATURE(feature) PushFeature(#feature, fea_set.feature); +#define PUSH_ENUM_FEATURE(feature, names) PushEnumFeature(#feature, names, fea_set.feature); + /***** Group 1: Computation related features *****/ + PUSH_FEATURE(float_mad); + PUSH_FEATURE(float_addsub); + PUSH_FEATURE(float_mul); + PUSH_FEATURE(float_divmod); + PUSH_FEATURE(float_cmp); + PUSH_FEATURE(float_math_func); + PUSH_FEATURE(float_other_func); + PUSH_FEATURE(int_mad); + PUSH_FEATURE(int_addsub); + PUSH_FEATURE(int_mul); + PUSH_FEATURE(int_divmod); + PUSH_FEATURE(int_cmp); + PUSH_FEATURE(int_math_func); + PUSH_FEATURE(int_other_func); + PUSH_FEATURE(bool_op); + PUSH_FEATURE(select_op); + + static const std::vector<std::string> annot_pos_names = { + "kPosNone", "kPosInnerSpatial", "kPosMiddleSpatial", "kPosOuterSpatial", + "kPosInnerReduce", "kPosMiddleReduce", "kPosOuterReduce", "kPosMixed", + }; + PUSH_FEATURE(vec_num); + PUSH_FEATURE(vec_prod); + PUSH_FEATURE(vec_len); + PUSH_ENUM_FEATURE(vec_type, annot_pos_names); + PUSH_FEATURE(unroll_num); + PUSH_FEATURE(unroll_prod); + PUSH_FEATURE(unroll_len); + PUSH_ENUM_FEATURE(unroll_type, annot_pos_names); + PUSH_FEATURE(parallel_num); + PUSH_FEATURE(parallel_prod); + PUSH_FEATURE(parallel_len); + PUSH_ENUM_FEATURE(parallel_type, annot_pos_names); + + PUSH_FEATURE(is_gpu); + PUSH_FEATURE(blockIdx_x_len); + PUSH_FEATURE(blockIdx_y_len); + PUSH_FEATURE(blockIdx_z_len); + PUSH_FEATURE(threadIdx_x_len); + PUSH_FEATURE(threadIdx_y_len); + PUSH_FEATURE(threadIdx_z_len); + PUSH_FEATURE(vthread_len); + + /***** Group 2: Buffer access related features *****/ + static const std::vector<std::string> acc_type_names = {"kRead", "kWrite", "kReadWrite"}; + auto& buf_feats = fea_set.access_feas; + std::sort(buf_feats.begin(), buf_feats.end(), + [](auto& a, auto& b) { return a.buffer_name < b.buffer_name; }); + buf_feats.resize(max_n_bufs); + for (size_t j = 0; j < max_n_bufs; ++j) { + const auto& acc_fea = buf_feats[j]; + +#define PUSH_BUF_FEATURE(feature) \ + PushFeature("B" + std::to_string(j) + "." + #feature, \ + acc_fea.feature.defined() ? slog(acc_fea.feature) : Integer(0)); + + PushEnumFeature("B" + std::to_string(j) + ".acc_type", acc_type_names, acc_fea.acc_type); + PUSH_BUF_FEATURE(bytes); + PUSH_BUF_FEATURE(unique_bytes); + PUSH_BUF_FEATURE(lines); + PUSH_BUF_FEATURE(unique_lines); + PUSH_BUF_FEATURE(multi_read_cond); + PUSH_BUF_FEATURE(serial_multi_rw_cond); + PUSH_BUF_FEATURE(no_reuse_cond); + PUSH_BUF_FEATURE(reuse_dis_iter); + PUSH_BUF_FEATURE(reuse_dis_bytes); + PUSH_BUF_FEATURE(reuse_ct); + PUSH_BUF_FEATURE(bytes_d_reuse_ct); + PUSH_BUF_FEATURE(unique_bytes_d_reuse_ct); + PUSH_BUF_FEATURE(lines_d_reuse_ct); + PUSH_BUF_FEATURE(unique_lines_d_reuse_ct); + PUSH_BUF_FEATURE(stride); + } + + /***** Group 4: Allocation related features *****/ + PUSH_FEATURE(alloc_size); + PUSH_FEATURE(alloc_prod); + PUSH_FEATURE(alloc_outer_prod); + PUSH_FEATURE(alloc_inner_prod); + + /***** Group 5: Outer scope related features *****/ + PUSH_FEATURE(outer_prod); + PUSH_FEATURE(num_loops); + PUSH_FEATURE(auto_unroll_max_step); + } + return feats; +} + +} // namespace felix +} // namespace tvm diff --git a/src/felix/features.h b/src/felix/features.h new file mode 100644 index 000000000..586c879a9 --- /dev/null +++ b/src/felix/features.h @@ -0,0 +1,29 @@ +#ifndef TVM_AUTO_SCHEDULER_FEATURES_H_ +#define TVM_AUTO_SCHEDULER_FEATURES_H_ + +#include <tvm/auto_scheduler/search_task.h> +#include <tvm/tir/stmt.h> + +#include <unordered_map> + +#include "rangeinfer.h" + +namespace tvm { +namespace felix { + +template <class T> +using BufferMap = std::unordered_map<tir::Buffer, T, ObjectHash, ObjectEqual>; +template <class T> +using ExprMap = std::unordered_map<PrimExpr, T, StructuralHash, StructuralEqual>; + +arith::VarDefStack GetPerStoreFeatureExpr(const tir::Stmt& stmt, arith::VarDefStackNode& vdefs, + RangeInfer& rinf, size_t cache_line_size, + size_t max_n_bufs); + +std::vector<PrimExpr> GetConstraints(const tir::Stmt& code, + const auto_scheduler::HardwareParams& hw_params); + +} // namespace felix +} // namespace tvm + +#endif // TVM_AUTO_SCHEDULER_FEATURE_H_ diff --git a/src/felix/rangeinfer.h b/src/felix/rangeinfer.h new file mode 100644 index 000000000..2d86370c4 --- /dev/null +++ b/src/felix/rangeinfer.h @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/tir/expr_functor.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> + +#include <algorithm> + +namespace tvm { +namespace felix { + +using namespace tir; + +inline FloatImm ToFloatImm(double e) { return FloatImm(DataType::Float(32), e); } + +inline PrimExpr CastToFloat(const PrimExpr& value) { + return value->dtype.is_float() ? value : cast(DataType::Float(32), value); +} + +inline bool ToConstNumber(const PrimExpr& x, double& val) { + if (!x.defined()) { + return false; + } + if (const auto op = x.as<IntImmNode>()) { + val = static_cast<float>(op->value); + return true; + } else if (const auto op = x.as<FloatImmNode>()) { + val = op->value; + return true; + } else { + return false; + } +} + +inline bool ToConstRange(const Range& range, double& min, double& max) { + double extent; + if (!ToConstNumber(range->min, min) || !ToConstNumber(range->extent, extent)) { + return false; + } + max = min + extent; + return true; +} + +inline bool ToConstNumber(const Range& range, double& x) { + double extent; + return ToConstNumber(range->min, x) && ToConstNumber(range->extent, extent) && extent == 0; +} + +template <typename Ret> +class MemoizedExprFunctor : public ExprFunctor<Ret(const PrimExpr&)> { + public: + Ret VisitExpr(const PrimExpr& expr) override { + auto it = this->memo.find(expr); + if (it != this->memo.end()) { + return it->second; + } + return this->memo[expr] = ExprFunctor<Ret(const PrimExpr&)>::VisitExpr(expr); + } + + protected: + std::unordered_map<PrimExpr, Ret, StructuralHash, StructuralEqual> memo; +}; + +template <> +class MemoizedExprFunctor<PrimExpr> : public ExprMutator { + public: + PrimExpr VisitExpr(const PrimExpr& expr) override { + auto it = this->memo.find(expr); + if (it != this->memo.end()) { + return it->second; + } + return this->memo[expr] = ExprMutator::VisitExpr(expr); + } + + protected: + std::unordered_map<PrimExpr, PrimExpr, StructuralHash, StructuralEqual> memo; +}; + +class RangeInfer : public MemoizedExprFunctor<Range> { + public: + RangeInfer(Range range_for_sizevar = Range()) : range_for_sizevar(range_for_sizevar) {} + + void BindLoop(const ForNode* loop, bool allow_override) { + // To tighten the bound, we adopt an [a, b] convension for the range instead of [a, b). + Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent - 1), allow_override); + } + + void Bind(Var var, Range range, bool allow_override) { + ICHECK(arith::ExprIsConstant(range->min) && arith::ExprIsConstant(range->extent)) + << "Cannot bind non-constant range for variable " << var << " = " << range; + if (this->var_bind.count(var->name_hint)) { + if (allow_override) { + this->memo.clear(); + } else { + LOG_FATAL << "Cannot override range for " << var->name_hint << " (already defined)"; + } + } + this->var_bind[var->name_hint] = range; + } + + PrimExpr GetMax(PrimExpr e) { + auto range = VisitExpr(e); + return range->min + range->extent; + } + + private: +#define VisitBinOpSameMono(Type, Func) \ + Range VisitExpr_(const Type##Node* op) final { \ + Range lhs = VisitExpr(op->a), rhs = VisitExpr(op->b); \ + PrimExpr lmax = lhs->min + lhs->extent, rmax = rhs->min + rhs->extent; \ + PrimExpr begin = Func(lhs->min, rhs->min), end = Func(lmax, rmax); \ + return Range(begin, end); \ + } + +#define VisitBinOpRevMono(Type, Func) \ + Range VisitExpr_(const Type##Node* op) final { \ + Range lhs = VisitExpr(op->a), rhs = VisitExpr(op->b); \ + PrimExpr lmax = lhs->min + lhs->extent, rmax = rhs->min + rhs->extent; \ + PrimExpr begin = Func(lhs->min, rmax), end = Func(lmax, rhs->min); \ + return Range(begin, end); \ + } + +#define VisitMods(Type) \ + Range VisitExpr_(const Type##Node* op) final { \ + Range lhs = VisitExpr(op->a), rhs = VisitExpr(op->b); \ + PrimExpr zero = Integer(0); \ + if (is_const_int(lhs->min, 0) && is_const_int(lhs->extent, 0)) { \ + return Range(zero, zero); \ + } \ + return Range(zero, rhs->min + rhs->extent - 1); \ + } + + VisitBinOpSameMono(Add, add); + VisitBinOpSameMono(Mul, mul); + VisitBinOpRevMono(Sub, sub); + VisitBinOpRevMono(Div, div); + VisitBinOpRevMono(FloorDiv, floordiv); + VisitBinOpSameMono(Min, min); + VisitBinOpSameMono(Max, max); + VisitMods(Mod); + VisitMods(FloorMod); + + Range VisitExpr_(const IntImmNode* op) final { + return Range::FromMinExtent(GetRef<IntImm>(op), Integer(0)); + } + Range VisitExpr_(const FloatImmNode* op) final { + return Range::FromMinExtent(GetRef<FloatImm>(op), ToFloatImm(0.0)); + } + // SizeVar that we insert as schedule vars are seen as constants + // as they eventually will be constants for each given configuration. + Range VisitExpr_(const SizeVarNode* op) final { + if (this->range_for_sizevar.defined()) { + return this->range_for_sizevar; + } + return Range::FromMinExtent(GetRef<SizeVar>(op), 0); + } + Range VisitExpr_(const VarNode* op) final { + auto it = this->var_bind.find(op->name_hint); + if (it != this->var_bind.end()) { + return it->second; + } else { + LOG_FATAL << "Cannot find var \'" << op->name_hint << "\' in range inference"; + return Range(); + } + } + + Range VisitExpr_(const SelectNode* op) final { + // Don't have much use for range of condition. + Range lhs = VisitExpr(op->true_value), rhs = VisitExpr(op->false_value); + PrimExpr lmax = lhs->min + lhs->extent, rmax = rhs->min + rhs->extent; + return Range(min(lhs->min, rhs->min), max(lmax, rmax)); + } + + Range VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } + + Range VisitExpr_(const CallNode* op) final { + auto* pop = op->op.as<OpNode>(); + ICHECK(pop != nullptr); + if (pop->name == "tir.exp") { + Range r = VisitExpr(op->args[0]); + return Range(exp(r->min), exp(r->min + r->extent)); + } else if (pop->name == "tir.log") { + Range r = VisitExpr(op->args[0]); + return Range(log(r->min), log(r->min + r->extent)); + } else if (pop->name == "tir.logk") { + double base_ = 0; + if (!ToConstNumber(VisitExpr(op->args[0]), base_)) { + LOG_FATAL << "logk only supports constant base"; + } + Range x = VisitExpr(op->args[1]); + double x_min = 0, x_max = 0; + if (ToConstRange(x, x_min, x_max)) { + return Range(ToFloatImm(std::log(x_min) / std::log(base_)), + ToFloatImm(std::log(x_max) / std::log(base_))); + } else { + return Range(div(log(x->min), std::log(base_)), + div(log(x->min + x->extent), std::log(base_))); + } + } else if (pop->name == "tir.pow") { + Range r1 = VisitExpr(op->args[0]), r2 = VisitExpr(op->args[1]); + double r1_min = 0, r1_max = 0, r2_min = 0, r2_max = 0; + bool is_r1_const = ToConstRange(r1, r1_min, r1_max), + is_r2_const = ToConstRange(r2, r2_min, r2_max); + if (is_r1_const && is_r2_const) { + return Range(ToFloatImm(std::pow(r1_min, r2_min)), ToFloatImm(std::pow(r1_max, r2_max))); + } else if (is_r1_const && r1_min == r1_max) { + FloatImm r1f = ToFloatImm(r1_min); + if (r1_min <= 0) { + LOG_FATAL << "only pow with base >= 1 is supported; got base " << r1 << " in " + << GetRef<PrimExpr>(op); + } else if (r1_min < 1) { + return Range(pow(r1f, r2->min + r2->extent), pow(r1f, r2->min)); + } else { + return Range(pow(r1f, r2->min), pow(r1f, r2->min + r2->extent)); + } + } else if (is_r2_const && r2_min == r2_max) { + FloatImm r2f = ToFloatImm(r2_min); + if (r2_min < 0) { + return Range(pow(r1->min + r1->extent, r2f), pow(r1->min, r2f)); + } else { + return Range(pow(r1->min, r2f), pow(r1->min + r1->extent, r2f)); + } + } else { + LOG_FATAL << "pow with non-constant base and exponent is unsupported: " + << GetRef<PrimExpr>(op); + } + } else { + LOG_FATAL << "Call to " << pop->name << " not supported"; + } + return Range(); // unreachable + } + + Range VisitExprDefault_(const Object* op) final { + LOG_FATAL << "Expression of type " << op->GetTypeKey() << " is unsupported in RangeInfer"; + return Range(); + } + + public: + std::unordered_map<std::string, Range> var_bind; + Range range_for_sizevar; +}; + +} // namespace felix +} // namespace tvm diff --git a/src/felix/utils.h b/src/felix/utils.h index c06d7b4f5..5f6f28134 100644 --- a/src/felix/utils.h +++ b/src/felix/utils.h @@ -5,6 +5,8 @@ #include <tvm/auto_scheduler/transform_step.h> #include <tvm/runtime/container/string.h> +#include <numeric> + namespace tvm { namespace felix { @@ -14,6 +16,12 @@ using auto_scheduler::Step; String PrintTrStep(const Step& step); +/*! \brief Compute the product of all elements in a vector */ +inline PrimExpr ElementProduct(const std::vector<PrimExpr>& array) { + return std::accumulate(array.begin(), array.end(), PrimExpr(1), + [](const PrimExpr& a, const PrimExpr& b) { return a * b; }); +} + inline std::pair<PrimExpr, PrimExpr> GetCumulativeSpaceAndReductionLength_(const Stage& stage) { PrimExpr cum_space_len = 1, cum_reduce_len = 1; for (const auto& iter : stage->iters) { @@ -26,6 +34,47 @@ inline std::pair<PrimExpr, PrimExpr> GetCumulativeSpaceAndReductionLength_(const return std::make_pair(cum_space_len, cum_reduce_len); } +#define LOG_TIME + +class Timer { + public: + Timer(std::string event_name) + : event_name(std::move(event_name)), + duration(), + now(std::chrono::high_resolution_clock::now()), + stopped(false) {} + + void Start() { +#ifdef LOG_TIME + this->now = std::chrono::high_resolution_clock::now(); + this->stopped = false; +#endif + } + + void Stop() { +#ifdef LOG_TIME + if (!this->stopped) { + this->duration += std::chrono::high_resolution_clock::now() - this->now; + this->stopped = true; + } +#endif + } + + ~Timer() { +#ifdef LOG_TIME + Stop(); + auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(this->duration); + LOG_INFO << this->event_name << " -- " << duration.count() << " ms"; +#endif + } + + private: + std::string event_name; + std::chrono::duration<double> duration; + std::chrono::time_point<std::chrono::high_resolution_clock> now; + bool stopped; +}; + } // namespace felix } // namespace tvm diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 64142c621..d27160b82 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -31,34 +31,30 @@ namespace tvm { namespace tir { -#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ - Name::Name(PrimExpr a, PrimExpr b, Span span) { \ - using T = Name::ContainerType; \ - ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ - ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ - CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \ - << b.dtype() << "\n"; \ - ObjectPtr<T> node = make_object<T>(); \ - node->dtype = a.dtype(); \ - node->a = std::move(a); \ - node->b = std::move(b); \ - node->span = std::move(span); \ - data_ = std::move(node); \ +#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ + Name::Name(PrimExpr a, PrimExpr b, Span span) { \ + using T = Name::ContainerType; \ + ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ + ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ + ObjectPtr<T> node = make_object<T>(); \ + node->dtype = a.dtype(); \ + node->a = std::move(a); \ + node->b = std::move(b); \ + node->span = std::move(span); \ + data_ = std::move(node); \ } -#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \ - Name::Name(PrimExpr a, PrimExpr b, Span span) { \ - using T = Name::ContainerType; \ - ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ - ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ - CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \ - << b.dtype() << "\n"; \ - ObjectPtr<T> node = make_object<T>(); \ - node->dtype = DataType::Bool(a.dtype().lanes()); \ - node->a = std::move(a); \ - node->b = std::move(b); \ - node->span = std::move(span); \ - data_ = std::move(node); \ +#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \ + Name::Name(PrimExpr a, PrimExpr b, Span span) { \ + using T = Name::ContainerType; \ + ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ + ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ + ObjectPtr<T> node = make_object<T>(); \ + node->dtype = DataType::Bool(a.dtype().lanes()); \ + node->a = std::move(a); \ + node->b = std::move(b); \ + node->span = std::move(span); \ + data_ = std::move(node); \ } // Var @@ -610,7 +606,6 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp ICHECK(false_value.defined()) << "ValueError: true_value is undefined"; ICHECK(condition.dtype().is_bool()); ICHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1); - ICHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types"; ObjectPtr<SelectNode> node = make_object<SelectNode>(); node->dtype = true_value.dtype(); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index d9d6093f2..1b4ed4c9e 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -372,8 +372,8 @@ PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span) { return floordiv(a, b, spa PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span) { return floormod(a, b, span); } PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; - ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; + // ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; + // ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b, span); PrimExpr ret = arith::TryConstFold<tir::FloorDiv>(a, b); if (ret.defined()) return ret; @@ -381,8 +381,8 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) { } PrimExpr floormod(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; - ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; + // ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; + // ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b, span); PrimExpr ret = arith::TryConstFold<tir::FloorMod>(a, b); if (ret.defined()) return ret; @@ -612,8 +612,6 @@ TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a, Span span) // pow PrimExpr pow(PrimExpr x, PrimExpr y, Span span) { - BinaryOpMatchTypes(x, y, span); - ICHECK(x.dtype().is_float()) << "power only applies to float"; static auto op = Op::Get("tir.pow"); return tir::Call(x.dtype(), op, {x, y}, span); } @@ -879,6 +877,8 @@ TIR_REGISTER_PURE_UNARY_OP("tir.atanh"); TIR_REGISTER_PURE_UNARY_OP("tir.clz"); +TIR_REGISTER_PURE_UNARY_OP("tir.hump"); + // binary intrinsics TIR_REGISTER_PURE_BINARY_OP("tir.atan2"); @@ -890,6 +890,8 @@ TIR_REGISTER_PURE_BINARY_OP("tir.copysign"); TIR_REGISTER_PURE_BINARY_OP("tir.ldexp"); +TIR_REGISTER_PURE_BINARY_OP("tir.logk"); + // expose basic functions to node namespace TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) { if (args[0].type_code() == kDLInt) { -- GitLab