diff --git a/spring-boot-tools/spring-boot-loader/src/main/java/org/springframework/boot/loader/data/RandomAccessDataFile.java b/spring-boot-tools/spring-boot-loader/src/main/java/org/springframework/boot/loader/data/RandomAccessDataFile.java index 263fd5e76f9..299c6308d06 100644 --- a/spring-boot-tools/spring-boot-loader/src/main/java/org/springframework/boot/loader/data/RandomAccessDataFile.java +++ b/spring-boot-tools/spring-boot-loader/src/main/java/org/springframework/boot/loader/data/RandomAccessDataFile.java @@ -170,11 +170,11 @@ public class RandomAccessDataFile implements RandomAccessData { return -1; } RandomAccessFile file = this.file; - if (file == null) { - file = RandomAccessDataFile.this.filePool.acquire(); - file.seek(RandomAccessDataFile.this.offset + this.position); - } try { + if (file == null) { + file = RandomAccessDataFile.this.filePool.acquire(); + file.seek(RandomAccessDataFile.this.offset + this.position); + } if (b == null) { int rtn = file.read(); moveOn(rtn == -1 ? 0 : 1); @@ -185,7 +185,7 @@ public class RandomAccessDataFile implements RandomAccessData { } } finally { - if (this.file == null) { + if (this.file == null && file != null) { RandomAccessDataFile.this.filePool.release(file); } } @@ -229,7 +229,7 @@ public class RandomAccessDataFile implements RandomAccessData { * Manage a pool that can be used to perform concurrent reads on the underlying * {@link RandomAccessFile}. */ - private class FilePool { + class FilePool { private final int size; diff --git a/spring-boot-tools/spring-boot-loader/src/test/java/org/springframework/boot/loader/data/RandomAccessDataFileTests.java b/spring-boot-tools/spring-boot-loader/src/test/java/org/springframework/boot/loader/data/RandomAccessDataFileTests.java index b295230426d..c156ccf048a 100644 --- a/spring-boot-tools/spring-boot-loader/src/test/java/org/springframework/boot/loader/data/RandomAccessDataFileTests.java +++ b/spring-boot-tools/spring-boot-loader/src/test/java/org/springframework/boot/loader/data/RandomAccessDataFileTests.java @@ -18,7 +18,9 @@ package org.springframework.boot.loader.data; import java.io.File; import java.io.FileOutputStream; +import java.io.IOException; import java.io.InputStream; +import java.io.RandomAccessFile; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Arrays; @@ -35,10 +37,20 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.rules.TemporaryFolder; +import org.mockito.internal.util.MockUtil; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import org.springframework.boot.loader.data.RandomAccessData.ResourceAccess; +import org.springframework.boot.loader.data.RandomAccessDataFile.FilePool; +import org.springframework.test.util.ReflectionTestUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.BDDMockito.willAnswer; +import static org.mockito.BDDMockito.willThrow; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Mockito.spy; /** * Tests for {@link RandomAccessDataFile}. @@ -309,4 +321,38 @@ public class RandomAccessDataFileTests { assertThat(queue.size()).isEqualTo(0); } + @Test + public void seekFailuresDoNotPreventSubsequentReads() throws Exception { + FilePool filePool = (FilePool) ReflectionTestUtils.getField(this.file, + "filePool"); + FilePool spiedPool = spy(filePool); + ReflectionTestUtils.setField(this.file, "filePool", spiedPool); + willAnswer(new Answer() { + + @Override + public RandomAccessFile answer(InvocationOnMock invocation) throws Throwable { + RandomAccessFile originalFile = (RandomAccessFile) invocation + .callRealMethod(); + if (new MockUtil().isSpy(originalFile)) { + return originalFile; + } + RandomAccessFile spiedFile = spy(originalFile); + willThrow(new IOException("Seek failed")).given(spiedFile) + .seek(anyLong()); + return spiedFile; + } + + }).given(spiedPool).acquire(); + + for (int i = 0; i < 5; i++) { + try { + this.file.getInputStream(ResourceAccess.PER_READ).read(); + fail("Read should fail due to exception from seek"); + } + catch (IOException ex) { + + } + } + } + }