You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by gg...@apache.org on 2023/11/10 20:36:33 UTC

(commons-compress) branch master updated: [COMPRESS-648] Add ability to restrict autodetection in CompressorStreamFactory (#433)

This is an automated email from the ASF dual-hosted git repository.

ggregory pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-compress.git


The following commit(s) were added to refs/heads/master by this push:
     new ab8316c57 [COMPRESS-648] Add ability to restrict autodetection in CompressorStreamFactory (#433)
ab8316c57 is described below

commit ab8316c57b0a1ce62b6c58c7d459132e6e735be9
Author: Yakov Shafranovich <ya...@users.noreply.github.com>
AuthorDate: Fri Nov 10 15:36:28 2023 -0500

    [COMPRESS-648] Add ability to restrict autodetection in CompressorStreamFactory (#433)
    
    * Changes for COMPRESS-648
    
    * Added comment
    
    * refactored
    
    * Removed line breaks and changed one method to package-private
    
    * refactoring test
    
    * Removed unused import
    
    ---------
    
    Co-authored-by: Yakov Shafranovich <ya...@amazon.com>
---
 .../compressors/CompressorStreamFactory.java       |  62 ++++++++--
 .../compress/compressors/DetectCompressorTest.java | 131 ++++++++++++++++++---
 2 files changed, 164 insertions(+), 29 deletions(-)

diff --git a/src/main/java/org/apache/commons/compress/compressors/CompressorStreamFactory.java b/src/main/java/org/apache/commons/compress/compressors/CompressorStreamFactory.java
index fc7c7fe5e..3f604cfcb 100644
--- a/src/main/java/org/apache/commons/compress/compressors/CompressorStreamFactory.java
+++ b/src/main/java/org/apache/commons/compress/compressors/CompressorStreamFactory.java
@@ -224,10 +224,29 @@ public class CompressorStreamFactory implements CompressorStreamProvider {
      * @since 1.14
      */
     public static String detect(final InputStream inputStream) throws CompressorException {
+        final Set<String> defaultCompressorNamesForDetection = Sets.newHashSet(BZIP2, GZIP, PACK200, SNAPPY_FRAMED, Z, DEFLATE, XZ, LZMA, LZ4_FRAMED, ZSTANDARD);
+        return detect(inputStream, defaultCompressorNamesForDetection);
+    }
+
+    /**
+     * Try to detect the type of compressor stream while limiting the type to the provided set of compressor names.
+     *
+     * @param inputStream input stream
+     * @param compressorNames compressor names to limit autodetection
+     * @return type of compressor stream detected
+     * @throws CompressorException if no compressor stream type was detected
+     *                             or if something else went wrong
+     * @throws IllegalArgumentException if stream is null or does not support mark
+     */
+    static String detect(final InputStream inputStream, final Set<String> compressorNames) throws CompressorException {
         if (inputStream == null) {
             throw new IllegalArgumentException("Stream must not be null.");
         }
 
+        if (compressorNames == null || compressorNames.isEmpty()) {
+            throw new IllegalArgumentException("Compressor names cannot be null or empty");
+        }
+
         if (!inputStream.markSupported()) {
             throw new IllegalArgumentException("Mark is not supported.");
         }
@@ -242,43 +261,44 @@ public class CompressorStreamFactory implements CompressorStreamProvider {
             throw new CompressorException("IOException while reading signature.", e);
         }
 
-        if (BZip2CompressorInputStream.matches(signature, signatureLength)) {
+        if (compressorNames.contains(BZIP2) && BZip2CompressorInputStream.matches(signature, signatureLength)) {
             return BZIP2;
         }
 
-        if (GzipCompressorInputStream.matches(signature, signatureLength)) {
+        if (compressorNames.contains(GZIP) && GzipCompressorInputStream.matches(signature, signatureLength)) {
             return GZIP;
         }
 
-        if (Pack200CompressorInputStream.matches(signature, signatureLength)) {
+        if (compressorNames.contains(PACK200) && Pack200CompressorInputStream.matches(signature, signatureLength)) {
             return PACK200;
         }
 
-        if (FramedSnappyCompressorInputStream.matches(signature, signatureLength)) {
+        if (compressorNames.contains(SNAPPY_FRAMED) &&
+                FramedSnappyCompressorInputStream.matches(signature, signatureLength)) {
             return SNAPPY_FRAMED;
         }
 
-        if (ZCompressorInputStream.matches(signature, signatureLength)) {
+        if (compressorNames.contains(Z) && ZCompressorInputStream.matches(signature, signatureLength)) {
             return Z;
         }
 
-        if (DeflateCompressorInputStream.matches(signature, signatureLength)) {
+        if (compressorNames.contains(DEFLATE) && DeflateCompressorInputStream.matches(signature, signatureLength)) {
             return DEFLATE;
         }
 
-        if (XZUtils.matches(signature, signatureLength)) {
+        if (compressorNames.contains(XZ) && XZUtils.matches(signature, signatureLength)) {
             return XZ;
         }
 
-        if (LZMAUtils.matches(signature, signatureLength)) {
+        if (compressorNames.contains(LZMA) && LZMAUtils.matches(signature, signatureLength)) {
             return LZMA;
         }
 
-        if (FramedLZ4CompressorInputStream.matches(signature, signatureLength)) {
+        if (compressorNames.contains(LZ4_FRAMED) && FramedLZ4CompressorInputStream.matches(signature, signatureLength)) {
             return LZ4_FRAMED;
         }
 
-        if (ZstdUtils.matches(signature, signatureLength)) {
+        if (compressorNames.contains(ZSTANDARD) && ZstdUtils.matches(signature, signatureLength)) {
             return ZSTANDARD;
         }
 
@@ -502,6 +522,7 @@ public class CompressorStreamFactory implements CompressorStreamProvider {
         this.decompressConcatenated = decompressUntilEOF;
         this.memoryLimitInKb = memoryLimitInKb;
     }
+
     /**
      * Create a compressor input stream from an input stream, auto-detecting the
      * compressor type from the first few bytes of the stream. The InputStream
@@ -520,6 +541,27 @@ public class CompressorStreamFactory implements CompressorStreamProvider {
         return createCompressorInputStream(detect(in), in);
     }
 
+    /**
+     * Create a compressor input stream from an input stream, auto-detecting the
+     * compressor type from the first few bytes of the stream while limiting the detected type
+     * to the provided set of compressor names. The InputStream must support marks, like BufferedInputStream.
+     *
+     * @param in
+     *            the input stream
+     * @param compressorNames
+     *            compressor names to limit autodetection
+     * @return the compressor input stream
+     * @throws CompressorException
+     *             if the autodetected compressor is not in the provided set of compressor names
+     * @throws IllegalArgumentException
+     *             if the stream is null or does not support mark
+     * @since 1.26
+     */
+    public CompressorInputStream createCompressorInputStream(final InputStream in, final Set<String> compressorNames)
+            throws CompressorException {
+        return createCompressorInputStream(detect(in, compressorNames), in);
+    }
+
     /**
      * Creates a compressor input stream from a compressor name and an input
      * stream.
diff --git a/src/test/java/org/apache/commons/compress/compressors/DetectCompressorTest.java b/src/test/java/org/apache/commons/compress/compressors/DetectCompressorTest.java
index 1c7e2ab81..c8970294a 100644
--- a/src/test/java/org/apache/commons/compress/compressors/DetectCompressorTest.java
+++ b/src/test/java/org/apache/commons/compress/compressors/DetectCompressorTest.java
@@ -30,6 +30,10 @@ import java.io.ByteArrayInputStream;
 import java.io.IOException;
 import java.io.InputStream;
 import java.nio.file.Files;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.stream.Stream;
 
 import org.apache.commons.compress.MemoryLimitException;
 import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream;
@@ -41,6 +45,9 @@ import org.apache.commons.compress.compressors.zstandard.ZstdCompressorInputStre
 import org.apache.commons.compress.utils.ByteUtils;
 import org.apache.commons.io.input.BrokenInputStream;
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
 
 @SuppressWarnings("deprecation") // deliberately tests setDecompressConcatenated
 public final class DetectCompressorTest {
@@ -96,26 +103,25 @@ public final class DetectCompressorTest {
     };
 
     @SuppressWarnings("resource") // Caller closes.
-    private CompressorInputStream createStreamFor(final String resource)
-            throws CompressorException, IOException {
+    private CompressorInputStream createStreamFor(final String resource) throws CompressorException, IOException {
         return factory.createCompressorInputStream(
                    new BufferedInputStream(Files.newInputStream(
                        getFile(resource).toPath())));
     }
 
     @SuppressWarnings("resource") // Caller closes.
-    private CompressorInputStream createStreamFor(final String resource, final CompressorStreamFactory factory)
-            throws CompressorException, IOException {
-        return factory.createCompressorInputStream(
-                   new BufferedInputStream(Files.newInputStream(
-                       getFile(resource).toPath())));
+    private CompressorInputStream createStreamFor(final String resource, final Set<String> compressorNames) throws CompressorException, IOException {
+        return factory.createCompressorInputStream(new BufferedInputStream(Files.newInputStream(getFile(resource).toPath())), compressorNames);
+    }
+
+    @SuppressWarnings("resource") // Caller closes.
+    private CompressorInputStream createStreamFor(final String resource, final CompressorStreamFactory factory) throws CompressorException, IOException {
+        return factory.createCompressorInputStream(new BufferedInputStream(Files.newInputStream(getFile(resource).toPath())));
     }
 
     private InputStream createStreamFor(final String fileName, final int memoryLimitInKb) throws Exception {
-        final CompressorStreamFactory fac = new CompressorStreamFactory(true,
-                memoryLimitInKb);
-        final InputStream is = new BufferedInputStream(
-                Files.newInputStream(getFile(fileName).toPath()));
+        final CompressorStreamFactory fac = new CompressorStreamFactory(true, memoryLimitInKb);
+        final InputStream is = new BufferedInputStream(Files.newInputStream(getFile(fileName).toPath()));
         try {
             return fac.createCompressorInputStream(is);
         } catch (final CompressorException e) {
@@ -129,9 +135,12 @@ public final class DetectCompressorTest {
     }
 
     private String detect(final String testFileName) throws IOException, CompressorException {
-        try (InputStream is = new BufferedInputStream(
-                Files.newInputStream(getFile(testFileName).toPath()))) {
-            return CompressorStreamFactory.detect(is);
+        return detect(testFileName, null);
+    }
+
+    private String detect(final String testFileName, final Set<String> compressorNames) throws IOException, CompressorException {
+        try (InputStream is = new BufferedInputStream(Files.newInputStream(getFile(testFileName).toPath()))) {
+            return compressorNames != null ? CompressorStreamFactory.detect(is, compressorNames) : CompressorStreamFactory.detect(is);
         }
     }
 
@@ -154,17 +163,58 @@ public final class DetectCompressorTest {
 
         assertThrows(CompressorException.class, () -> CompressorStreamFactory.detect(new BufferedInputStream(new ByteArrayInputStream(ByteUtils.EMPTY_BYTE_ARRAY))));
 
-        final IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> CompressorStreamFactory.detect(null),
-                "shouldn't be able to detect null stream");
+        final IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> CompressorStreamFactory.detect(null), "shouldn't be able to detect null stream");
         assertEquals("Stream must not be null.", e.getMessage());
 
-        final CompressorException ce = assertThrows(CompressorException.class, () -> CompressorStreamFactory.detect(new BufferedInputStream(new BrokenInputStream())),
-                "Expected IOException");
+        final CompressorException ce = assertThrows(CompressorException.class, () -> CompressorStreamFactory.detect(new BufferedInputStream(new BrokenInputStream())), "Expected IOException");
         assertEquals("IOException while reading signature.", ce.getMessage());
     }
 
     @Test
-    public void testDetection() throws Exception {
+    public void testDetectNullOrEmptyCompressorNames() throws Exception {
+        assertThrows(IllegalArgumentException.class, () -> CompressorStreamFactory.detect(createStreamFor("bla.txt.bz2"), (Set<String>) null));
+        assertThrows(IllegalArgumentException.class, () -> CompressorStreamFactory.detect(createStreamFor("bla.tgz"), new HashSet<>()));
+    }
+
+    public static Stream<Arguments> limitedByNameData() {
+        return Stream.of(
+                Arguments.of("bla.txt.bz2", CompressorStreamFactory.BZIP2),
+                Arguments.of("bla.tgz", CompressorStreamFactory.GZIP),
+                Arguments.of("bla.pack", CompressorStreamFactory.PACK200),
+                Arguments.of("bla.tar.xz", CompressorStreamFactory.XZ),
+                Arguments.of("bla.tar.deflatez", CompressorStreamFactory.DEFLATE),
+                Arguments.of("bla.tar.lz4", CompressorStreamFactory.LZ4_FRAMED),
+                Arguments.of("bla.tar.lzma", CompressorStreamFactory.LZMA),
+                Arguments.of("bla.tar.sz", CompressorStreamFactory.SNAPPY_FRAMED),
+                Arguments.of("bla.tar.Z", CompressorStreamFactory.Z),
+                Arguments.of("bla.tar.zst", CompressorStreamFactory.ZSTANDARD)
+        );
+    }
+
+    @ParameterizedTest
+    @MethodSource("limitedByNameData")
+    public void testDetectLimitedByName(final String filename, final String compressorName) throws Exception {
+        assertEquals(compressorName, detect(filename, Collections.singleton(compressorName)));
+    }
+
+    @Test
+    public void testDetectLimitedByNameNotFound() throws Exception {
+        Set<String> compressorNames = Collections.singleton(CompressorStreamFactory.DEFLATE);
+
+        assertThrows(CompressorException.class, () -> detect("bla.txt.bz2", compressorNames));
+        assertThrows(CompressorException.class, () -> detect("bla.tgz", compressorNames));
+        assertThrows(CompressorException.class, () -> detect("bla.pack", compressorNames));
+        assertThrows(CompressorException.class, () -> detect("bla.tar.xz", compressorNames));
+        assertThrows(CompressorException.class, () -> detect("bla.tar.deflatez", Collections.singleton(CompressorStreamFactory.BZIP2)));
+        assertThrows(CompressorException.class, () -> detect("bla.tar.lz4", compressorNames));
+        assertThrows(CompressorException.class, () -> detect("bla.tar.lzma", compressorNames));
+        assertThrows(CompressorException.class, () -> detect("bla.tar.sz", compressorNames));
+        assertThrows(CompressorException.class, () -> detect("bla.tar.Z", compressorNames));
+        assertThrows(CompressorException.class, () -> detect("bla.tar.zst", compressorNames));
+    }
+
+    @Test
+    public void testCreateWithAutoDetection() throws Exception {
         try (CompressorInputStream bzip2 = createStreamFor("bla.txt.bz2")) {
             assertNotNull(bzip2);
             assertTrue(bzip2 instanceof BZip2CompressorInputStream);
@@ -198,6 +248,49 @@ public final class DetectCompressorTest {
         assertThrows(CompressorException.class, () -> factory.createCompressorInputStream(new ByteArrayInputStream(ByteUtils.EMPTY_BYTE_ARRAY)));
     }
 
+    @Test
+    public void testCreateLimitedByName() throws Exception {
+        try (CompressorInputStream bzip2 = createStreamFor("bla.txt.bz2", Collections.singleton(CompressorStreamFactory.BZIP2))) {
+            assertNotNull(bzip2);
+            assertTrue(bzip2 instanceof BZip2CompressorInputStream);
+        }
+
+        try (CompressorInputStream gzip = createStreamFor("bla.tgz", Collections.singleton(CompressorStreamFactory.GZIP))) {
+            assertNotNull(gzip);
+            assertTrue(gzip instanceof GzipCompressorInputStream);
+        }
+
+        try (CompressorInputStream pack200 = createStreamFor("bla.pack", Collections.singleton(CompressorStreamFactory.PACK200))) {
+            assertNotNull(pack200);
+            assertTrue(pack200 instanceof Pack200CompressorInputStream);
+        }
+
+        try (CompressorInputStream xz = createStreamFor("bla.tar.xz", Collections.singleton(CompressorStreamFactory.XZ))) {
+            assertNotNull(xz);
+            assertTrue(xz instanceof XZCompressorInputStream);
+        }
+
+        try (CompressorInputStream zlib = createStreamFor("bla.tar.deflatez", Collections.singleton(CompressorStreamFactory.DEFLATE))) {
+            assertNotNull(zlib);
+            assertTrue(zlib instanceof DeflateCompressorInputStream);
+        }
+
+        try (CompressorInputStream zstd = createStreamFor("bla.tar.zst", Collections.singleton(CompressorStreamFactory.ZSTANDARD))) {
+            assertNotNull(zstd);
+            assertTrue(zstd instanceof ZstdCompressorInputStream);
+        }
+    }
+
+    @Test
+    public void testCreateLimitedByNameNotFound() throws Exception {
+        assertThrows(CompressorException.class, () -> createStreamFor("bla.txt.bz2", Collections.singleton(CompressorStreamFactory.BROTLI)));
+        assertThrows(CompressorException.class, () -> createStreamFor("bla.tgz", Collections.singleton(CompressorStreamFactory.Z)));
+        assertThrows(CompressorException.class, () -> createStreamFor("bla.pack", Collections.singleton(CompressorStreamFactory.SNAPPY_FRAMED)));
+        assertThrows(CompressorException.class, () -> createStreamFor("bla.tar.xz", Collections.singleton(CompressorStreamFactory.GZIP)));
+        assertThrows(CompressorException.class, () -> createStreamFor("bla.tar.deflatez", Collections.singleton(CompressorStreamFactory.PACK200)));
+        assertThrows(CompressorException.class, () -> createStreamFor("bla.tar.zst", Collections.singleton(CompressorStreamFactory.LZ4_FRAMED)));
+    }
+
     @Test
     public void testLZMAMemoryLimit() throws Exception {
         assertThrows(MemoryLimitException.class, () -> createStreamFor("COMPRESS-382", 100));