You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by cs...@apache.org on 2022/11/16 16:48:35 UTC
[tvm] branch main updated: [TOPI] Update names for pooling ops (#13401)
This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 271ad43029 [TOPI] Update names for pooling ops (#13401)
271ad43029 is described below
commit 271ad4302917011a54b257ca2a78c563a7ba652c
Author: abhikran-quic <63...@users.noreply.github.com>
AuthorDate: Wed Nov 16 22:18:28 2022 +0530
[TOPI] Update names for pooling ops (#13401)
[TOPI] Specify names for pooling ops
- Explicit names are useful while fetching
the compute during scheduling of pooling ops.
- Specify meta_schedule attributes.
---
include/tvm/topi/nn/pooling.h | 18 ++++++++++++------
1 file changed, 12 insertions(+), 6 deletions(-)
diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h
index c81c7cda7d..3503584687 100644
--- a/include/tvm/topi/nn/pooling.h
+++ b/include/tvm/topi/nn/pooling.h
@@ -353,7 +353,9 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_
return std::make_tuple(indices, reduce_axes);
};
+ Map<String, ObjectRef> attrs;
if (pool_type == kMaxPool) {
+ attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.adaptive_pool_max"));
return tvm::te::compute(
out_shape,
[&](const Array<Var>& output) {
@@ -362,8 +364,9 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_
std::tie(indices, reduce_axes) = get_iter_vars(output, true);
return tvm::max(x(indices), reduce_axes); // NOLINT(*)
},
- "tensor", "adaptive_pool_max");
+ "adaptive_pool_max", "adaptive_pool_max", attrs);
} else if (pool_type == kAvgPool) {
+ attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.adaptive_pool_avg"));
auto pool_sum = tvm::te::compute(
out_shape,
[&](const Array<Var>& output) {
@@ -372,7 +375,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_
std::tie(indices, reduce_axes) = get_iter_vars(output, true);
return tvm::sum(x(indices), reduce_axes);
},
- "tensor", "adaptive_pool_sum");
+ "adaptive_pool_sum", "adaptive_pool_sum");
return tvm::te::compute(
out_shape,
@@ -388,7 +391,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_
return div(pool_sum(indices), divide_factor);
},
- "tensor", kElementWise);
+ "adaptive_pool_avg", kElementWise, attrs);
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
return x;
@@ -566,8 +569,10 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
out_shape.Set(ii, out_dim);
}
+ Map<String, ObjectRef> attrs;
if (pool_type == kMaxPool) {
auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
+ attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.pool_max"));
return tvm::te::compute(
out_shape,
[&](const Array<Var>& output) {
@@ -580,8 +585,9 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
}
return tvm::max(temp(indices), daxis);
},
- "tensor", "pool_max");
+ "pool_max", "pool_max", attrs);
} else if (pool_type == kAvgPool) {
+ attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.pool_avg"));
// Pad the inputs
auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
@@ -598,7 +604,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
}
return tvm::sum(temp(indices), daxis);
},
- "tensor", "pool_sum");
+ "pool_sum", "pool_sum");
// TVM compute for dividing the reduced window sum by kernel size.
return tvm::te::compute(
@@ -650,7 +656,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
return div(pool_sum(indices), divide_factor);
}
},
- "tensor", kElementWise);
+ "pool_avg", kElementWise, attrs);
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
return x;