diff --git a/.classpath b/.classpath index 149cb3c9..822f9dca 100644 --- a/.classpath +++ b/.classpath @@ -6,7 +6,7 @@ - + diff --git a/src/com/machinepublishers/jbrowserdriver/StatusMonitor.java b/src/com/machinepublishers/jbrowserdriver/StatusMonitor.java index d081e369..6bc5c464 100644 --- a/src/com/machinepublishers/jbrowserdriver/StatusMonitor.java +++ b/src/com/machinepublishers/jbrowserdriver/StatusMonitor.java @@ -123,6 +123,7 @@ void clearStatusMonitor() { for (StreamConnection conn : connections.values()) { Util.close(conn); } + StreamConnection.cleanUp(); connections.clear(); primaryDocuments.clear(); discarded.clear(); diff --git a/src/com/machinepublishers/jbrowserdriver/StreamConnection.java b/src/com/machinepublishers/jbrowserdriver/StreamConnection.java index 5bc36b86..8b6f1326 100644 --- a/src/com/machinepublishers/jbrowserdriver/StreamConnection.java +++ b/src/com/machinepublishers/jbrowserdriver/StreamConnection.java @@ -59,6 +59,7 @@ import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.regex.Matcher; @@ -97,6 +98,7 @@ import org.apache.http.protocol.HttpContext; import org.apache.http.ssl.SSLContexts; import org.apache.http.ssl.TrustStrategy; +import org.apache.http.util.EntityUtils; class StreamConnection extends HttpURLConnection implements Closeable { private static final Pattern invalidUrlChar = Pattern.compile("[^-A-Za-z0-9._~:/?#\\[\\]@!$&'()*+,;=]"); @@ -137,28 +139,29 @@ class StreamConnection extends HttpURLConnection implements Closeable { .setDefaultCredentialsProvider(ProxyAuth.instance()) .setConnectionReuseStrategy(DefaultConnectionReuseStrategy.INSTANCE) .build(); - private static boolean cacheByDefault; + private static final AtomicBoolean cacheByDefault = new AtomicBoolean(); private final Map> reqHeaders = new LinkedHashMap>(); - private final RequestConfig.Builder config = RequestConfig.custom(); + private final AtomicReference config = new AtomicReference(RequestConfig.custom()); private final URL url; private final String urlString; private final AtomicBoolean skip = new AtomicBoolean(); private final AtomicLong settingsId = new AtomicLong(); - private int connectTimeout; - private int readTimeout; - private String method; - private boolean cache = cacheByDefault; - private boolean connected; - private boolean exec; - private CloseableHttpResponse response; - private HttpEntity entity; - private boolean consumed; - private HttpClientContext context = HttpClientContext.create(); - private HttpRequestBase req; - private boolean contentEncodingRemoved; - private long contentLength = -1; - private ByteArrayOutputStream reqData = new ByteArrayOutputStream(); + private final AtomicInteger connectTimeout = new AtomicInteger(); + private final AtomicInteger readTimeout = new AtomicInteger(); + private final AtomicReference method = new AtomicReference(); + private final AtomicBoolean cache = new AtomicBoolean(cacheByDefault.get()); + private final AtomicBoolean connected = new AtomicBoolean(); + private final AtomicBoolean exec = new AtomicBoolean(); + private final AtomicReference response = new AtomicReference(); + private final AtomicReference entity = new AtomicReference(); + private final AtomicBoolean consumed = new AtomicBoolean(); + private final AtomicBoolean closed = new AtomicBoolean(); + private final AtomicReference context = new AtomicReference(HttpClientContext.create()); + private final AtomicReference req = new AtomicReference(); + private final AtomicBoolean contentEncodingRemoved = new AtomicBoolean(); + private final AtomicLong contentLength = new AtomicLong(-1); + private final AtomicReference reqData = new AtomicReference(new ByteArrayOutputStream()); static { if (!"false".equals(System.getProperty("jbd.blockads"))) { @@ -349,16 +352,15 @@ private void processHeaders(AtomicReference settings, HttpRequestBase @Override public void connect() throws IOException { try { - if (!connected) { + if (connected.compareAndSet(false, true)) { if (StatusMonitor.get(settingsId.get()).isDiscarded(urlString) || isBlocked(url.getHost())) { skip.set(true); } else if (SettingsManager.get(settingsId.get()) != null) { - connected = true; - config + config.get() .setCookieSpec(CookieSpecs.STANDARD) - .setConnectTimeout(connectTimeout) - .setConnectionRequestTimeout(readTimeout); + .setConnectTimeout(connectTimeout.get()) + .setConnectionRequestTimeout(readTimeout.get()); URI uri = null; try { uri = url.toURI(); @@ -375,34 +377,34 @@ public void connect() throws IOException { builder.append(urlString.substring(left)); uri = new URI(builder.toString()); } - if ("OPTIONS".equals(method)) { - req = new HttpOptions(uri); - } else if ("GET".equals(method)) { - req = new HttpGet(uri); - } else if ("HEAD".equals(method)) { - req = new HttpHead(uri); - } else if ("POST".equals(method)) { - req = new HttpPost(uri); - } else if ("PUT".equals(method)) { - req = new HttpPut(uri); - } else if ("DELETE".equals(method)) { - req = new HttpDelete(uri); - } else if ("TRACE".equals(method)) { - req = new HttpTrace(uri); + if ("OPTIONS".equals(method.get())) { + req.set(new HttpOptions(uri)); + } else if ("GET".equals(method.get())) { + req.set(new HttpGet(uri)); + } else if ("HEAD".equals(method.get())) { + req.set(new HttpHead(uri)); + } else if ("POST".equals(method.get())) { + req.set(new HttpPost(uri)); + } else if ("PUT".equals(method.get())) { + req.set(new HttpPut(uri)); + } else if ("DELETE".equals(method.get())) { + req.set(new HttpDelete(uri)); + } else if ("TRACE".equals(method.get())) { + req.set(new HttpTrace(uri)); } - processHeaders(SettingsManager.get(settingsId.get()), req, url.getHost()); + processHeaders(SettingsManager.get(settingsId.get()), req.get(), url.getHost()); ProxyConfig proxy = SettingsManager.get(settingsId.get()).get().proxy(); if (proxy != null && !proxy.directConnection()) { - config.setExpectContinueEnabled(proxy.expectContinue()); + config.get().setExpectContinueEnabled(proxy.expectContinue()); InetSocketAddress proxyAddress = new InetSocketAddress(proxy.host(), proxy.port()); if (proxy.type() == ProxyConfig.Type.SOCKS) { - context.setAttribute("proxy.socks.address", proxyAddress); + context.get().setAttribute("proxy.socks.address", proxyAddress); } else { - config.setProxy(new HttpHost(proxy.host(), proxy.port())); + config.get().setProxy(new HttpHost(proxy.host(), proxy.port())); } } - context.setCookieStore(SettingsManager.get(settingsId.get()).get().cookieStore()); - context.setRequestConfig(config.build()); + context.get().setCookieStore(SettingsManager.get(settingsId.get()).get().cookieStore()); + context.get().setRequestConfig(config.get().build()); StatusMonitor.get(settingsId.get()).addStatusMonitor(url, this); } } @@ -413,19 +415,18 @@ public void connect() throws IOException { private void exec() throws IOException { try { - if (!exec) { - exec = true; + if (exec.compareAndSet(false, true)) { connect(); - if (req != null) { - if ("POST".equals(method)) { - ((HttpPost) req).setEntity(new ByteArrayEntity(reqData.toByteArray())); - } else if ("PUT".equals(method)) { - ((HttpPut) req).setEntity(new ByteArrayEntity(reqData.toByteArray())); + if (req.get() != null) { + if ("POST".equals(method.get())) { + ((HttpPost) req.get()).setEntity(new ByteArrayEntity(reqData.get().toByteArray())); + } else if ("PUT".equals(method.get())) { + ((HttpPut) req.get()).setEntity(new ByteArrayEntity(reqData.get().toByteArray())); } - response = cache ? cachingClient.execute(req, context) : client.execute(req, context); - if (response != null && response.getEntity() != null) { - entity = response.getEntity(); - response.setHeader("cache-control", "no-store"); + response.set(cache.get() ? cachingClient.execute(req.get(), context.get()) : client.execute(req.get(), context.get())); + if (response.get() != null && response.get().getEntity() != null) { + entity.set(response.get().getEntity()); + response.get().setHeader("cache-control", "no-store"); } } } @@ -441,13 +442,26 @@ public void disconnect() { @Override public void close() throws IOException { - try { - if (response != null) { - response.close(); + if (closed.compareAndSet(false, true)) { + if (entity.get() != null) { + try { + EntityUtils.consume(entity.get()); + } catch (Throwable t) {} + } + if (req.get() != null) { + try { + req.get().reset(); + } catch (Throwable t) {} + } + if (response.get() != null) { + try { + response.get().close(); + } catch (Throwable t) {} } - } catch (Throwable t) { - Logs.logsFor(settingsId.get()).exception(t); } + } + + static void cleanUp() { manager.closeExpiredConnections(); manager.closeIdleConnections(30, TimeUnit.SECONDS); } @@ -455,27 +469,33 @@ public void close() throws IOException { @Override public InputStream getInputStream() throws IOException { exec(); - if (!consumed) { - consumed = true; - if (entity != null && entity.getContent() != null && !skip.get()) { - String header = getHeaderField("content-disposition"); - if (header != null && !header.isEmpty()) { - Matcher matcher = downloadHeader.matcher(header); - if (matcher.matches()) { - AtomicReference settings = SettingsManager.get(settingsId.get()); - if (settings != null) { - File downloadFile = new File(settings.get().downloadDir(), - matcher.group(1) == null || matcher.group(1).isEmpty() - ? Long.toString(System.nanoTime()) : matcher.group(1)); - downloadFile.deleteOnExit(); - Files.write(downloadFile.toPath(), Util.toBytes(entity.getContent())); + if (consumed.compareAndSet(false, true)) { + if (entity.get() != null) { + try { + InputStream entityStream = entity.get().getContent(); + if (entityStream != null && !skip.get()) { + String header = getHeaderField("content-disposition"); + if (header != null && !header.isEmpty()) { + Matcher matcher = downloadHeader.matcher(header); + if (matcher.matches()) { + AtomicReference settings = SettingsManager.get(settingsId.get()); + if (settings != null) { + File downloadFile = new File(settings.get().downloadDir(), + matcher.group(1) == null || matcher.group(1).isEmpty() + ? Long.toString(System.nanoTime()) : matcher.group(1)); + downloadFile.deleteOnExit(); + Files.write(downloadFile.toPath(), Util.toBytes(entityStream)); + } + skip.set(true); + } + } + if (!skip.get()) { + return StreamInjectors.injectedStream( + this, entityStream, urlString, settingsId.get()); } - skip.set(true); } - } - if (!skip.get()) { - return StreamInjectors.injectedStream( - this, entity.getContent(), urlString, settingsId.get()); + } finally { + close(); } } } @@ -497,7 +517,7 @@ public InputStream getErrorStream() { @Override public String getResponseMessage() throws IOException { exec(); - return response == null || response.getStatusLine() == null ? null : response.getStatusLine().getReasonPhrase(); + return response.get() == null || response.get().getStatusLine() == null ? null : response.get().getStatusLine().getReasonPhrase(); } @Override @@ -507,7 +527,7 @@ public int getResponseCode() throws IOException { if (skip.get()) { return 204; } - return response == null || response.getStatusLine() == null ? 499 : response.getStatusLine().getStatusCode(); + return response.get() == null || response.get().getStatusLine() == null ? 499 : response.get().getStatusLine().getStatusCode(); } @Override @@ -526,36 +546,36 @@ public Object getContent(Class[] classes) throws IOException { @Override public String getContentEncoding() { - if (contentEncodingRemoved) { + if (contentEncodingRemoved.get()) { return null; } - return entity == null || entity.getContentEncoding() == null || skip.get() ? null : entity.getContentEncoding().getValue(); + return entity.get() == null || entity.get().getContentEncoding() == null || skip.get() ? null : entity.get().getContentEncoding().getValue(); } public void removeContentEncoding() { - response.removeHeaders("content-encoding"); - contentEncodingRemoved = true; + response.get().removeHeaders("content-encoding"); + contentEncodingRemoved.set(true); } @Override public int getContentLength() { - if (contentLength != -1) { - return (int) contentLength; + if (contentLength.get() != -1) { + return (int) contentLength.get(); } - return entity == null || skip.get() ? 0 : (int) entity.getContentLength(); + return entity.get() == null || skip.get() ? 0 : (int) entity.get().getContentLength(); } @Override public long getContentLengthLong() { - if (contentLength != -1) { - return contentLength; + if (contentLength.get() != -1) { + return contentLength.get(); } - return entity == null || skip.get() ? 0 : entity.getContentLength(); + return entity.get() == null || skip.get() ? 0 : entity.get().getContentLength(); } public void setContentLength(long contentLength) { - this.contentLength = contentLength; - response.setHeader("content-length", Long.toString(contentLength)); + this.contentLength.set(contentLength); + response.get().setHeader("content-length", Long.toString(contentLength)); } @Override @@ -566,7 +586,7 @@ public Permission getPermission() throws IOException { @Override public String getContentType() { - return entity == null || entity.getContentType() == null || skip.get() ? null : entity.getContentType().getValue(); + return entity.get() == null || entity.get().getContentType() == null || skip.get() ? null : entity.get().getContentType().getValue(); } @Override @@ -587,8 +607,8 @@ public long getLastModified() { @Override public Map> getHeaderFields() { Map> map = new HashMap>(); - if (response != null) { - Header[] headers = response.getAllHeaders(); + if (response.get() != null) { + Header[] headers = response.get().getAllHeaders(); for (int i = 0; headers != null && i < headers.length; i++) { String name = headers[i].getName(); if (!map.containsKey(name)) { @@ -602,8 +622,8 @@ public Map> getHeaderFields() { @Override public String getHeaderField(String name) { - if (response != null) { - Header[] headers = response.getHeaders(name); + if (response.get() != null) { + Header[] headers = response.get().getHeaders(name); if (headers != null && headers.length > 0) { return headers[headers.length - 1].getValue(); } @@ -613,8 +633,8 @@ public String getHeaderField(String name) { @Override public int getHeaderFieldInt(String name, int defaultValue) { - if (response != null) { - Header[] headers = response.getHeaders(name); + if (response.get() != null) { + Header[] headers = response.get().getHeaders(name); if (headers != null && headers.length > 0) { return Integer.parseInt(headers[headers.length - 1].getValue()); } @@ -624,8 +644,8 @@ public int getHeaderFieldInt(String name, int defaultValue) { @Override public long getHeaderFieldLong(String name, long defaultValue) { - if (response != null) { - Header[] headers = response.getHeaders(name); + if (response.get() != null) { + Header[] headers = response.get().getHeaders(name); if (headers != null && headers.length > 0) { return Long.parseLong(headers[headers.length - 1].getValue()); } @@ -640,18 +660,18 @@ public long getHeaderFieldDate(String name, long defaultValue) { @Override public String getHeaderFieldKey(int n) { - return response == null - || response.getAllHeaders() == null - || n >= response.getAllHeaders().length - || response.getAllHeaders()[n] == null ? null : response.getAllHeaders()[n].getName(); + return response.get() == null + || response.get().getAllHeaders() == null + || n >= response.get().getAllHeaders().length + || response.get().getAllHeaders()[n] == null ? null : response.get().getAllHeaders()[n].getName(); } @Override public String getHeaderField(int n) { - return response == null - || response.getAllHeaders() == null - || n >= response.getAllHeaders().length - || response.getAllHeaders()[n] == null ? null : response.getAllHeaders()[n].getValue(); + return response.get() == null + || response.get().getAllHeaders() == null + || n >= response.get().getAllHeaders().length + || response.get().getAllHeaders()[n] == null ? null : response.get().getAllHeaders()[n].getValue(); } /////////////////////////////////////////////////////////// @@ -660,7 +680,7 @@ public String getHeaderField(int n) { @Override public OutputStream getOutputStream() throws IOException { - return skip.get() ? new ByteArrayOutputStream() : reqData; + return skip.get() ? new ByteArrayOutputStream() : reqData.get(); } @Override @@ -670,32 +690,32 @@ public URL getURL() { @Override public String getRequestMethod() { - return method; + return method.get(); } @Override public void setRequestMethod(String method) throws ProtocolException { - this.method = method.toUpperCase(); + this.method.set(method.toUpperCase()); } @Override public int getConnectTimeout() { - return connectTimeout; + return connectTimeout.get(); } @Override public void setConnectTimeout(int timeout) { - this.connectTimeout = timeout; + this.connectTimeout.set(timeout); } @Override public int getReadTimeout() { - return readTimeout; + return readTimeout.get(); } @Override public void setReadTimeout(int timeout) { - this.readTimeout = timeout; + this.readTimeout.set(timeout); } @Override diff --git a/src/com/machinepublishers/jbrowserdriver/Util.java b/src/com/machinepublishers/jbrowserdriver/Util.java index 104f7251..a9dc5174 100644 --- a/src/com/machinepublishers/jbrowserdriver/Util.java +++ b/src/com/machinepublishers/jbrowserdriver/Util.java @@ -29,6 +29,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.net.SocketException; import java.net.URLConnection; import java.nio.charset.Charset; import java.util.Random; @@ -40,6 +41,8 @@ import javax.net.ssl.SSLProtocolException; +import org.apache.http.ConnectionClosedException; + import com.machinepublishers.browser.Browser; class Util { @@ -178,7 +181,7 @@ static String toString(InputStream inputStream, String charset) { BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, charset), chars.length); try { for (int len; -1 != (len = reader.read(chars, 0, chars.length)); builder.append(chars, 0, len)); - } catch (EOFException | SSLProtocolException e) {} + } catch (EOFException | SSLProtocolException | ConnectionClosedException | SocketException e) {} return builder.toString(); } catch (Throwable t) { return null; @@ -193,7 +196,7 @@ static byte[] toBytes(InputStream inputStream) throws IOException { ByteArrayOutputStream out = new ByteArrayOutputStream(bytes.length); try { for (int len = 0; -1 != (len = inputStream.read(bytes, 0, bytes.length)); out.write(bytes, 0, len)); - } catch (EOFException | SSLProtocolException e) {} + } catch (EOFException | SSLProtocolException | ConnectionClosedException | SocketException e) {} return out.toByteArray(); } finally { close(inputStream);