فهرست منبع

fix: improve type safety in httpWithTLS for custom RoundTripper (#861)

* fix: improve type safety in httpWithTLS for custom RoundTripper

Add type assertion check to handle cases where DefaultHTTPClient.Transport
is a custom http.RoundTripper implementation (not *http.Transport).

This improves upon the fix in PR #844 which only handled nil Transport.
The previous code would still panic if users set a custom RoundTripper:

  trans := baseTransport.(*http.Transport).Clone()  // panic if not *http.Transport

Now safely handles three scenarios:
1. Transport is nil -> use http.DefaultTransport
2. Transport is *http.Transport -> clone it
3. Transport is custom RoundTripper -> use http.DefaultTransport

Added comprehensive test cases:
- TestHttpWithTLS_NilTransport
- TestHttpWithTLS_CustomTransport
- TestHttpWithTLS_CustomRoundTripper

Related to #803

* refactor: reduce code duplication and complexity in httpWithTLS

- Eliminate duplicate http.DefaultTransport.Clone() calls
- Reduce cyclomatic complexity by simplifying conditional logic
- Use nil check pattern instead of nested else branches
- Maintain same functionality with cleaner code structure

This addresses golangci-lint warnings for dupl and gocyclo.

* fix: add newline at end of http_test.go

Fix gofmt -s compliance issue:
- File must end with newline character
- Addresses golangci-lint gofmt error on line 81

This fixes CI check failure.
is-Xiaoen 8 ماه پیش
والد
کامیت
30c8e77246
2فایلهای تغییر یافته به همراه92 افزوده شده و 5 حذف شده
  1. 11 5
      util/http.go
  2. 81 0
      util/http_test.go

+ 11 - 5
util/http.go

@@ -292,13 +292,19 @@ func httpWithTLS(rootCa, key string) (*http.Client, error) {
 		Certificates: []tls.Certificate{cert},
 	}
 
-	var baseTransport http.RoundTripper
+	// 安全地获取 *http.Transport
+	var trans *http.Transport
+	// 尝试从 DefaultHTTPClient 获取 Transport,如果失败则使用默认值
 	if DefaultHTTPClient.Transport != nil {
-		baseTransport = DefaultHTTPClient.Transport
-	} else {
-		baseTransport = http.DefaultTransport
+		if t, ok := DefaultHTTPClient.Transport.(*http.Transport); ok {
+			trans = t.Clone()
+		}
+	}
+	// 如果无法获取有效的 Transport,使用默认值
+	if trans == nil {
+		trans = http.DefaultTransport.(*http.Transport).Clone()
 	}
-	trans := baseTransport.(*http.Transport).Clone()
+
 	trans.TLSClientConfig = config
 	trans.DisableCompression = true
 	client = &http.Client{Transport: trans}

+ 81 - 0
util/http_test.go

@@ -0,0 +1,81 @@
+package util
+
+import (
+	"net/http"
+	"testing"
+)
+
+// TestHttpWithTLS_NilTransport tests the scenario where DefaultHTTPClient.Transport is nil
+func TestHttpWithTLS_NilTransport(t *testing.T) {
+	// Save original transport
+	originalTransport := DefaultHTTPClient.Transport
+	defer func() {
+		DefaultHTTPClient.Transport = originalTransport
+	}()
+
+	// Set Transport to nil to simulate the bug scenario
+	DefaultHTTPClient.Transport = nil
+
+	// This should not panic after the fix
+	// Note: This will fail due to invalid cert path, but shouldn't panic on type assertion
+	_, err := httpWithTLS("./testdata/invalid_cert.p12", "password")
+
+	// We expect an error (cert file not found), but NOT a panic
+	if err == nil {
+		t.Error("Expected error due to invalid cert path, but got nil")
+	}
+}
+
+// TestHttpWithTLS_CustomTransport tests the scenario where DefaultHTTPClient has a custom Transport
+func TestHttpWithTLS_CustomTransport(t *testing.T) {
+	// Save original transport
+	originalTransport := DefaultHTTPClient.Transport
+	defer func() {
+		DefaultHTTPClient.Transport = originalTransport
+	}()
+
+	// Set a custom http.Transport
+	customTransport := &http.Transport{
+		MaxIdleConns: 100,
+	}
+	DefaultHTTPClient.Transport = customTransport
+
+	// This should not panic
+	_, err := httpWithTLS("./testdata/invalid_cert.p12", "password")
+
+	// We expect an error (cert file not found), but NOT a panic
+	if err == nil {
+		t.Error("Expected error due to invalid cert path, but got nil")
+	}
+}
+
+// CustomRoundTripper is a custom implementation of http.RoundTripper
+type CustomRoundTripper struct{}
+
+func (c *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+	return http.DefaultTransport.RoundTrip(req)
+}
+
+// TestHttpWithTLS_CustomRoundTripper tests the edge case where DefaultHTTPClient has a custom RoundTripper
+// that is NOT *http.Transport
+func TestHttpWithTLS_CustomRoundTripper(t *testing.T) {
+	// Save original transport
+	originalTransport := DefaultHTTPClient.Transport
+	defer func() {
+		DefaultHTTPClient.Transport = originalTransport
+	}()
+
+	// Set a custom RoundTripper that is NOT *http.Transport
+	customRoundTripper := &CustomRoundTripper{}
+	DefaultHTTPClient.Transport = customRoundTripper
+
+	// Create a recovery handler to catch potential panic
+	defer func() {
+		if r := recover(); r != nil {
+			t.Errorf("httpWithTLS panicked with custom RoundTripper: %v", r)
+		}
+	}()
+
+	// This might panic if the code doesn't handle non-*http.Transport RoundTripper properly
+	_, _ = httpWithTLS("./testdata/invalid_cert.p12", "password")
+}