From def207eef16607643dcdaf02d2dc01b439ccfc8a Mon Sep 17 00:00:00 2001
From: Alex Rogozhnikov <arogozhnikov@users.noreply.github.com>
Date: Mon, 30 Sep 2024 06:37:01 -0700
Subject: [PATCH] Add support for cloudflare's R2 storage (#888)

---
 .github/workflows/ci.yml |  5 ++-
 s3fs/core.py             | 73 +++++++++++++++++++++++-----------------
 s3fs/tests/test_s3fs.py  | 50 +++++++++++++++++++++++++++
 3 files changed, 95 insertions(+), 33 deletions(-)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index d5d28eee..99d11beb 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -24,11 +24,10 @@ jobs:
           fetch-depth: 0
 
       - name: Setup conda
-        uses: mamba-org/setup-micromamba@v1
+        uses: conda-incubator/setup-miniconda@v3
         with:
           environment-file: ci/env.yaml
-          create-args: >-
-            python=${{ matrix.PY }}
+          python-version: ${{ matrix.PY }}
 
       - name: Install
         shell: bash -l {0}
diff --git a/s3fs/core.py b/s3fs/core.py
index 2da6f0bd..43acd603 100644
--- a/s3fs/core.py
+++ b/s3fs/core.py
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 import asyncio
 import errno
+import io
 import logging
 import mimetypes
 import os
@@ -248,6 +249,9 @@ class S3FileSystem(AsyncFileSystem):
         this parameter to affect ``pipe()``, ``cat()`` and ``get()``. Increasing this
         value will result in higher memory usage during multipart upload operations (by
         ``max_concurrency * chunksize`` bytes per file).
+    fixed_upload_size : bool (False)
+        Use same chunk size for all parts in multipart upload (last part can be smaller).
+        Cloudflare R2 storage requires fixed_upload_size=True for multipart uploads.
 
     The following parameters are passed on to fsspec:
 
@@ -296,6 +300,7 @@ def __init__(
         asynchronous=False,
         loop=None,
         max_concurrency=1,
+        fixed_upload_size: bool = False,
         **kwargs,
     ):
         if key and username:
@@ -333,6 +338,7 @@ def __init__(
         self.cache_regions = cache_regions
         self._s3 = None
         self.session = session
+        self.fixed_upload_size = fixed_upload_size
         if max_concurrency < 1:
             raise ValueError("max_concurrency must be >= 1")
         self.max_concurrency = max_concurrency
@@ -2330,46 +2336,53 @@ def _upload_chunk(self, final=False):
             and final
             and self.tell() < self.blocksize
         ):
-            # only happens when closing small file, use on-shot PUT
-            data1 = False
+            # only happens when closing small file, use one-shot PUT
+            pass
         else:
             self.buffer.seek(0)
-            (data0, data1) = (None, self.buffer.read(self.blocksize))
 
-        while data1:
-            (data0, data1) = (data1, self.buffer.read(self.blocksize))
-            data1_size = len(data1)
+            def upload_part(part_data: bytes):
+                if len(part_data) == 0:
+                    return
+                part = len(self.parts) + 1
+                logger.debug("Upload chunk %s, %s" % (self, part))
 
-            if 0 < data1_size < self.blocksize:
-                remainder = data0 + data1
-                remainder_size = self.blocksize + data1_size
+                out = self._call_s3(
+                    "upload_part",
+                    Bucket=bucket,
+                    PartNumber=part,
+                    UploadId=self.mpu["UploadId"],
+                    Body=part_data,
+                    Key=key,
+                )
 
-                if remainder_size <= self.part_max:
-                    (data0, data1) = (remainder, None)
-                else:
-                    partition = remainder_size // 2
-                    (data0, data1) = (remainder[:partition], remainder[partition:])
+                part_header = {"PartNumber": part, "ETag": out["ETag"]}
+                if "ChecksumSHA256" in out:
+                    part_header["ChecksumSHA256"] = out["ChecksumSHA256"]
+                self.parts.append(part_header)
 
-            part = len(self.parts) + 1
-            logger.debug("Upload chunk %s, %s" % (self, part))
+            def n_bytes_left() -> int:
+                return len(self.buffer.getbuffer()) - self.buffer.tell()
 
-            out = self._call_s3(
-                "upload_part",
-                Bucket=bucket,
-                PartNumber=part,
-                UploadId=self.mpu["UploadId"],
-                Body=data0,
-                Key=key,
-            )
-
-            part_header = {"PartNumber": part, "ETag": out["ETag"]}
-            if "ChecksumSHA256" in out:
-                part_header["ChecksumSHA256"] = out["ChecksumSHA256"]
-            self.parts.append(part_header)
+            min_chunk = 1 if final else self.blocksize
+            if self.fs.fixed_upload_size:
+                # all chunks have fixed size, exception: last one can be smaller
+                while n_bytes_left() >= min_chunk:
+                    upload_part(self.buffer.read(self.blocksize))
+            else:
+                while n_bytes_left() >= min_chunk:
+                    upload_part(self.buffer.read(self.part_max))
 
         if self.autocommit and final:
             self.commit()
-        return not final
+        else:
+            # update 'upload offset'
+            self.offset += self.buffer.tell()
+            # create new smaller buffer, seek to file end
+            self.buffer = io.BytesIO(self.buffer.read())
+            self.buffer.seek(0, 2)
+
+        return False  # instruct fsspec.flush to NOT clear self.buffer
 
     def commit(self):
         logger.debug("Commit %s" % self)
diff --git a/s3fs/tests/test_s3fs.py b/s3fs/tests/test_s3fs.py
index d3d90899..2275519a 100644
--- a/s3fs/tests/test_s3fs.py
+++ b/s3fs/tests/test_s3fs.py
@@ -884,6 +884,9 @@ def test_seek(s3):
     with s3.open(a, "wb") as f:
         f.write(b"123")
 
+    with s3.open(a) as f:
+        assert f.read() == b"123"
+
     with s3.open(a) as f:
         f.seek(1000)
         with pytest.raises(ValueError):
@@ -2749,3 +2752,50 @@ def test_bucket_versioning(s3):
     assert s3.is_bucket_versioned("maybe_versioned")
     s3.make_bucket_versioned("maybe_versioned", False)
     assert not s3.is_bucket_versioned("maybe_versioned")
+
+
+@pytest.fixture()
+def s3_fixed_upload_size(s3):
+    s3_fixed = S3FileSystem(
+        anon=False,
+        client_kwargs={"endpoint_url": endpoint_uri},
+        fixed_upload_size=True,
+    )
+    s3_fixed.invalidate_cache()
+    yield s3_fixed
+
+
+def test_upload_parts(s3_fixed_upload_size):
+    with s3_fixed_upload_size.open(a, "wb", block_size=6_000_000) as f:
+        f.write(b" " * 6_001_000)
+        assert len(f.buffer.getbuffer()) == 1000
+        # check we are at the right position
+        assert f.tell() == 6_001_000
+        # offset is introduced in fsspec.core, but never used.
+        # apparently it should keep offset for part that is already uploaded
+        assert f.offset == 6_000_000
+        f.write(b" " * 6_001_000)
+        assert len(f.buffer.getbuffer()) == 2000
+        assert f.tell() == 2 * 6_001_000
+        assert f.offset == 2 * 6_000_000
+
+    with s3_fixed_upload_size.open(a, "r") as f:
+        assert len(f.read()) == 6_001_000 * 2
+
+
+def test_upload_part_with_prime_pads(s3_fixed_upload_size):
+    block = 6_000_000
+    pad1, pad2 = 1013, 1019  # prime pad sizes to exclude divisibility
+    with s3_fixed_upload_size.open(a, "wb", block_size=block) as f:
+        f.write(b" " * (block + pad1))
+        assert len(f.buffer.getbuffer()) == pad1
+        # check we are at the right position
+        assert f.tell() == block + pad1
+        assert f.offset == block
+        f.write(b" " * (block + pad2))
+        assert len(f.buffer.getbuffer()) == pad1 + pad2
+        assert f.tell() == 2 * block + pad1 + pad2
+        assert f.offset == 2 * block
+
+    with s3_fixed_upload_size.open(a, "r") as f:
+        assert len(f.read()) == 2 * block + pad1 + pad2