You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@thrift.apache.org by yu...@apache.org on 2022/02/23 17:24:01 UTC

[thrift] branch master updated: THRIFT-5527: Don't swallow idl exceptions in Process function

This is an automated email from the ASF dual-hosted git repository.

yuxuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new 9bee877  THRIFT-5527: Don't swallow idl exceptions in Process function
9bee877 is described below

commit 9bee877e663f11f4cbdd3a4f02938c8ab9fe8976
Author: Yuxuan 'fishy' Wang <yu...@reddit.com>
AuthorDate: Tue Feb 22 18:48:17 2022 -0800

    THRIFT-5527: Don't swallow idl exceptions in Process function
    
    Client: go
    
    This allows ProcessorMiddlewares to access such exceptions, unless
    there's a network error writing the response (which takes priority).
    
    While I'm here, also make the indentation of Process function more
    consistent, and make it consistent on returning false and an error when
    the reading/writing fails.
---
 compiler/cpp/src/thrift/generate/t_go_generator.cc | 144 +++++++++++++++------
 lib/go/test/Makefile.am                            |   8 +-
 lib/go/test/ProcessorMiddlewareTest.thrift         |  32 +++++
 lib/go/test/tests/processor_middleware_test.go     | 108 ++++++++++++++++
 4 files changed, 249 insertions(+), 43 deletions(-)

diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc
index 7897b62..3b885f1 100644
--- a/compiler/cpp/src/thrift/generate/t_go_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc
@@ -959,8 +959,8 @@ string t_go_generator::go_imports_begin(bool consts) {
   // If not writing constants, and there are enums, need extra imports.
   if (!consts && get_program()->get_enums().size() > 0) {
     system_packages.push_back("database/sql/driver");
-    system_packages.push_back("errors");
   }
+  system_packages.push_back("errors");
   system_packages.push_back("fmt");
   system_packages.push_back("time");
   // For the thrift import, always do rename import to make sure it's called thrift.
@@ -980,6 +980,7 @@ string t_go_generator::go_imports_end() {
       "// (needed to ensure safety because of naive import list construction.)\n"
       "var _ = thrift.ZERO\n"
       "var _ = fmt.Printf\n"
+      "var _ = errors.New\n"
       "var _ = context.Background\n"
       "var _ = time.Now\n"
       "var _ = bytes.Equal\n\n");
@@ -2964,21 +2965,27 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
              << ") Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err "
                 "thrift.TException) {" << endl;
   indent_up();
+  string write_err;
+  if (!tfunction->is_oneway()) {
+    write_err = tmp("_write_err");
+    f_types_ << indent() << "var " << write_err << " error" << endl;
+  }
   f_types_ << indent() << "args := " << argsname << "{}" << endl;
-  f_types_ << indent() << "var err2 error" << endl;
-  f_types_ << indent() << "if err2 = args." << read_method_name_ <<  "(ctx, iprot); err2 != nil {" << endl;
-  f_types_ << indent() << "  iprot.ReadMessageEnd(ctx)" << endl;
+  f_types_ << indent() << "if err2 := args." << read_method_name_ <<  "(ctx, iprot); err2 != nil {" << endl;
+  indent_up();
+  f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl;
   if (!tfunction->is_oneway()) {
     f_types_ << indent()
-               << "  x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())"
+               << "x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())"
                << endl;
-    f_types_ << indent() << "  oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
+    f_types_ << indent() << "oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
                << "\", thrift.EXCEPTION, seqId)" << endl;
-    f_types_ << indent() << "  x.Write(ctx, oprot)" << endl;
-    f_types_ << indent() << "  oprot.WriteMessageEnd(ctx)" << endl;
-    f_types_ << indent() << "  oprot.Flush(ctx)" << endl;
+    f_types_ << indent() << "x.Write(ctx, oprot)" << endl;
+    f_types_ << indent() << "oprot.WriteMessageEnd(ctx)" << endl;
+    f_types_ << indent() << "oprot.Flush(ctx)" << endl;
   }
-  f_types_ << indent() << "  return false, thrift.WrapTException(err2)" << endl;
+  f_types_ << indent() << "return false, thrift.WrapTException(err2)" << endl;
+  indent_down();
   f_types_ << indent() << "}" << endl;
   f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl << endl;
 
@@ -3037,9 +3044,6 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
     f_types_ << indent() << "result := " << resultname << "{}" << endl;
   }
   bool need_reference = type_need_reference(tfunction->get_returntype());
-  if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) {
-    f_types_ << indent() << "var retval " << type_to_go_type(tfunction->get_returntype()) << endl;
-  }
 
   f_types_ << indent() << "if ";
 
@@ -3053,7 +3057,7 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
   t_struct* arg_struct = tfunction->get_arglist();
   const std::vector<t_field*>& fields = arg_struct->get_members();
   vector<t_field*>::const_iterator f_iter;
-  f_types_ << "err2 = p.handler." << publicize(tfunction->get_name()) << "(";
+  f_types_ << "err2 := p.handler." << publicize(tfunction->get_name()) << "(";
   bool first = true;
 
   f_types_ << "ctx";
@@ -3069,7 +3073,9 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
   }
 
   f_types_ << "); err2 != nil {" << endl;
-  f_types_ << indent() << "  tickerCancel()" << endl;
+  indent_up();
+  f_types_ << indent() << "tickerCancel()" << endl;
+  f_types_ << indent() << "err = thrift.WrapTException(err2)" << endl;
 
   t_struct* exceptions = tfunction->get_xceptions();
   const vector<t_field*>& x_fields = exceptions->get_members();
@@ -3079,36 +3085,74 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
     vector<t_field*>::const_iterator xf_iter;
 
     for (xf_iter = x_fields.begin(); xf_iter != x_fields.end(); ++xf_iter) {
-      f_types_ << indent() << "  case " << type_to_go_type(((*xf_iter)->get_type())) << ":"
+      f_types_ << indent() << "case " << type_to_go_type(((*xf_iter)->get_type())) << ":"
                  << endl;
+      indent_up();
       f_types_ << indent() << "result." << publicize((*xf_iter)->get_name()) << " = v" << endl;
+      indent_down();
     }
 
-    f_types_ << indent() << "  default:" << endl;
+    f_types_ << indent() << "default:" << endl;
+    indent_up();
   }
 
   if (!tfunction->is_oneway()) {
     // Avoid writing the error to the wire if it's ErrAbandonRequest
-    f_types_ << indent() << "  if err2 == thrift.ErrAbandonRequest {" << endl;
-    f_types_ << indent() << "    return false, thrift.WrapTException(err2)" << endl;
-    f_types_ << indent() << "  }" << endl;
+    f_types_ << indent() << "if errors.Is(err2, thrift.ErrAbandonRequest) {" << endl;
+    indent_up();
+    f_types_ << indent() << "return false, thrift.WrapTException(err2)" << endl;
+    indent_down();
+    f_types_ << indent() << "}" << endl;
 
-    f_types_ << indent() << "  x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "
+    string exc(tmp("_exc"));
+    f_types_ << indent() << exc << " := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "
                               "\"Internal error processing " << escape_string(tfunction->get_name())
                << ": \" + err2.Error())" << endl;
-    f_types_ << indent() << "  oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
-               << "\", thrift.EXCEPTION, seqId)" << endl;
-    f_types_ << indent() << "  x.Write(ctx, oprot)" << endl;
-    f_types_ << indent() << "  oprot.WriteMessageEnd(ctx)" << endl;
-    f_types_ << indent() << "  oprot.Flush(ctx)" << endl;
-  }
 
-  f_types_ << indent() << "  return true, thrift.WrapTException(err2)" << endl;
+    f_types_ << indent() << "if err2 := oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
+               << "\", thrift.EXCEPTION, seqId); err2 != nil {" << endl;
+    indent_up();
+    f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+    indent_down();
+    f_types_ << indent() << "}" << endl;
 
-  if (!x_fields.empty()) {
+    f_types_ << indent() << "if err2 := " << exc << ".Write(ctx, oprot); "
+               << write_err << " == nil && err2 != nil {" << endl;
+    indent_up();
+    f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+    indent_down();
+    f_types_ << indent() << "}" << endl;
+
+    f_types_ << indent() << "if err2 := oprot.WriteMessageEnd(ctx); "
+               << write_err << " == nil && err2 != nil {" << endl;
+    indent_up();
+    f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+    indent_down();
+    f_types_ << indent() << "}" << endl;
+
+    f_types_ << indent() << "if err2 := oprot.Flush(ctx); "
+               << write_err << " == nil && err2 != nil {" << endl;
+    indent_up();
+    f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+    indent_down();
+    f_types_ << indent() << "}" << endl;
+
+    f_types_ << indent() << "if " << write_err << " != nil {" << endl;
+    indent_up();
+    f_types_ << indent() << "return false, thrift.WrapTException(" << write_err << ")" << endl;
+    indent_down();
     f_types_ << indent() << "}" << endl;
+
+    // return success=true as long as writing to the wire was successful.
+    f_types_ << indent() << "return true, err" << endl;
   }
 
+  if (!x_fields.empty()) {
+    indent_down();
+    f_types_ << indent() << "}" << endl; // closes switch
+  }
+
+  indent_down();
   f_types_ << indent() << "}"; // closes err2 != nil
 
   if (!tfunction->is_oneway()) {
@@ -3126,29 +3170,47 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
       f_types_ << endl;
     }
     f_types_ << indent() << "tickerCancel()" << endl;
-    f_types_ << indent() << "if err2 = oprot.WriteMessageBegin(ctx, \""
+
+    f_types_ << indent() << "if err2 := oprot.WriteMessageBegin(ctx, \""
                << escape_string(tfunction->get_name()) << "\", thrift.REPLY, seqId); err2 != nil {"
                << endl;
-    f_types_ << indent() << "  err = thrift.WrapTException(err2)" << endl;
+    indent_up();
+    f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+    indent_down();
     f_types_ << indent() << "}" << endl;
-    f_types_ << indent() << "if err2 = result." << write_method_name_ << "(ctx, oprot); err == nil && err2 != nil {" << endl;
-    f_types_ << indent() << "  err = thrift.WrapTException(err2)" << endl;
+
+    f_types_ << indent() << "if err2 := result." << write_method_name_ << "(ctx, oprot); "
+               << write_err << " == nil && err2 != nil {" << endl;
+    indent_up();
+    f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+    indent_down();
     f_types_ << indent() << "}" << endl;
-    f_types_ << indent() << "if err2 = oprot.WriteMessageEnd(ctx); err == nil && err2 != nil {"
-               << endl;
-    f_types_ << indent() << "  err = thrift.WrapTException(err2)" << endl;
+
+    f_types_ << indent() << "if err2 := oprot.WriteMessageEnd(ctx); "
+               << write_err << " == nil && err2 != nil {" << endl;
+    indent_up();
+    f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+    indent_down();
     f_types_ << indent() << "}" << endl;
-    f_types_ << indent() << "if err2 = oprot.Flush(ctx); err == nil && err2 != nil {" << endl;
-    f_types_ << indent() << "  err = thrift.WrapTException(err2)" << endl;
+
+    f_types_ << indent() << "if err2 := oprot.Flush(ctx); " << write_err << " == nil && err2 != nil {" << endl;
+    indent_up();
+    f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+    indent_down();
     f_types_ << indent() << "}" << endl;
-    f_types_ << indent() << "if err != nil {" << endl;
-    f_types_ << indent() << "  return" << endl;
+
+    f_types_ << indent() << "if " << write_err << " != nil {" << endl;
+    indent_up();
+    f_types_ << indent() << "return false, thrift.WrapTException(" << write_err << ")" << endl;
+    indent_down();
     f_types_ << indent() << "}" << endl;
+
+    // return success=true as long as writing to the wire was successful.
     f_types_ << indent() << "return true, err" << endl;
   } else {
     f_types_ << endl;
     f_types_ << indent() << "tickerCancel()" << endl;
-    f_types_ << indent() << "return true, nil" << endl;
+    f_types_ << indent() << "return true, err" << endl;
   }
   indent_down();
   f_types_ << indent() << "}" << endl << endl;
diff --git a/lib/go/test/Makefile.am b/lib/go/test/Makefile.am
index 4b3ecda..2cca411 100644
--- a/lib/go/test/Makefile.am
+++ b/lib/go/test/Makefile.am
@@ -52,7 +52,8 @@ gopath: $(THRIFT) $(THRIFTTEST) \
 				EqualsTest.thrift \
 				ConflictArgNamesTest.thrift \
 				ConstOptionalFieldImport.thrift \
-				ConstOptionalField.thrift
+				ConstOptionalField.thrift \
+				ProcessorMiddlewareTest.thrift
 	mkdir -p gopath/src
 	grep -v list.*map.*list.*map $(THRIFTTEST) | grep -v 'set<Insanity>' > ThriftTest.thrift
 	$(THRIFT) $(THRIFTARGS) -r IncludesTest.thrift
@@ -84,6 +85,7 @@ gopath: $(THRIFT) $(THRIFTTEST) \
 	$(THRIFT) $(THRIFTARGS) EqualsTest.thrift
 	$(THRIFT) $(THRIFTARGS) ConflictArgNamesTest.thrift
 	$(THRIFT) $(THRIFTARGS) -r ConstOptionalField.thrift
+	$(THRIFT) $(THRIFTARGS) ProcessorMiddlewareTest.thrift
 	ln -nfs ../../tests gopath/src/tests
 	cp -r ./dontexportrwtest gopath/src
 	touch gopath
@@ -106,7 +108,8 @@ check: gopath
 				./gopath/src/servicestest/container_test-remote \
 				./gopath/src/duplicateimportstest \
 				./gopath/src/equalstest \
-				./gopath/src/conflictargnamestest
+				./gopath/src/conflictargnamestest \
+				./gopath/src/processormiddlewaretest
 	$(GO) test -mod=mod github.com/apache/thrift/lib/go/thrift
 	$(GO) test -mod=mod ./gopath/src/tests ./gopath/src/dontexportrwtest
 
@@ -145,6 +148,7 @@ EXTRA_DIST = \
 	NamesTest.thrift \
 	OnewayTest.thrift \
 	OptionalFieldsTest.thrift \
+	ProcessorMiddlewareTest.thrift \
 	RefAnnotationFieldsTest.thrift \
 	RequiredFieldTest.thrift \
 	ServicesTest.thrift \
diff --git a/lib/go/test/ProcessorMiddlewareTest.thrift b/lib/go/test/ProcessorMiddlewareTest.thrift
new file mode 100644
index 0000000..2d4f5f4
--- /dev/null
+++ b/lib/go/test/ProcessorMiddlewareTest.thrift
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ * Contains some contributions under the Thrift Software License.
+ * Please see doc/old-thrift-license.txt in the Thrift distribution for
+ * details.
+ */
+
+exception Error {
+  1: optional string foo,
+}
+
+service Service {
+  void ping() throws (
+    1: Error error,
+  );
+}
diff --git a/lib/go/test/tests/processor_middleware_test.go b/lib/go/test/tests/processor_middleware_test.go
new file mode 100644
index 0000000..1bd911c
--- /dev/null
+++ b/lib/go/test/tests/processor_middleware_test.go
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package tests
+
+import (
+	"context"
+	"errors"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/apache/thrift/lib/go/test/gopath/src/processormiddlewaretest"
+	"github.com/apache/thrift/lib/go/thrift"
+)
+
+const errorMessage = "foo error"
+
+type serviceImpl struct{}
+
+func (serviceImpl) Ping(_ context.Context) (err error) {
+	return &processormiddlewaretest.Error{
+		Foo: thrift.StringPtr(errorMessage),
+	}
+}
+
+func middleware(t *testing.T) thrift.ProcessorMiddleware {
+	return func(name string, next thrift.TProcessorFunction) thrift.TProcessorFunction {
+		return thrift.WrappedTProcessorFunction{
+			Wrapped: func(ctx context.Context, seqId int32, in, out thrift.TProtocol) (_ bool, err thrift.TException) {
+				defer func() {
+					checkError(t, err)
+				}()
+				return next.Process(ctx, seqId, in, out)
+			},
+		}
+	}
+}
+
+func checkError(tb testing.TB, err error) {
+	tb.Helper()
+
+	var idlErr *processormiddlewaretest.Error
+	if !errors.As(err, &idlErr) {
+		tb.Errorf("expected error to be of type *processormiddlewaretest.Error, actual %T, %#v", err, err)
+		return
+	}
+	if actual := idlErr.GetFoo(); actual != errorMessage {
+		tb.Errorf("expected error message to be %q, actual %q", errorMessage, actual)
+	}
+}
+
+func TestProcessorMiddleware(t *testing.T) {
+	const timeout = time.Second
+
+	processor := processormiddlewaretest.NewServiceProcessor(&serviceImpl{})
+	serverTransport, err := thrift.NewTServerSocket("127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("Could not find available server port: %v", err)
+	}
+	server := thrift.NewTSimpleServer4(
+		thrift.WrapProcessor(processor, middleware(t)),
+		serverTransport,
+		thrift.NewTHeaderTransportFactoryConf(nil, nil),
+		thrift.NewTHeaderProtocolFactoryConf(nil),
+	)
+	defer server.Stop()
+	var wg sync.WaitGroup
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		server.Serve()
+	}()
+
+	time.Sleep(10 * time.Millisecond)
+
+	cfg := &thrift.TConfiguration{
+		ConnectTimeout: timeout,
+		SocketTimeout:  timeout,
+	}
+	transport := thrift.NewTSocketFromAddrConf(serverTransport.Addr(), cfg)
+	if err := transport.Open(); err != nil {
+		t.Fatalf("Could not open client transport: %v", err)
+	}
+	defer transport.Close()
+	protocol := thrift.NewTHeaderProtocolConf(transport, nil)
+
+	client := processormiddlewaretest.NewServiceClient(thrift.NewTStandardClient(protocol, protocol))
+
+	err = client.Ping(context.Background())
+	checkError(t, err)
+}