You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/11/16 05:05:40 UTC

[tvm] branch main updated: [ONNX][Relay] Support "tf_crop_and_resize" in relay Resize op. (#9475)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new eda12cb  [ONNX][Relay] Support "tf_crop_and_resize" in relay Resize op. (#9475)
eda12cb is described below

commit eda12cb52ce8cdc2027a7342d0b39b61727c5d89
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Mon Nov 15 22:04:56 2021 -0700

    [ONNX][Relay] Support "tf_crop_and_resize" in relay Resize op. (#9475)
    
    * add fallback to opset 11
    
    * Support tf_crop_and_resize in resize op
    
    * change api use in the rest of the codebase
    
    really fix the tests
    
    * respond to review comments, improve doc strings
    
    * fix docstring indentation
    
    * remove N anc C from resize roi
---
 include/tvm/relay/attrs/image.h                   |  24 +++
 python/tvm/relay/frontend/onnx.py                 |  66 ++++++--
 python/tvm/relay/frontend/pytorch.py              |   8 +-
 python/tvm/relay/frontend/tensorflow_ops.py       |   4 +-
 python/tvm/relay/frontend/tflite.py               |   2 +-
 python/tvm/relay/op/dyn/image/_image.py           |   6 +
 python/tvm/relay/op/image/_image.py               |  12 ++
 python/tvm/relay/op/image/image.py                |  97 +++++++++---
 python/tvm/topi/image/resize.py                   | 181 ++++++++++++++++++----
 python/tvm/topi/nn/upsampling.py                  |   2 +
 src/relay/op/dyn/image/resize.cc                  |  16 +-
 src/relay/op/image/resize.cc                      |  33 ++--
 src/relay/op/make_op.h                            |   7 +-
 src/relay/transforms/dynamic_to_static.cc         |  26 ++--
 src/relay/transforms/pattern_utils.h              |  17 ++
 tests/python/contrib/test_onnx.py                 |   1 +
 tests/python/frontend/onnx/test_forward.py        |   1 -
 tests/python/relay/dyn/test_dynamic_op_level5.py  |   2 +-
 tests/python/relay/test_any.py                    |   2 +-
 tests/python/relay/test_op_level5.py              |  12 +-
 tests/python/relay/test_pass_dynamic_to_static.py |   2 +-
 tests/python/topi/python/test_topi_image.py       |   2 +
 22 files changed, 419 insertions(+), 104 deletions(-)

diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h
index b851add..78687b3 100644
--- a/include/tvm/relay/attrs/image.h
+++ b/include/tvm/relay/attrs/image.h
@@ -35,16 +35,21 @@ namespace relay {
 /*! \brief Attributes used in image resize1d operator */
 struct Resize1DAttrs : public tvm::AttrsNode<Resize1DAttrs> {
   Array<IndexExpr> size;
+  Array<FloatImm> roi;
   std::string layout;
   std::string method;
   std::string coordinate_transformation_mode;
   std::string rounding_method;
   double cubic_alpha;
   int cubic_exclude;
+  double extrapolation_value;
   DataType out_dtype;
 
   TVM_DECLARE_ATTRS(Resize1DAttrs, "relay.attrs.Resize1DAttrs") {
     TVM_ATTR_FIELD(size).set_default(NullValue<Array<IndexExpr> >()).describe("Output Size.");
+    TVM_ATTR_FIELD(roi)
+        .set_default(NullValue<Array<FloatImm> >())
+        .describe("Region of Interest for coordinate transformation mode 'tf_crop_and_resize'");
     TVM_ATTR_FIELD(layout).set_default("NCW").describe(
         "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
         "'N', 'C', 'W' stands for batch, channel and width"
@@ -73,6 +78,9 @@ struct Resize1DAttrs : public tvm::AttrsNode<Resize1DAttrs> {
     TVM_ATTR_FIELD(cubic_exclude)
         .set_default(0)
         .describe("Flag to exclude exterior of the image during cubic interpolation");
+    TVM_ATTR_FIELD(extrapolation_value)
+        .set_default(0.0)
+        .describe("Value to return when roi is outside of the image");
     TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output data type.");
   }
 };
@@ -80,16 +88,21 @@ struct Resize1DAttrs : public tvm::AttrsNode<Resize1DAttrs> {
 /*! \brief Attributes used in image resize2d operator */
 struct Resize2DAttrs : public tvm::AttrsNode<Resize2DAttrs> {
   Array<IndexExpr> size;
+  Array<FloatImm> roi;
   std::string layout;
   std::string method;
   std::string coordinate_transformation_mode;
   std::string rounding_method;
   double cubic_alpha;
   int cubic_exclude;
+  double extrapolation_value;
   DataType out_dtype;
 
   TVM_DECLARE_ATTRS(Resize2DAttrs, "relay.attrs.Resize2DAttrs") {
     TVM_ATTR_FIELD(size).set_default(NullValue<Array<IndexExpr> >()).describe("Output Size.");
+    TVM_ATTR_FIELD(roi)
+        .set_default(NullValue<Array<FloatImm> >())
+        .describe("Region of Interest for coordinate transformation mode 'tf_crop_and_resize'");
     TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
         "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
         "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
@@ -118,6 +131,9 @@ struct Resize2DAttrs : public tvm::AttrsNode<Resize2DAttrs> {
     TVM_ATTR_FIELD(cubic_exclude)
         .set_default(0)
         .describe("Flag to exclude exterior of the image during bicubic interpolation");
+    TVM_ATTR_FIELD(extrapolation_value)
+        .set_default(0.0)
+        .describe("Value to return when roi is outside of the image");
     TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output data type.");
   }
 };
@@ -125,16 +141,21 @@ struct Resize2DAttrs : public tvm::AttrsNode<Resize2DAttrs> {
 /*! \brief Attributes used in image resize3d operator */
 struct Resize3DAttrs : public tvm::AttrsNode<Resize3DAttrs> {
   Array<IndexExpr> size;
+  Array<FloatImm> roi;
   std::string layout;
   std::string method;
   std::string coordinate_transformation_mode;
   std::string rounding_method;
   double cubic_alpha;
   int cubic_exclude;
+  double extrapolation_value;
   DataType out_dtype;
 
   TVM_DECLARE_ATTRS(Resize3DAttrs, "relay.attrs.Resize3DAttrs") {
     TVM_ATTR_FIELD(size).set_default(NullValue<Array<IndexExpr> >()).describe("Output Size.");
+    TVM_ATTR_FIELD(roi)
+        .set_default(NullValue<Array<FloatImm> >())
+        .describe("Region of Interest for coordinate transformation mode 'tf_crop_and_resize'");
     TVM_ATTR_FIELD(layout).set_default("NCDHW").describe(
         "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
         "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
@@ -163,6 +184,9 @@ struct Resize3DAttrs : public tvm::AttrsNode<Resize3DAttrs> {
     TVM_ATTR_FIELD(cubic_exclude)
         .set_default(0)
         .describe("Flag to exclude exterior of the image during tricubic interpolation");
+    TVM_ATTR_FIELD(extrapolation_value)
+        .set_default(0.0)
+        .describe("Value to return when roi is outside of the image");
     TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output data type.");
   }
 };
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 8eda1c9..0dc08d5 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -2610,13 +2610,13 @@ class Resize(OnnxOpConverter):
         out = None
         if ndims == 3:
             out_size = fold_constant(_op.strided_slice(size, [2], [3]))
-            out = _op.image.resize1d(inputs[0], out_size, "NCW", method, "asymmetric")
+            out = _op.image.resize1d(inputs[0], out_size, None, "NCW", method, "asymmetric")
         elif ndims == 4:
             out_size = fold_constant(_op.strided_slice(size, [2], [4]))
-            out = _op.image.resize2d(inputs[0], out_size, "NCHW", method, "asymmetric")
+            out = _op.image.resize2d(inputs[0], out_size, None, "NCHW", method, "asymmetric")
         elif ndims == 5:
             out_size = fold_constant(_op.strided_slice(size, [2], [5]))
-            out = _op.image.resize3d(inputs[0], out_size, "NCDHW", method, "asymmetric")
+            out = _op.image.resize3d(inputs[0], out_size, None, "NCDHW", method, "asymmetric")
         else:
             raise NotImplementedError("Resize only supports 3, 4, or 5 dims")
         return out
@@ -2639,6 +2639,12 @@ class Resize(OnnxOpConverter):
     def _impl_v13(cls, inputs, attr, params):
         scale = inputs[2]
         size = inputs[3]
+
+        # Some versions of onnx exporters produce an opset 13 model with the opset 11
+        # resize op, handle that edge case
+        if scale is not None and size is not None:
+            return cls._impl_v11(inputs, attr, params)
+
         if size is not None:
             assert scale is None, "One of scale or size should be passed, not both."
         else:
@@ -2657,6 +2663,9 @@ class Resize(OnnxOpConverter):
         they handle the passing of scale and size. This utility
         provides the implementation for both
         """
+        roi = inputs[1]
+        if roi is not None and infer_shape(roi)[0] == 0:
+            roi = None
         ndims = len(infer_shape(inputs[0]))
         mode = attr.get("mode").decode("ascii")
         if mode == "nearest":
@@ -2674,23 +2683,60 @@ class Resize(OnnxOpConverter):
         nearest_mode = attr.get("nearest_mode", b"round_prefer_floor").decode("ascii")
         alpha = attr.get("cubic_coeff_a", -0.75)
         exclude = attr.get("exclude_outside", 0)
+        extrapolation_value = attr.get("extrapolation_value", 0.0)
+
+        if roi is not None:
+            roi = fold_constant(
+                _op.concatenate(
+                    [
+                        _op.strided_slice(roi, [2], [ndims]),
+                        _op.strided_slice(roi, [ndims + 2], [2 * ndims]),
+                    ],
+                    axis=0,
+                )
+            )
+
+        out_size = fold_constant(_op.strided_slice(size, [2], [ndims]))
 
-        out_size = fold_constant(_op.strided_slice(size, [2], [4]))
         out = None
         if ndims == 3:
-            out_size = fold_constant(_op.strided_slice(size, [2], [3]))
             out = _op.image.resize1d(
-                inputs[0], out_size, "NCW", method, coord_trans, nearest_mode, alpha, exclude
+                inputs[0],
+                out_size,
+                roi,
+                "NCW",
+                method,
+                coord_trans,
+                nearest_mode,
+                alpha,
+                exclude,
+                extrapolation_value,
             )
         elif ndims == 4:
-            out_size = fold_constant(_op.strided_slice(size, [2], [4]))
             out = _op.image.resize2d(
-                inputs[0], out_size, "NCHW", method, coord_trans, nearest_mode, alpha, exclude
+                inputs[0],
+                out_size,
+                roi,
+                "NCHW",
+                method,
+                coord_trans,
+                nearest_mode,
+                alpha,
+                exclude,
+                extrapolation_value,
             )
         elif ndims == 5:
-            out_size = fold_constant(_op.strided_slice(size, [2], [5]))
             out = _op.image.resize3d(
-                inputs[0], out_size, "NCDHW", method, coord_trans, nearest_mode, alpha, exclude
+                inputs[0],
+                out_size,
+                roi,
+                "NCDHW",
+                method,
+                coord_trans,
+                nearest_mode,
+                alpha,
+                exclude,
+                extrapolation_value,
             )
         else:
             raise NotImplementedError("Resize only supports 3, 4, or 5 dims")
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index a17a10e..2c9268e 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1822,7 +1822,7 @@ class PyTorchOpConverter:
 
             def func(x):
                 return _op.image.resize2d(
-                    x, out_size, "NCHW", method, coord_trans, cubic_alpha=-0.75
+                    x, out_size, None, "NCHW", method, coord_trans, cubic_alpha=-0.75
                 )
 
             if self.is_quantized_tensor(data):
@@ -1854,7 +1854,7 @@ class PyTorchOpConverter:
             else:
                 coord_trans = "half_pixel"
 
-            return _op.image.resize3d(data, out_size, "NCDHW", method, coord_trans)
+            return _op.image.resize3d(data, out_size, None, "NCDHW", method, coord_trans)
 
         return upsample3d
 
@@ -2186,7 +2186,9 @@ class PyTorchOpConverter:
         else:
             coord_trans = "half_pixel"
 
-        return _op.image.resize2d(data, out_size, "NCHW", method, coord_trans, cubic_alpha=-0.75)
+        return _op.image.resize2d(
+            data, out_size, None, "NCHW", method, coord_trans, cubic_alpha=-0.75
+        )
 
     def numel(self, inputs, input_types):
         return _op.ndarray_size(inputs[0])
diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py
index 26ea4f4..df8b743 100644
--- a/python/tvm/relay/frontend/tensorflow_ops.py
+++ b/python/tvm/relay/frontend/tensorflow_ops.py
@@ -1090,7 +1090,9 @@ def _resize(method):
 
         # Ignore the new attributes from TF2.0, for now.
         return AttrCvt(
-            op_name="resize2d", ignores=["Tdim", "half_pixel_centers"], extras={"method": method}
+            op_name="resize2d",
+            ignores=["Tdim", "half_pixel_centers"],
+            extras={"method": method, "roi": None},
         )(inputs, attr)
 
     return _impl
diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py
index 12beca5..f0f20e1 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -658,7 +658,7 @@ class OperatorConverter(object):
         if bilinear_method and input_tensor.qnn_params:
             in_expr = self.dequantize(in_expr, input_tensor)
         out = _op.image.resize2d(
-            in_expr, target_size, "NHWC", method, coordinate_transformation_mode=coord_trans
+            in_expr, target_size, None, "NHWC", method, coordinate_transformation_mode=coord_trans
         )
         if bilinear_method and output_tensor.qnn_params:
             out = self.quantize(out, output_tensor)
diff --git a/python/tvm/relay/op/dyn/image/_image.py b/python/tvm/relay/op/dyn/image/_image.py
index 5e97d24..faebde0 100644
--- a/python/tvm/relay/op/dyn/image/_image.py
+++ b/python/tvm/relay/op/dyn/image/_image.py
@@ -28,16 +28,21 @@ from ... import op as reg
 # resize
 @reg.register_compute("dyn.image.resize2d")
 def compute_resize2d(attrs, inputs, out_type):
+    """
+    Compute function calls into topi
+    """
     layout = attrs.layout
     method = attrs.method
     coord_trans = attrs.coordinate_transformation_mode
     rounding_method = attrs.rounding_method
     cubic_alpha = attrs.cubic_alpha
     cubic_exclude = attrs.cubic_exclude
+    extrapolation_value = attrs.extrapolation_value
     out_dtype = attrs.out_dtype
     return [
         tvm.topi.image.resize2d(
             inputs[0],
+            inputs[2],
             inputs[1],
             layout,
             method,
@@ -45,6 +50,7 @@ def compute_resize2d(attrs, inputs, out_type):
             rounding_method,
             cubic_alpha,
             cubic_exclude,
+            extrapolation_value,
             out_dtype,
             out_type.shape,
         )
diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py
index ec24ff7..d992986 100644
--- a/python/tvm/relay/op/image/_image.py
+++ b/python/tvm/relay/op/image/_image.py
@@ -34,16 +34,19 @@ from .image import resize2d
 def compute_resize1d(attrs, inputs, out_type):
     """compute definition for resize1d op"""
     size = attrs.size
+    roi = attrs.roi
     layout = attrs.layout
     method = attrs.method
     coord_trans = attrs.coordinate_transformation_mode
     rounding_method = attrs.rounding_method
     cubic_alpha = attrs.cubic_alpha
     cubic_exclude = attrs.cubic_exclude
+    extrapolation_value = attrs.extrapolation_value
     out_dtype = attrs.out_dtype
     return [
         topi.image.resize1d(
             inputs[0],
+            roi,
             size,
             layout,
             method,
@@ -51,6 +54,7 @@ def compute_resize1d(attrs, inputs, out_type):
             rounding_method,
             cubic_alpha,
             cubic_exclude,
+            extrapolation_value,
             out_dtype,
         )
     ]
@@ -128,16 +132,19 @@ def resize1d_shape_func(attrs, inputs, _):
 def compute_resize2d(attrs, inputs, out_type):
     """compute definition for resize2d op"""
     size = attrs.size
+    roi = attrs.roi
     layout = attrs.layout
     method = attrs.method
     coord_trans = attrs.coordinate_transformation_mode
     rounding_method = attrs.rounding_method
     cubic_alpha = attrs.cubic_alpha
     cubic_exclude = attrs.cubic_exclude
+    extrapolation_value = attrs.extrapolation_value
     out_dtype = attrs.out_dtype
     return [
         topi.image.resize2d(
             inputs[0],
+            roi,
             size,
             layout,
             method,
@@ -145,6 +152,7 @@ def compute_resize2d(attrs, inputs, out_type):
             rounding_method,
             cubic_alpha,
             cubic_exclude,
+            extrapolation_value,
             out_dtype,
         )
     ]
@@ -225,16 +233,19 @@ def resize2d_shape_func(attrs, inputs, _):
 def compute_resize3d(attrs, inputs, out_type):
     """compute definition for resize3d op"""
     size = attrs.size
+    roi = attrs.roi
     layout = attrs.layout
     method = attrs.method
     coord_trans = attrs.coordinate_transformation_mode
     rounding_method = attrs.rounding_method
     cubic_alpha = attrs.cubic_alpha
     cubic_exclude = attrs.cubic_exclude
+    extrapolation_value = attrs.extrapolation_value
     out_dtype = attrs.out_dtype
     return [
         topi.image.resize3d(
             inputs[0],
+            roi,
             size,
             layout,
             method,
@@ -242,6 +253,7 @@ def compute_resize3d(attrs, inputs, out_type):
             rounding_method,
             cubic_alpha,
             cubic_exclude,
+            extrapolation_value,
             out_dtype,
         )
     ]
diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py
index 7f5bd80..30f8a9c 100644
--- a/python/tvm/relay/op/image/image.py
+++ b/python/tvm/relay/op/image/image.py
@@ -17,18 +17,20 @@
 """Image operations."""
 from . import _make
 from ..dyn.image import _make as _dyn_make
-from ...expr import Expr, Constant
+from ...expr import Expr, Constant, const
 
 
 def resize1d(
     data,
     size,
+    roi=None,
     layout="NCW",
     method="linear",
     coordinate_transformation_mode="half_pixel",
     rounding_method="",
     cubic_alpha=-0.5,
     cubic_exclude=0,
+    extrapolation_value=0.0,
     out_dtype=None,
 ):
     """Image resize1d operator.
@@ -49,6 +51,11 @@ def resize1d(
     size: Tuple of Int or Expr
         The out size to which the image will be resized.
 
+    roi: Tuple of Float or Expr, optional
+        The region of interest for cropping the input image. Expected to be of
+        size 2, and format [start_w, end_w].
+        Only used if coordinate_transformation_mode is tf_crop_and_resize.
+
     layout : str, optional
         Layout of the input.
 
@@ -57,9 +64,10 @@ def resize1d(
 
     coordinate_transformation_mode : string, optional
         Describes how to transform the coordinate in the resized tensor
-        to the coordinate in the original tensor.
-        Refer to the ONNX Resize operator specification for details.
-        [half_pixel, align_corners, asymmetric]
+        to the coordinate in the original tensor. Defintions can be found
+        in topi/image/resize.py.
+        [half_pixel, align_corners, asymmetric, pytorch_half_pixel,
+        tf_half_pixel_for_nn, and tf_crop_and_resize].
 
     rounding_method: string, optional
         indicates how to find the "nearest" pixel in nearest_neighbor method
@@ -69,7 +77,10 @@ def resize1d(
         Spline Coefficient for cubic interpolation
 
     cubic_exclude: int
-            Flag to exclude exterior of the image during cubic interpolation
+        Flag to exclude exterior of the image during cubic interpolation
+
+    extrapolation_value: float
+        Fill value to use when roi is outside of the image
 
     out_dtype : str, optional
         Type to return. If left None returns the same type as input.
@@ -79,19 +90,27 @@ def resize1d(
     result: relay.Expr
         The resized result.
     """
+    if roi is None:
+        roi = [0.0] * 2
     if isinstance(size, Constant):
         size = list(size.data.numpy().astype("int32"))
-    if isinstance(size, Expr):
-        raise NotImplementedError("dyn.resize1d is not yet implemented, got size", size)
+    if isinstance(roi, Constant):
+        roi = list(roi.data.numpy().astype("int32"))
+    if isinstance(size, Expr) or isinstance(roi, Expr):
+        raise NotImplementedError(
+            "dyn.resize1d is not yet implemented, got size", size, "and roi", roi
+        )
     return _make.resize1d(
         data,
         size,
+        roi,
         layout,
         method,
         coordinate_transformation_mode,
         rounding_method,
         cubic_alpha,
         cubic_exclude,
+        extrapolation_value,
         out_dtype,
     )
 
@@ -99,12 +118,14 @@ def resize1d(
 def resize2d(
     data,
     size,
+    roi=None,
     layout="NCHW",
     method="linear",
     coordinate_transformation_mode="half_pixel",
     rounding_method="",
     cubic_alpha=-0.5,
     cubic_exclude=0,
+    extrapolation_value=0.0,
     out_dtype=None,
 ):
     """Image resize2d operator.
@@ -125,6 +146,11 @@ def resize2d(
     size: Tuple of Int or Expr
         The out size to which the image will be resized.
 
+    roi: Tuple of Float or Expr, optional
+        The region of interest for cropping the input image. Expected to be of
+        size 4, and format [start_h, start_w, end_h, end_w].
+        Only used if coordinate_transformation_mode is tf_crop_and_resize.
+
     layout : str, optional
         Layout of the input.
 
@@ -133,9 +159,10 @@ def resize2d(
 
     coordinate_transformation_mode : string, optional
         Describes how to transform the coordinate in the resized tensor
-        to the coordinate in the original tensor.
-        Refer to the ONNX Resize operator specification for details.
-        [half_pixel, align_corners, asymmetric]
+        to the coordinate in the original tensor. Defintions can be found
+        in topi/image/resize.py.
+        [half_pixel, align_corners, asymmetric, pytorch_half_pixel,
+        tf_half_pixel_for_nn, and tf_crop_and_resize].
 
     rounding_method: string, optional
         indicates how to find the "nearest" pixel in nearest_neighbor method
@@ -145,7 +172,10 @@ def resize2d(
         Spline Coefficient for bicubic interpolation
 
     cubic_exclude: int
-            Flag to exclude exterior of the image during bicubic interpolation
+        Flag to exclude exterior of the image during bicubic interpolation
+
+    extrapolation_value: float
+        Fill value to use when roi is outside of the image
 
     out_dtype : str, optional
         Type to return. If left None returns the same type as input.
@@ -155,29 +185,41 @@ def resize2d(
     result: relay.Expr
         The resized result.
     """
+    if roi is None:
+        roi = [0.0] * 4
     if isinstance(size, Constant):
         size = list(size.data.numpy().astype("int32"))
-    if isinstance(size, Expr):
+    if isinstance(roi, Constant):
+        roi = list(roi.data.numpy().astype("float32"))
+    if isinstance(size, Expr) or isinstance(roi, Expr):
+        if not isinstance(size, Expr):
+            size = const(size, "int64")
+        if not isinstance(roi, Expr):
+            roi = const(roi, "float32")
         return _dyn_make.resize2d(
             data,
             size,
+            roi,
             layout,
             method,
             coordinate_transformation_mode,
             rounding_method,
             cubic_alpha,
             cubic_exclude,
+            extrapolation_value,
             out_dtype,
         )
     return _make.resize2d(
         data,
         size,
+        roi,
         layout,
         method,
         coordinate_transformation_mode,
         rounding_method,
         cubic_alpha,
         cubic_exclude,
+        extrapolation_value,
         out_dtype,
     )
 
@@ -185,12 +227,14 @@ def resize2d(
 def resize3d(
     data,
     size,
+    roi=None,
     layout="NCDHW",
     method="linear",
     coordinate_transformation_mode="half_pixel",
     rounding_method="",
     cubic_alpha=-0.5,
     cubic_exclude=0,
+    extrapolation_value=0.0,
     out_dtype=None,
 ):
     """Image resize3d operator.
@@ -211,6 +255,11 @@ def resize3d(
     size: Tuple of Int or Expr
         The out size to which the image will be resized.
 
+    roi: Tuple of Float or Expr, optional
+        The region of interest for cropping the input image. Expected to be of
+        size 6, and format [start_d, start_h, start_w, end_d, end_h, end_w].
+        Only used if coordinate_transformation_mode is tf_crop_and_resize.
+
     layout : str, optional
         Layout of the input.
 
@@ -219,9 +268,10 @@ def resize3d(
 
     coordinate_transformation_mode : string, optional
         Describes how to transform the coordinate in the resized tensor
-        to the coordinate in the original tensor.
-        Refer to the ONNX Resize operator specification for details.
-        [half_pixel, align_corners, asymmetric]
+        to the coordinate in the original tensor. Defintions can be found
+        in topi/image/resize.py.
+        [half_pixel, align_corners, asymmetric, pytorch_half_pixel,
+        tf_half_pixel_for_nn, and tf_crop_and_resize].
 
     rounding_method: string, optional
         indicates how to find the "nearest" pixel in nearest_neighbor method
@@ -231,7 +281,10 @@ def resize3d(
         Spline Coefficient for cubic interpolation
 
     cubic_exclude: int
-            Flag to exclude exterior of the image during cubic interpolation
+        Flag to exclude exterior of the image during cubic interpolation
+
+    extrapolation_value: float
+        Fill value to use when roi is outside of the image
 
     out_dtype : str, optional
         Type to return. If left None returns the same type as input.
@@ -241,19 +294,27 @@ def resize3d(
     result: relay.Expr
         The resized result.
     """
+    if roi is None:
+        roi = [0.0] * 6
     if isinstance(size, Constant):
         size = list(size.data.numpy().astype("int32"))
-    if isinstance(size, Expr):
-        raise NotImplementedError("dyn.resize3d is not yet implemented, got size", size)
+    if isinstance(roi, Constant):
+        roi = list(roi.data.numpy().astype("int32"))
+    if isinstance(size, Expr) or isinstance(roi, Expr):
+        raise NotImplementedError(
+            "dyn.resize3d is not yet implemented, got size", size, "and roi", roi
+        )
     return _make.resize3d(
         data,
         size,
+        roi,
         layout,
         method,
         coordinate_transformation_mode,
         rounding_method,
         cubic_alpha,
         cubic_exclude,
+        extrapolation_value,
         out_dtype,
     )
 
diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py
index d1abffd..4fc6845 100644
--- a/python/tvm/topi/image/resize.py
+++ b/python/tvm/topi/image/resize.py
@@ -119,7 +119,7 @@ def get_3d_pixel(data, layout, image_depth, image_height, image_width, n, c, z,
     return data(n, c, z, y, x, cc).astype("float")
 
 
-def get_inx(x, image_width, target_width, coordinate_transformation_mode):
+def get_inx(x, image_width, target_width, coordinate_transformation_mode, start_x=0, end_x=-1):
     """Infer input x from output x with various coordinate transformation methods"""
     scale_x = te.div(image_width.astype("float"), target_width.astype("float"))
     if coordinate_transformation_mode == "half_pixel":
@@ -132,6 +132,13 @@ def get_inx(x, image_width, target_width, coordinate_transformation_mode):
         in_x = te.if_then_else(target_width > 1, (x + 0.5) * scale_x - 0.5, 0.0)
     elif coordinate_transformation_mode == "tf_half_pixel_for_nn":
         in_x = (x + 0.5) * scale_x
+    elif coordinate_transformation_mode == "tf_crop_and_resize":
+        in_x = te.if_then_else(
+            target_width > 1,
+            start_x * (image_width - 1)
+            + x * (end_x - start_x) * (image_width - 1).astype("float") / (target_width - 1),
+            0.5 * (start_x + end_x) * (image_width - 1),
+        )
     else:
         raise ValueError(
             "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode)
@@ -184,12 +191,13 @@ def _cubic_kernel(inputs, w):
 def _resize_1d(
     indices,
     data,
+    roi,
     image_width,
     target_width,
     boxes=None,
     box_indices=None,
     method=None,
-    extrapolation_value=None,
+    extrapolation_value=0.0,
     layout="NCW",
     coordinate_transformation_mode="align_corners",
     rounding_method="",
@@ -210,6 +218,11 @@ def _resize_1d(
         [batch, channel, in_width]
         or  [batch, in_width, channel]
 
+    roi: Tuple of Float or Expr
+        The region of interest for cropping the input image. Expected to be of
+        size 2, and format [start_w, end_w].
+        Only used if coordinate_transformation_mode is tf_crop_and_resize.
+
     image_width : integer
         Input image width
 
@@ -230,11 +243,14 @@ def _resize_1d(
     layout: string, optional
         "NCW", "NWC", or "NCWc".
 
-    coordinate_transformation_mode: string, optional
+    method: string, optional
+        method of interpolation ("nearest", "linear", "bicubic")
+
+    coordinate_transformation_mode : string, optional
         Describes how to transform the coordinate in the resized tensor
         to the coordinate in the original tensor.
-        Refer to the ONNX Resize operator specification for details.
-        Available options are "half_pixel", "align_corners" and "asymmetric".
+        [half_pixel, align_corners, asymmetric, pytorch_half_pixel,
+        tf_half_pixel_for_nn, and tf_crop_and_resize].
 
     rounding_method: string, optional
         indicates how to find the "nearest" pixel in nearest_neighbor method
@@ -243,7 +259,7 @@ def _resize_1d(
     alpha: float, optional
         Bicubic spline coefficient
 
-    exclude_oiutside: bool, optional:
+    exclude_outside: bool, optional:
         Exclude values outside the image fdor bicubic interpolation
 
     out_dtype: string, optional
@@ -272,6 +288,8 @@ def _resize_1d(
         image_width,
         target_width,
         coordinate_transformation_mode,
+        roi[0],
+        roi[1],
     )
 
     if method == "nearest_neighbor":
@@ -347,7 +365,7 @@ def _resize_1d(
     else:
         raise ValueError("Unknown resize method:", method)
 
-    if extrapolation_value is not None:
+    if coordinate_transformation_mode == "tf_crop_and_resize":
         # use extrapolation_value if in_x is out of boundary
         value = tvm.tir.if_then_else(
             in_x < 0,
@@ -359,6 +377,7 @@ def _resize_1d(
 
 def resize1d(
     data,
+    roi,
     size,
     layout="NCW",
     method="linear",
@@ -366,6 +385,7 @@ def resize1d(
     rounding_method="",
     bicubic_alpha=-0.5,
     bicubic_exclude=0,
+    extrapolation_value=0.0,
     out_dtype=None,
     output_shape=None,
 ):
@@ -378,6 +398,11 @@ def resize1d(
         [batch, channel in_width]
         or  [batch in_width, channel]
 
+    roi: Tuple of Float or Expr
+        The region of interest for cropping the input image. Expected to be of
+        size 2, and format [start_w, end_w].
+        Only used if coordinate_transformation_mode is tf_crop_and_resize.
+
     size: Tuple
         Output resolution scale to
 
@@ -390,8 +415,26 @@ def resize1d(
         Refer to the ONNX Resize operator specification for details.
         Available options are "half_pixel", "align_corners" and "asymmetric".
 
-    method: {"linear", "nearest_neighbor", "cubic"}
-        Method to be used for resizing.
+    method: string, optional
+        method of interpolation ("nearest", "linear", "bicubic")
+
+    coordinate_transformation_mode : string, optional
+        Describes how to transform the coordinate in the resized tensor
+        to the coordinate in the original tensor.
+        [half_pixel, align_corners, asymmetric, pytorch_half_pixel,
+        tf_half_pixel_for_nn, and tf_crop_and_resize].
+
+    rounding_method:
+        Method for rounding coordinate locations
+
+    bicubic_alpha: float, optional
+        Bicubic spline coefficient
+
+    bicubic_exclude: bool, optional:
+        Exclude values outside the image fdor bicubic interpolation
+
+    extrapolation_value: float, optional
+        Value used for extrapolation, when applicable.
 
     out_dtype: string, optional
         Type to return. If left None will be same as input type.
@@ -438,6 +481,7 @@ def resize1d(
         return _resize_1d(
             indices,
             data,
+            roi,
             in_w,
             size[0],
             method=method,
@@ -446,6 +490,7 @@ def resize1d(
             rounding_method=rounding_method,
             alpha=bicubic_alpha,
             exclude_outside=bicubic_exclude,
+            extrapolation_value=extrapolation_value,
             out_dtype=out_dtype,
         )
 
@@ -455,6 +500,7 @@ def resize1d(
 def _resize_2d(
     indices,
     data,
+    roi,
     image_height,
     image_width,
     target_height,
@@ -462,7 +508,7 @@ def _resize_2d(
     boxes=None,
     box_indices=None,
     method=None,
-    extrapolation_value=None,
+    extrapolation_value=0.0,
     layout="NCHW",
     coordinate_transformation_mode="align_corners",
     rounding_method="",
@@ -483,6 +529,11 @@ def _resize_2d(
         [batch, channel, in_height, in_width]
         or  [batch, in_height, in_width, channel]
 
+    roi: Tuple of Float or Expr
+        The region of interest for cropping the input image. Expected to be of
+        size 4, and format [start_h, start_w, end_h, end_w].
+        Only used if coordinate_transformation_mode is tf_crop_and_resize.
+
     image_height : integer
         Input image height
 
@@ -499,6 +550,9 @@ def _resize_2d(
         A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies
         the coordinates of a box.
 
+    method: string, optional
+        method of interpolation ("nearest", "linear", "bicubic")
+
     box_indices : tvm.te.Tensor, optional
         A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that
         the i-th box refers to.
@@ -509,11 +563,11 @@ def _resize_2d(
     layout: string, optional
         "NCHW", "NHWC", or "NCHWc".
 
-    coordinate_transformation_mode: string, optional
+    coordinate_transformation_mode : string, optional
         Describes how to transform the coordinate in the resized tensor
         to the coordinate in the original tensor.
-        Refer to the ONNX Resize operator specification for details.
-        Available options are "half_pixel", "align_corners" and "asymmetric".
+        [half_pixel, align_corners, asymmetric, pytorch_half_pixel,
+        tf_half_pixel_for_nn, and tf_crop_and_resize].
 
     rounding_method: string, optional
         indicates how to find the "nearest" pixel in nearest_neighbor method
@@ -522,7 +576,7 @@ def _resize_2d(
     alpha: float, optional
         Bicubic spline coefficient
 
-    exclude_oiutside: bool, optional:
+    exclude_outside: bool, optional:
         Exclude values outside the image fdor bicubic interpolation
 
     out_dtype: string, optional
@@ -555,8 +609,10 @@ def _resize_2d(
         in_y = y1 * (image_height - 1) + h_scale * y
         in_x = x1 * (image_width - 1) + w_scale * x
     else:
-        in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode)
-        in_y = get_inx(y, image_height, target_height, coordinate_transformation_mode)
+        in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode, roi[1], roi[3])
+        in_y = get_inx(
+            y, image_height, target_height, coordinate_transformation_mode, roi[0], roi[2]
+        )
 
     if method == "nearest_neighbor":
         if rounding_method == "":
@@ -657,7 +713,7 @@ def _resize_2d(
     else:
         raise ValueError("Unknown resize method:", method)
 
-    if extrapolation_value is not None:
+    if coordinate_transformation_mode == "tf_crop_and_resize":
         out = tvm.tir.if_then_else(
             in_y < 0,
             extrapolation_value,
@@ -674,6 +730,7 @@ def _resize_2d(
 
 def resize2d(
     data,
+    roi,
     size,
     layout="NCHW",
     method="linear",
@@ -681,6 +738,7 @@ def resize2d(
     rounding_method="",
     bicubic_alpha=-0.5,
     bicubic_exclude=0,
+    extrapolation_value=0.0,
     out_dtype=None,
     output_shape=None,
 ):
@@ -693,6 +751,11 @@ def resize2d(
         [batch, channel, in_height, in_width]
         or  [batch, in_height, in_width, channel]
 
+    roi: Tuple of Float or Expr
+        The region of interest for cropping the input image. Expected to be of
+        size 4, and format [start_h, start_w, end_h, end_w].
+        Only used if coordinate_transformation_mode is tf_crop_and_resize.
+
     size: Tuple
         Output resolution scale to
 
@@ -705,8 +768,26 @@ def resize2d(
         Refer to the ONNX Resize operator specification for details.
         Available options are "half_pixel", "align_corners" and "asymmetric".
 
-    method: {"linear", "nearest_neighbor", "cubic"}
-        Method to be used for resizing.
+    method: string, optional
+        method of interpolation ("nearest", "linear", "bicubic")
+
+    coordinate_transformation_mode : string, optional
+        Describes how to transform the coordinate in the resized tensor
+        to the coordinate in the original tensor.
+        [half_pixel, align_corners, asymmetric, pytorch_half_pixel,
+        tf_half_pixel_for_nn, and tf_crop_and_resize].
+
+    rounding_method:
+        Method for rounding coordinate locations
+
+    bicubic_alpha: float, optional
+        Bicubic spline coefficient
+
+    bicubic_exclude: bool, optional:
+        Exclude values outside the image fdor bicubic interpolation
+
+    extrapolation_value: float, optional
+        Value used for extrapolation, when applicable.
 
     out_dtype: string, optional
         Type to return. If left None will be same as input type.
@@ -753,6 +834,7 @@ def resize2d(
         return _resize_2d(
             indices,
             data,
+            roi,
             in_h,
             in_w,
             size[0],
@@ -763,6 +845,7 @@ def resize2d(
             rounding_method=rounding_method,
             alpha=bicubic_alpha,
             exclude_outside=bicubic_exclude,
+            extrapolation_value=extrapolation_value,
             out_dtype=out_dtype,
         )
 
@@ -776,7 +859,7 @@ def crop_and_resize(
     crop_size,
     layout="NCHW",
     method="bilinear",
-    extrapolation_value=0,
+    extrapolation_value=None,
     out_dtype=None,
 ):
     """Perform crop and resize operation on the data.
@@ -847,6 +930,7 @@ def crop_and_resize(
         return _resize_2d(
             indices,
             data,
+            [0.0] * 4,
             image_h,
             image_w,
             target_h,
@@ -856,6 +940,7 @@ def crop_and_resize(
             method=method,
             extrapolation_value=extrapolation_value,
             layout=layout,
+            coordinate_transformation_mode="tf_crop_and_resize",
             out_dtype=out_dtype,
         )
 
@@ -865,6 +950,7 @@ def crop_and_resize(
 def _resize_3d(
     indices,
     data,
+    roi,
     image_depth,
     image_height,
     image_width,
@@ -874,7 +960,7 @@ def _resize_3d(
     boxes=None,
     box_indices=None,
     method=None,
-    extrapolation_value=None,
+    extrapolation_value=0.0,
     layout="NCHW",
     coordinate_transformation_mode="align_corners",
     rounding_method="",
@@ -895,6 +981,11 @@ def _resize_3d(
         [batch, channel, in_height, in_width]
         or  [batch, in_height, in_width, channel]
 
+    roi: Tuple of Float or Expr
+        The region of interest for cropping the input image. Expected to be of
+        size 6, and format [start_d, start_h, start_w, end_d, end_h, end_w].
+        Only used if coordinate_transformation_mode is tf_crop_and_resize.
+
     image_depth : integer
         Input image depth
 
@@ -921,17 +1012,20 @@ def _resize_3d(
         A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that
         the i-th box refers to.
 
+    method: string, optional
+        method of interpolation ("nearest", "linear", "bicubic")
+
     extrapolation_value: float, optional
         Value used for extrapolation, when applicable.
 
     layout: string, optional
         "NCHW", "NHWC", or "NCHWc".
 
-    coordinate_transformation_mode: string, optional
+    coordinate_transformation_mode : string, optional
         Describes how to transform the coordinate in the resized tensor
         to the coordinate in the original tensor.
-        Refer to the ONNX Resize operator specification for details.
-        Available options are "half_pixel", "align_corners" and "asymmetric".
+        [half_pixel, align_corners, asymmetric, pytorch_half_pixel,
+        tf_half_pixel_for_nn, and tf_crop_and_resize].
 
     rounding_method: string, optional
         indicates how to find the "nearest" pixel in nearest_neighbor method
@@ -964,9 +1058,9 @@ def _resize_3d(
     if boxes is not None:
         # TODO(mbrookhart): Find an example of this
         raise NotImplementedError("resize1d with image boxes not yet implemented")
-    in_z = get_inx(z, image_depth, target_depth, coordinate_transformation_mode)
-    in_y = get_inx(y, image_height, target_height, coordinate_transformation_mode)
-    in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode)
+    in_z = get_inx(z, image_depth, target_depth, coordinate_transformation_mode, roi[2], roi[5])
+    in_y = get_inx(y, image_height, target_height, coordinate_transformation_mode, roi[1], roi[4])
+    in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode, roi[0], roi[3])
 
     if method == "nearest_neighbor":
         if rounding_method == "":
@@ -1090,7 +1184,7 @@ def _resize_3d(
     else:
         raise ValueError("Unknown resize method:", method)
 
-    if extrapolation_value is not None:
+    if coordinate_transformation_mode == "tf_crop_and_resize":
         out = tvm.tir.if_then_else(
             in_z < 0,
             extrapolation_value,
@@ -1112,6 +1206,7 @@ def _resize_3d(
 
 def resize3d(
     data,
+    roi,
     size,
     layout="NCDHW",
     method="linear",
@@ -1119,6 +1214,7 @@ def resize3d(
     rounding_method="",
     bicubic_alpha=-0.5,
     bicubic_exclude=0,
+    extrapolation_value=0.0,
     out_dtype=None,
     output_shape=None,
 ):
@@ -1131,20 +1227,37 @@ def resize3d(
         [batch, channel, in_depth, in_height, in_width]
         or  [batch, in_depth, in_height, in_width, channel]
 
+    roi: Tuple of Float or Expr
+        The region of interest for cropping the input image. Expected to be of
+        size 6, and format [start_d, start_h, start_w, end_d, end_h, end_w].
+        Only used if coordinate_transformation_mode is tf_crop_and_resize.
+
     size: Tuple
         Output resolution scale to
 
     layout: string, optional
         "NCDHW", "NDHWC", or "NCDHWc".
 
-    coordinate_transformation_mode: string, optional
+    method: string, optional
+        method of interpolation ("nearest", "linear", "bicubic")
+
+    coordinate_transformation_mode : string, optional
         Describes how to transform the coordinate in the resized tensor
         to the coordinate in the original tensor.
-        Refer to the ONNX Resize operator specification for details.
-        Available options are "half_pixel", "align_corners" and "asymmetric".
+        [half_pixel, align_corners, asymmetric, pytorch_half_pixel,
+        tf_half_pixel_for_nn, and tf_crop_and_resize].
 
-    method: {"linear", "nearest_neighbor", "cubic"}
-        Method to be used for resizing.
+    rounding_method:
+        Method for rounding coordinate locations
+
+    bicubic_alpha: float, optional
+        Bicubic spline coefficient
+
+    bicubic_exclude: bool, optional:
+        Exclude values outside the image fdor bicubic interpolation
+
+    extrapolation_value: float, optional
+        Value used for extrapolation, when applicable.
 
     out_dtype: string, optional
         Type to return. If left None will be same as input type.
@@ -1185,6 +1298,7 @@ def resize3d(
         return _resize_3d(
             indices,
             data,
+            roi,
             in_d,
             in_h,
             in_w,
@@ -1197,6 +1311,7 @@ def resize3d(
             rounding_method=rounding_method,
             alpha=bicubic_alpha,
             exclude_outside=bicubic_exclude,
+            extrapolation_value=extrapolation_value,
             out_dtype=out_dtype,
         )
 
diff --git a/python/tvm/topi/nn/upsampling.py b/python/tvm/topi/nn/upsampling.py
index 36b9349..e9c810c 100644
--- a/python/tvm/topi/nn/upsampling.py
+++ b/python/tvm/topi/nn/upsampling.py
@@ -96,6 +96,7 @@ def upsampling(
         method = method[2:]
     return topi.image.resize2d(
         data,
+        [0.0] * 4,
         reshape_size,
         layout=layout,
         method=method,
@@ -194,6 +195,7 @@ def upsampling3d(
         method = method[3:]
     return topi.image.resize3d(
         data,
+        [0.0] * 6,
         resize_shape,
         layout=layout,
         method=method,
diff --git a/src/relay/op/dyn/image/resize.cc b/src/relay/op/dyn/image/resize.cc
index 002105f..1f5f6b4 100644
--- a/src/relay/op/dyn/image/resize.cc
+++ b/src/relay/op/dyn/image/resize.cc
@@ -35,8 +35,8 @@ TVM_REGISTER_NODE_TYPE(Resize2DAttrs);
 
 bool Resize2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                  const TypeReporter& reporter) {
-  // {data, size, out}
-  ICHECK_EQ(types.size(), 3);
+  // {data, size, roi, out}
+  ICHECK_EQ(types.size(), 4);
   const auto* data = types[0].as<TensorTypeNode>();
   if (data == nullptr) return false;
 
@@ -60,15 +60,15 @@ bool Resize2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   }
 
   // assign output type
-  reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), out_dtype));
+  reporter->Assign(types[3], TensorType(layout_converter.BackwardShape(oshape), out_dtype));
   return true;
 }
 
 // Positional relay function to create image operator
 // used by frontend FFI.
-Expr MakeResize2D(Expr data, Expr size, String layout, String method,
+Expr MakeResize2D(Expr data, Expr size, Expr roi, String layout, String method,
                   String coordinate_transformation_mode, String rounding_method, double cubic_alpha,
-                  double cubic_exclude, DataType out_dtype) {
+                  double cubic_exclude, double extrapolation_value, DataType out_dtype) {
   auto attrs = make_object<Resize2DAttrs>();
   attrs->layout = std::move(layout);
   attrs->method = std::move(method);
@@ -76,9 +76,10 @@ Expr MakeResize2D(Expr data, Expr size, String layout, String method,
   attrs->rounding_method = rounding_method;
   attrs->cubic_alpha = cubic_alpha;
   attrs->cubic_exclude = cubic_exclude;
+  attrs->extrapolation_value = extrapolation_value;
   attrs->out_dtype = out_dtype;
   static const Op& op = Op::Get("dyn.image.resize2d");
-  return Call(op, {data, size}, Attrs(attrs), {});
+  return Call(op, {data, size, roi}, Attrs(attrs), {});
 }
 
 TVM_REGISTER_GLOBAL("relay.op.dyn.image._make.resize2d").set_body_typed(MakeResize2D);
@@ -101,9 +102,10 @@ RELAY_REGISTER_OP("dyn.image.resize2d")
            (batch_size, size[0], size[1], channels)
 )code" TVM_ADD_FILELINE)
     .set_attrs_type<Resize2DAttrs>()
-    .set_num_inputs(2)
+    .set_num_inputs(3)
     .add_argument("data", "Tensor", "The input tensor.")
     .add_argument("size", "Tensor", "The output size tensor.")
+    .add_argument("roi", "Tensor", "The region of interest for tf_crop_and_resize.")
     .set_support_level(5)
     .add_type_rel("DynResize2D", Resize2DRel)
     .set_attr<TOpPattern>("TOpPattern", kInjective);
diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc
index ee77984..ca05a4b 100644
--- a/src/relay/op/image/resize.cc
+++ b/src/relay/op/image/resize.cc
@@ -68,6 +68,8 @@ bool Resize1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 
   const Resize1DAttrs* param = attrs.as<Resize1DAttrs>();
   ICHECK(param != nullptr);
+  ICHECK(param->size.size() == 1);
+  ICHECK(param->roi.size() == 2);
   const Layout in_layout(param->layout);
   auto layout_converter = tir::BijectiveLayout(in_layout, kNCW);
   ICHECK(layout_converter.defined())
@@ -89,17 +91,20 @@ bool Resize1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 
 // Positional relay function to create image operator
 // used by frontend FFI.
-Expr MakeResize1D(Expr data, Array<IndexExpr> size, String layout, String method,
-                  String coordinate_transformation_mode, String rounding_method, double cubic_alpha,
-                  int cubic_exclude, DataType out_dtype) {
+Expr MakeResize1D(Expr data, Array<IndexExpr> size, Array<FloatImm> roi, String layout,
+                  String method, String coordinate_transformation_mode, String rounding_method,
+                  double cubic_alpha, int cubic_exclude, double extrapolation_value,
+                  DataType out_dtype) {
   auto attrs = make_object<Resize1DAttrs>();
   attrs->size = std::move(size);
+  attrs->roi = std::move(roi);
   attrs->layout = std::move(layout);
   attrs->method = std::move(method);
   attrs->coordinate_transformation_mode = coordinate_transformation_mode;
   attrs->rounding_method = rounding_method;
   attrs->cubic_alpha = cubic_alpha;
   attrs->cubic_exclude = cubic_exclude;
+  attrs->extrapolation_value = extrapolation_value;
   attrs->out_dtype = out_dtype;
   static const Op& op = Op::Get("image.resize1d");
   return Call(op, {data}, Attrs(attrs), {});
@@ -141,6 +146,8 @@ bool Resize2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 
   const Resize2DAttrs* param = attrs.as<Resize2DAttrs>();
   ICHECK(param != nullptr);
+  ICHECK(param->size.size() == 2);
+  ICHECK(param->roi.size() == 4);
   const Layout in_layout(param->layout);
   auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
   ICHECK(layout_converter.defined())
@@ -163,17 +170,20 @@ bool Resize2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 
 // Positional relay function to create image operator
 // used by frontend FFI.
-Expr MakeResize2D(Expr data, Array<IndexExpr> size, String layout, String method,
-                  String coordinate_transformation_mode, String rounding_method, double cubic_alpha,
-                  int cubic_exclude, DataType out_dtype) {
+Expr MakeResize2D(Expr data, Array<IndexExpr> size, Array<FloatImm> roi, String layout,
+                  String method, String coordinate_transformation_mode, String rounding_method,
+                  double cubic_alpha, int cubic_exclude, double extrapolation_value,
+                  DataType out_dtype) {
   auto attrs = make_object<Resize2DAttrs>();
   attrs->size = std::move(size);
+  attrs->roi = std::move(roi);
   attrs->layout = std::move(layout);
   attrs->method = std::move(method);
   attrs->coordinate_transformation_mode = coordinate_transformation_mode;
   attrs->rounding_method = rounding_method;
   attrs->cubic_alpha = cubic_alpha;
   attrs->cubic_exclude = cubic_exclude;
+  attrs->extrapolation_value = extrapolation_value;
   attrs->out_dtype = out_dtype;
   static const Op& op = Op::Get("image.resize2d");
   return Call(op, {data}, Attrs(attrs), {});
@@ -215,6 +225,8 @@ bool Resize3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 
   const Resize3DAttrs* param = attrs.as<Resize3DAttrs>();
   ICHECK(param != nullptr);
+  ICHECK(param->size.size() == 3);
+  ICHECK(param->roi.size() == 6);
   const Layout in_layout(param->layout);
   auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW);
   ICHECK(layout_converter.defined())
@@ -238,17 +250,20 @@ bool Resize3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 
 // Positional relay function to create image operator
 // used by frontend FFI.
-Expr MakeResize3D(Expr data, Array<IndexExpr> size, String layout, String method,
-                  String coordinate_transformation_mode, String rounding_method, double cubic_alpha,
-                  int cubic_exclude, DataType out_dtype) {
+Expr MakeResize3D(Expr data, Array<IndexExpr> size, Array<FloatImm> roi, String layout,
+                  String method, String coordinate_transformation_mode, String rounding_method,
+                  double cubic_alpha, int cubic_exclude, double extrapolation_value,
+                  DataType out_dtype) {
   auto attrs = make_object<Resize3DAttrs>();
   attrs->size = std::move(size);
+  attrs->roi = std::move(roi);
   attrs->layout = std::move(layout);
   attrs->method = std::move(method);
   attrs->coordinate_transformation_mode = coordinate_transformation_mode;
   attrs->rounding_method = rounding_method;
   attrs->cubic_alpha = cubic_alpha;
   attrs->cubic_exclude = cubic_exclude;
+  attrs->extrapolation_value = extrapolation_value;
   attrs->out_dtype = out_dtype;
   static const Op& op = Op::Get("image.resize3d");
   return Call(op, {data}, Attrs(attrs), {});
diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h
index 43ce665..d02aed7 100644
--- a/src/relay/op/make_op.h
+++ b/src/relay/op/make_op.h
@@ -101,9 +101,10 @@ Expr MakeZeros(Array<Integer> shape, DataType dtype);
 
 Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype);
 
-Expr MakeResize2D(Expr data, Array<IndexExpr> size, String layout, String method,
-                  String coordinate_transformation_mode, String rounding_method, double cubic_alpha,
-                  int cubic_exclude, DataType out_dtype);
+Expr MakeResize2D(Expr data, Array<IndexExpr> size, Array<FloatImm> roi, String layout,
+                  String method, String coordinate_transformation_mode, String rounding_method,
+                  double cubic_alpha, int cubic_exclude, double extrapolation_value,
+                  DataType out_dtype);
 
 Expr MakeSparseToDense(Expr indices, Array<Integer> output_shape, Expr values, Expr default_value);
 
diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc
index 751271d..f3c53cf 100644
--- a/src/relay/transforms/dynamic_to_static.cc
+++ b/src/relay/transforms/dynamic_to_static.cc
@@ -119,16 +119,24 @@ class DynamicToStaticMutator : public MixedModeMutator {
          [this](const CallNode* call_node) {
            auto args = PrepareArgs(call_node);
            if (const ConstantNode* size = args[1].as<ConstantNode>()) {
-             const Resize2DAttrs* param = call_node->attrs.as<Resize2DAttrs>();
-             ICHECK(param);
-             auto size_int = ToVector(size->data);
-             Array<PrimExpr> size_prim;
-             for (size_t i = 0; i < size_int.size(); ++i) {
-               size_prim.push_back(size_int[i]);
+             if (const ConstantNode* roi = args[2].as<ConstantNode>()) {
+               const Resize2DAttrs* param = call_node->attrs.as<Resize2DAttrs>();
+               ICHECK(param);
+               auto size_int = ToVector(size->data);
+               Array<PrimExpr> size_prim;
+               for (size_t i = 0; i < size_int.size(); ++i) {
+                 size_prim.push_back(size_int[i]);
+               }
+               auto roi_vec = ToFloatVector(roi->data);
+               Array<FloatImm> roi_prim;
+               for (size_t i = 0; i < roi_vec.size(); ++i) {
+                 roi_prim.push_back(roi_vec[i]);
+               }
+               return MakeResize2D(call_node->args[0], size_prim, roi_prim, param->layout,
+                                   param->method, param->coordinate_transformation_mode,
+                                   param->rounding_method, param->cubic_alpha, param->cubic_exclude,
+                                   param->extrapolation_value, param->out_dtype);
              }
-             return MakeResize2D(call_node->args[0], size_prim, param->layout, param->method,
-                                 param->coordinate_transformation_mode, param->rounding_method,
-                                 param->cubic_alpha, param->cubic_exclude, param->out_dtype);
            }
            return Expr(nullptr);
          }},
diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h
index 03b8ee6..69ad20a 100644
--- a/src/relay/transforms/pattern_utils.h
+++ b/src/relay/transforms/pattern_utils.h
@@ -451,6 +451,23 @@ static inline Array<Integer> ToVector(const runtime::NDArray& array) {
 }
 
 /*!
+ * \brief Convert a NDArray with type int or float to Array<FloatImm>.
+ * \param array Input NDArray
+ * \return Converted Array.
+ */
+static inline Array<FloatImm> ToFloatVector(const runtime::NDArray& array) {
+  size_t ndim = array.Shape().size();
+  ICHECK_EQ(ndim, 1) << "This function should only be used for 1D NDArrays";
+  size_t len = array.Shape().front();
+  Array<FloatImm> out;
+  for (size_t i = 0; i < len; ++i) {
+    long double elem_val = ToScalar(array, i);
+    out.push_back(FloatImm(DataType::Float(32), static_cast<float>(elem_val)));
+  }
+  return out;
+}
+
+/*!
  * \brief Convert a NDArray with type int or float to Array<Array<Integer>>.
  * \param array Input NDArray
  * \return Converted Array.
diff --git a/tests/python/contrib/test_onnx.py b/tests/python/contrib/test_onnx.py
index 6f23228..214166c 100644
--- a/tests/python/contrib/test_onnx.py
+++ b/tests/python/contrib/test_onnx.py
@@ -663,6 +663,7 @@ def test_resize():
         y = relay.image.resize2d(
             x,
             outsize,
+            None,
             layout="NCHW",
             method=method,
             coordinate_transformation_mode=coord_trans,
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index f8870ed..0531bfc 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -4958,7 +4958,6 @@ unsupported_onnx_tests = [
     "test_reduce_sum_keepdims_random",
     "test_reduce_sum_negative_axes_keepdims_example",
     "test_reduce_sum_negative_axes_keepdims_random",
-    "test_resize_tf_crop_and_resize",
     "test_rnn_seq_length",
     "test_round",
     "test_scan9_sum",
diff --git a/tests/python/relay/dyn/test_dynamic_op_level5.py b/tests/python/relay/dyn/test_dynamic_op_level5.py
index c29ea2c..2eeeb1d 100644
--- a/tests/python/relay/dyn/test_dynamic_op_level5.py
+++ b/tests/python/relay/dyn/test_dynamic_op_level5.py
@@ -51,7 +51,7 @@ def test_resize2d():
 
         coord_trans = "asymmetric" if method == "nearest_neighbor" else "align_corners"
         z = relay.image.resize2d(
-            x, size_var, layout, method, coordinate_transformation_mode=coord_trans
+            x, size_var, None, layout, method, coordinate_transformation_mode=coord_trans
         )
 
         zz = run_infer_type(z)
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index f42f7ad..c04520f 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -1563,7 +1563,7 @@ def verify_any_resize2d(data_shape, scale, layout, static_data_shape, ref_out_sh
         size = (data_shape[1] * scale, data_shape[2] * scale)
     else:
         size = (data_shape[2] * scale, data_shape[3] * scale)
-    y = relay.image.resize2d(data, size, layout)
+    y = relay.image.resize2d(data, size, None, layout)
     mod["main"] = relay.Function([data], y)
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py
index c968c5a..6ad3b38 100644
--- a/tests/python/relay/test_op_level5.py
+++ b/tests/python/relay/test_op_level5.py
@@ -40,7 +40,7 @@ def test_resize1d_infer_type():
     assert zz.checked_type == relay.TensorType((n, c, tw), "int8")
 
     x = relay.var("x", relay.TensorType((n, c, w), "int8"))
-    z = relay.image.resize1d(x, (200,), "NCW", "linear", "align_corners")
+    z = relay.image.resize1d(x, (200,), None, "NCW", "linear", "align_corners")
     assert "size=" in z.astext()
     zz = run_infer_type(z)
     assert zz.checked_type == relay.TensorType((n, c, 200), "int8")
@@ -83,7 +83,7 @@ class TestResize1D:
         )
         x = relay.var("x", relay.TensorType(dshape, "float32"))
         z = relay.image.resize1d(
-            x, size, layout, interpolate_method, coordinate_transformation_mode=coord_trans
+            x, size, None, layout, interpolate_method, coordinate_transformation_mode=coord_trans
         )
         assert "size=" in z.astext()
         zz = run_infer_type(z)
@@ -104,7 +104,7 @@ def test_resize2d_infer_type():
     assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8")
 
     x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
-    z = relay.image.resize2d(x, (100, 200), "NCHW", "linear", "align_corners")
+    z = relay.image.resize2d(x, (100, 200), None, "NCHW", "linear", "align_corners")
     assert "size=" in z.astext()
     zz = run_infer_type(z)
     assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8")
@@ -148,7 +148,7 @@ class TestResize2D:
         )
         x = relay.var("x", relay.TensorType(dshape, "float32"))
         z = relay.image.resize2d(
-            x, size, layout, interpolate_method, coordinate_transformation_mode=coord_trans
+            x, size, None, layout, interpolate_method, coordinate_transformation_mode=coord_trans
         )
         assert "size=" in z.astext()
         zz = run_infer_type(z)
@@ -175,7 +175,7 @@ def test_resize3d_infer_type():
     assert zz.checked_type == relay.TensorType((n, c, td, th, tw), "int8")
 
     x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8"))
-    z = relay.image.resize3d(x, (10, 10, 20), "NCDHW", "linear", "align_corners")
+    z = relay.image.resize3d(x, (10, 10, 20), None, "NCDHW", "linear", "align_corners")
     assert "size=" in z.astext()
     zz = run_infer_type(z)
     assert zz.checked_type == relay.TensorType((n, c, 10, 10, 20), "int8")
@@ -204,7 +204,7 @@ class TestResize3D:
             x_data, (scale, scale, scale), layout, interpolate_method, coord_trans
         )
         x = relay.var("x", relay.TensorType(dshape, "float32"))
-        z = relay.image.resize3d(x, size, layout, interpolate_method, coord_trans)
+        z = relay.image.resize3d(x, size, None, layout, interpolate_method, coord_trans)
         assert "size=" in z.astext()
         zz = run_infer_type(z)
         assert zz.checked_type == relay.TensorType(ref_res.shape, "float32")
diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py
index 5b61733..f523ad2 100644
--- a/tests/python/relay/test_pass_dynamic_to_static.py
+++ b/tests/python/relay/test_pass_dynamic_to_static.py
@@ -291,7 +291,7 @@ def test_dynamic_to_static_resize2d():
         size_var = relay.var("size", relay.TensorType((len(size),), "float32"))
         coord_trans = "asymmetric" if method == "nearest_neighbor" else "align_corners"
         z = relay.image.resize2d(
-            x, size_var, layout, method, coordinate_transformation_mode=coord_trans
+            x, size_var, None, layout, method, coordinate_transformation_mode=coord_trans
         )
         params = {"size": np.array(size).astype("float32")}
 
diff --git a/tests/python/topi/python/test_topi_image.py b/tests/python/topi/python/test_topi_image.py
index fe7fba5..062cf79 100644
--- a/tests/python/topi/python/test_topi_image.py
+++ b/tests/python/topi/python/test_topi_image.py
@@ -49,6 +49,7 @@ def verify_resize2d(
         raise NotImplementedError("Layout not supported {} ".format(layout))
     B = topi.image.resize2d(
         A,
+        [0.0] * 4,
         (out_height, out_width),
         layout=layout,
         coordinate_transformation_mode=coord_trans,
@@ -127,6 +128,7 @@ def verify_resize3d(
 
     B = topi.image.resize3d(
         A,
+        [0.0] * 6,
         (out_depth, out_height, out_width),
         layout=layout,
         coordinate_transformation_mode=coordinate_transformation_mode,