Skip to content

Latest commit

ย 

History

History
566 lines (462 loc) ยท 20.9 KB

File metadata and controls

566 lines (462 loc) ยท 20.9 KB

CODE GUIDE

๋งˆ์ง€๋ง‰ ์ˆ˜์ • ์ผ์ž: README ์ƒํƒœ: ๐Ÿ“๊ฒŒ์‹œํŒ ์ž‘์„ฑ์ผ์‹œ: 2020๋…„ 11์›” 2์ผ ์˜ค์ „ 11:53

์‹œ์ž‘์— ์•ž์„œ

์ด ํ”„๋กœ์ ํŠธ๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ 2๊ฐœ์˜ python ํŒŒ์ผ (run_websocket_server(0.2.3).py, run_websocket_client(0.2.3).py**)**๋กœ ๊ตฌ์„ฑ๋˜์–ด์žˆ์Šต๋‹ˆ๋‹ค. ๋‘ ํŒŒ์ผ์€ ๊ฐ๊ฐ ๋ผ์ฆˆ๋ฒ ๋ฆฌํŒŒ์ด์™€ ์ปดํ“จํ„ฐ์— ๋‹ค์šด๋กœ๋“œ๋˜์–ด์žˆ์–ด์•ผ ํ•˜๋ฉฐ ์„œ๋ฒ„๋ฅผ ๊ตฌ์„ฑ ๋ฐ ๊ตฌ๋™ํ•˜๋Š” ์—ญํ• ์„ ํ•ฉ๋‹ˆ๋‹ค. ์—ญํ•  ๋ฐ ์‚ฌ์šฉ๋ฒ•์— ๋Œ€ํ•œ ์ž์„ธํ•œ ์„ค๋ช…์€ '์ด ๊ณณ'์„ ์ฐธ์กฐํ•ด ์ฃผ์‹ญ์‹œ์˜ค. ๋ณธ ๋ฌธ์„œ์—์„œ๋Š” ํ–ฅํ›„ ์œ ์ง€๋ณด์ˆ˜๋ฅผ ์œ„ํ•ด ํŒŒ์ด์ฌ ํŒŒ์ผ์˜ ์ฝ”๋“œ๋งŒ์„ ๊ฐ„๋žตํ•˜๊ฒŒ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.

run_websocket_server(0.2.3).py

์—ญํ• 

โœ”๏ธ ๋ผ์ฆˆ๋ฒ ๋ฆฌํŒŒ์ด์—์„œ ์‹คํ–‰ํ•˜๋ฉฐ ์ค‘์•™ ์„œ๋ฒ„๋กœ๋ถ€ํ„ฐ ์ปค๋งจ๋“œ๋ฅผ ์ˆ˜์‹ ํ•˜๋Š” ์„œ๋ฒ„๋ฅผ ๊ตฌ๋™ํ•ฉ๋‹ˆ๋‹ค.

โœ”๏ธ ์ „์ฒ˜๋ฆฌ๋œ ๋ฐ์ดํ„ฐ์…‹์„ ํ† ๋Œ€๋กœ ์ˆ˜์‹ ํ•œ ์ปค๋งจ๋“œ์— ๋”ฐ๋ผ ์ •ํ•ด์ง„ ๋™์ž‘์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

์ฝ”๋“œ

  1. def : start_proc
def start_proc(participant, kwargs):  # pragma: no cover
    """ helper function for spinning up a websocket participant """

    def target():
        server = participant(**kwargs)
        server.start()

    p = Process(target=target)
    p.start()
    return p
  • ๋งค๊ฐœ๋ณ€์ˆ˜ participant ์—๋Š” syft ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ WebsocketServerWorker ํ•จ์ˆ˜๊ฐ€ ์ฃผ๋กœ ๋“ค์–ด๊ฐ‘๋‹ˆ๋‹ค.
  • ๋งค๊ฐœ๋ณ€์ˆ˜ kwargs ์—๋Š” ํŒŒ์ผ์„ ์‹คํ–‰ํ•  ๋•Œ ๋„ฃ์€ ์ธ์ž๊ฐ’์ด ๋“ค์–ด๊ฐ‘๋‹ˆ๋‹ค.
  • ๋‘ ๊ฐœ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ํ† ๋Œ€๋กœ ์„œ๋ฒ„๋ฅผ ๊ตฌ๋™ํ•ฉ๋‹ˆ๋‹ค. multiprocess ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ Processs ํ•จ์ˆ˜๋ฅผ ์ด์šฉํ•œ ์“ฐ๋ ˆ๋”ฉ์„ ์ด์šฉํ•ฉ๋‹ˆ๋‹ค.

2. part : parser

parser = argparse.ArgumentParser(description="Run websocket server worker.")
parser.add_argument(
    "--port", "-p", type=int, help="port number of the websocket server worker, e.g. --port 8777"
)
parser.add_argument("--host", type=str, default="localhost", help="host for the connection")
parser.add_argument(
    "--id", type=str, help="name (id) of the websocket server worker, e.g. --id alice"
)
parser.add_argument(
    "--verbose",
    "-v",
    action="store_true",
    help="if set, websocket server worker will be started in verbose mode",
)

args = parser.parse_args()
  • ํŒŒ์ผ์„ ์‹คํ–‰ํ•  ๋•Œ ๊ฐ€๋Šฅํ•œ ์˜ต์…˜(argument)์— ๋Œ€ํ•ด ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
  • port, host(์•„์ดํ”ผ ์ฃผ์†Œ), id, verbose 4๊ฐœ์˜ ์ธ์ž๊ฐ€ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค.

3. part : main

kwargs = {
    "id": args.id,
    "host": args.host,
    "port": args.port,
    "hook": hook,
    "verbose": args.verbose,
}
server = start_proc(WebsocketServerWorker, kwargs)
  • ์‹คํ–‰์‹œ ๋ฐ›์€ ์ธ์ž๋ฅผ ํ† ๋Œ€๋กœ start_proc์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.

FL_ECG.py

์—ญํ• 

โœ”๏ธ ์ค‘์•™ ์žฅ์น˜(๋ฐ์Šคํฌํ†ฑ, ๋…ธํŠธ๋ถ)์— ์กด์žฌํ•˜๋ฉฐ ์‹คํ–‰ ์‹œ ๋ผ์ฆˆ๋ฒ ๋ฆฌํŒŒ์ด๊ฐ€ ๊ตฌ๋™์ค‘์ธ ์„œ๋ฒ„์— ์—ฐ๊ฒฐํ•ฉ๋‹ˆ๋‹ค.

โœ”๏ธ ๊ฐ ๋ผ์ฆˆ๋ฒ ๋ฆฌํŒŒ์ด์— FL์„ ์œ„ํ•œ ์ปค๋งจ๋“œ๋ฅผ ์†ก์‹ ํ•ฉ๋‹ˆ๋‹ค.

โœ”๏ธ ๋ผ์ฆˆ๋ฒ ๋ฆฌํŒŒ์ด๋กœ๋ถ€ํ„ฐ ์ˆ˜์‹ ํ•œ ๋ชจ๋ธ์„ ํ•ฉ์‚ฐ, ์ฒ˜๋ฆฌํ•œ ํ›„ ์—…๋ฐ์ดํŠธ๋œ ๋ชจ๋ธ์„ ํšŒ์‹ ํ•ฉ๋‹ˆ๋‹ค.

์ฝ”๋“œ

  1. class : ๋„คํŠธ์›Œํฌ
import torch.nn as nn
import torch.nn.functional as f

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = f.relu(self.conv1(x))
        x = f.max_pool2d(x, 2, 2)
        x = f.relu(self.conv2(x))
        x = f.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = f.relu(self.fc1(x))
        x = self.fc2(x)
        return f.log_softmax(x, dim=1)
  • ๋จธ์‹ ๋Ÿฌ๋‹ ๋ชจ๋ธ ๋„คํŠธ์›Œํฌ๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
  • nn ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ ์ถ”์ƒํ™”๋œ Module ํด๋ž˜์Šค๋ฅผ ์ƒ์†ํ•ฉ๋‹ˆ๋‹ค.
  • nn ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ functional ์—์„œ relu, pool๊ณผ ๊ฐ™์€ ๋ ˆ์ด์–ด ํ”„๋ฆฌ์…‹์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

2. def : train_on_batches

๋งค๊ฐœ๋ณ€์ˆ˜

import torch.optim as optim

def train_on_batches(worker, batches, model_in, device, lr):
    """Train the model on the worker on the provided batches
    Args:
        worker(syft.workers.BaseWorker): worker on which the
        training will be executed
        batches: batches of data of this worker
        model_in: machine learning model, training will be done on a copy
        device (torch.device): where to run the training
        lr: learning rate of the training steps
    Returns:
        model, loss: obtained model and loss after training
    """
    model = model_in.copy()
    optimizer = optim.SGD(model.parameters(), lr=lr)  # TODO momentum is not supported at the moment

    model.train()
    model.send(worker)
    loss_local = False
  • optimizer์— ์‚ฌ์šฉํ•˜๊ณ ์ž ํ•˜๋Š” ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์„ ํƒํ•˜๊ณ  ๋ชจ๋ธ์„ ์—ฐ๊ฒฐ ํ•œ ํ›„ ํ•™์Šต๋ฅ ์„ ์ •ํ•ฉ๋‹ˆ๋‹ค.
  • train() ํ•จ์ˆ˜๋ฅผ ์ด์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ํ•™์Šต๋ชจ๋“œ๋กœ ์ „ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
 for batch_idx, (data, target) in enumerate(batches):
        loss_local = False
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = f.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            loss = loss.get()  # <-- NEW: get the loss back
            loss_local = True
            logger.debug(
                "Train Worker {}: [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    worker.id,
                    batch_idx,
                    len(batches),
                    100.0 * batch_idx / len(batches),
                    loss.item(),
                )
            )

    if not loss_local:
        loss = loss.get()  # <-- NEW: get the loss back
    model.get()  # <-- NEW: get the model back
    return model, loss
  • ๋ฐ›์•„์˜จ ๋ฐฐ์น˜ ๋ฐ์ดํ„ฐ์…‹์„ ๋””๋ฐ”์ด์Šค๋กœ ๋ณด๋ƒ…๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ์…‹์€ ๋ฐ์ดํ„ฐ์™€ ํƒ€๊ฒŸ์œผ๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ์Šต๋‹ˆ๋‹ค.

3. def : get_next_batches

๋งค๊ฐœ๋ณ€์ˆ˜

def get_next_batches(fdataloader: sy.FederatedDataLoader, nr_batches: int):
    """retrieve next nr_batches of the federated data loader and group
    the batches by worker
    Args:
        fdataloader (sy.FederatedDataLoader): federated data loader
        over which the function will iterate
        nr_batches (int): number of batches (per worker) to retrieve
    Returns:
        Dict[syft.workers.BaseWorker, List[batches]]
    """
    batches = {}
    for worker_id in fdataloader.workers:
        worker = fdataloader.federated_dataset.datasets[worker_id].location
        batches[worker] = []
    try:
        for i in range(nr_batches):
            next_batches = next(fdataloader)
            for worker in next_batches:
                batches[worker].append(next_batches[worker])
    except StopIteration:
        pass
    return batches

4. def : train

๋งค๊ฐœ๋ณ€์ˆ˜

def train(
    model, device, federated_train_loader, lr, federate_after_n_batches, abort_after_one=False
):
    model.train()

    nr_batches = federate_after_n_batches

    models = {}
    loss_values = {}

    iter(federated_train_loader)  # initialize iterators
    batches = get_next_batches(federated_train_loader, nr_batches)
    counter = 0

    while True:
        logger.debug(
            "Starting training round, batches [{}, {}]".format(counter, counter + nr_batches)
        )
        data_for_all_workers = True
        for worker in batches:
            curr_batches = batches[worker]
            if curr_batches:
                models[worker], loss_values[worker] = train_on_batches(
                    worker, curr_batches, model, device, lr
                )
            else:
                data_for_all_workers = False
        counter += nr_batches
        if not data_for_all_workers:
            logger.debug("At least one worker ran out of data, stopping.")
            break

        model = utils.federated_avg(models)
        batches = get_next_batches(federated_train_loader, nr_batches)
        if abort_after_one:
            break
    return model
  • ๋ชจ๋ธ์„ ํŠธ๋ ˆ์ด๋‹ํ•˜๋Š” ๋ถ€๋ถ„์ž…๋‹ˆ๋‹ค. model.train() ์— ์‚ฌ์šฉ๋˜๋Š” ํŠธ๋ ˆ์ด๋‹ ๋ฉ”์„œ๋“œ์™€๋Š” ๋ณ„๊ฐœ์ž…๋‹ˆ๋‹ค.
  • batches ๋ณ€์ˆ˜์—๋Š” get_next_batches ํ•จ์ˆ˜๋ฅผ ์ด์šฉํ•˜์—ฌ ๋ฏธ๋ฆฌ ์ •ํ•œ ๋ฐฐ์น˜ ์ˆ˜ ๋งŒํผ์˜ ๋ฐ์ดํ„ฐ์…‹์„ ๋ฐ›์•„์˜ต๋‹ˆ๋‹ค.

5. def : test

๋งค๊ฐœ๋ณ€์ˆ˜

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += f.nll_loss(output, target, reduction="sum").item()  # sum up batch loss
            pred = output.argmax(1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    logger.debug("\n")
    accuracy = 100.0 * correct / len(test_loader.dataset)
    logger.info(
        "Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss, correct, len(test_loader.dataset), accuracy
        )
    )

6. def : define_and_get_arguments

๋งค๊ฐœ๋ณ€์ˆ˜

def define_and_get_arguments(args=sys.argv[1:]):
    parser = argparse.ArgumentParser(
        description="Run federated learning using websocket client workers."
    )
    parser.add_argument("--batch_size", type=int, default=64, help="batch size of the training")
    parser.add_argument(
        "--test_batch_size", type=int, default=1000, help="batch size used for the test data"
    )
    parser.add_argument("--epochs", type=int, default=2, help="number of epochs to train")
    parser.add_argument(
        "--federate_after_n_batches",
        type=int,
        default=50,        help="number of training steps performed on each remote worker " "before averaging",
    )
    parser.add_argument("--lr", type=float, default=0.01, help="learning rate")
    parser.add_argument("--cuda", action="store_true", help="use cuda")
    parser.add_argument("--seed", type=int, default=1, help="seed used for randomization")
    parser.add_argument("--save_model", action="store_true", help="if set, model will be saved")
    parser.add_argument(
        "--verbose",
        "-v",
        action="store_true",
        help="if set, websocket client workers will " "be started in verbose mode",
    )
    parser.add_argument(
        "--use_virtual", action="store_true", help="if set, virtual workers will be used"
    )

    args = parser.parse_args(args=args)
    return args

7. def : main

๋งค๊ฐœ๋ณ€์ˆ˜ ์—†์Œ

def main():
    args = define_and_get_arguments()

    hook = sy.TorchHook(torch)

    # ๊ฐ€์ƒ์ž‘์—…์ž(์‹œ๋ฎฌ๋ ˆ์ด์…˜) ์‚ฌ์šฉ์‹œ ์ด๊ณณ์œผ๋กœ ๋ถ„๊ธฐ
    if args.use_virtual:
        alice = VirtualWorker(id="alice", hook=hook, verbose=args.verbose)
        bob = VirtualWorker(id="bob", hook=hook, verbose=args.verbose)
        charlie = VirtualWorker(id="charlie", hook=hook, verbose=args.verbose)
    # ์›น์†Œ์ผ“์ž‘์—…์ž ์‚ฌ์šฉ์‹œ ์ด๊ณณ์œผ๋กœ ๋ถ„๊ธฐ
    else:
        a_kwargs_websocket = {"host": "192.168.0.52", "hook": hook}
        b_kwargs_websocket = {"host": "192.168.0.53", "hook": hook}
        c_kwargs_websocket = {"host": "192.168.0.54", "hook": hook}

        baseport = 10002
        alice = WebsocketClientWorker(id="alice", port=baseport, **a_kwargs_websocket)
        bob = WebsocketClientWorker(id="bob", port=baseport, **b_kwargs_websocket)
        charlie = WebsocketClientWorker(id="charlie", port=baseport, **c_kwargs_websocket)

		# ๊ฐ์ฒด๋ฅผ ๋ฆฌ์ŠคํŠธ๋กœ ๋ฌถ์Œ
    workers = [alice, bob, charlie]

		# ์ฟ ๋‹ค ์‚ฌ์šฉ ์—ฌ๋ถ€
    use_cuda = args.cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

		# ๋žœ๋ค ์‹œ๋“œ ์„ค์ •
    torch.manual_seed(args.seed)
  • define_and_get_arguments() ๋ฅผ ์ด์šฉํ•˜์—ฌ ์‹คํ–‰ ์˜ต์…˜์„ ๋ฐ›์•„์˜ต๋‹ˆ๋‹ค.
  • use_virtual ์˜ต์…˜์„ ์‹คํ–‰ํ–ˆ์„ ๊ฒฝ์šฐ ์›น์†Œ์ผ“์„ ์ด์šฉํ•˜์ง€ ์•Š๊ณ  ๊ฐ€์ƒ ์›Œ์ปค๋กœ ์‹œ๋ฎฌ๋ ˆ์ด์…˜ ํ•ฉ๋‹ˆ๋‹ค. ์‹ค์ œ๋กœ ๋ผ์ฆˆ๋ฒ ๋ฆฌ ํŒŒ์ด์— ์—ฐ๊ฒฐํ•˜์—ฌ ์‹คํ–‰ํ•˜๊ธฐ ์ „ ๊ฐ€์ƒ ์›Œ์ปค ์‹œ๋ฎฌ๋ ˆ์ด์…˜์„ ์ด์šฉํ•ด ํ…Œ์ŠคํŠธ ์‹œ๊ฐ„์„ ๋‹จ์ถ•ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • use_virtual์„ ๋”ฐ๋กœ ์„ค์ •ํ•˜์ง€ ์•Š์•˜์„ ๊ฒฝ์šฐ ์›น์†Œ์ผ“์œผ๋กœ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ฒฝ์šฐ kwargs_websocket์—๋Š” ๋ผ์ฆˆ๋ฒ ๋ฆฌํŒŒ์ด์˜ IP, hook์ด ์ฃผ์–ด์ง‘๋‹ˆ๋‹ค. ๊ทธ ํ›„, WebsocketClientWorker๋ฅผ ์ด์šฉํ•˜์—ฌ ๊ฐ ID์— ์›Œ์ปค ๊ฐ์ฒด๋ฅผ ํ• ๋‹นํ•ฉ๋‹ˆ๋‹ค.
  • ๋”ฑํžˆ ์ˆ˜์ •ํ•  ์ผ์ด ์—†๋Š” ํŒŒํŠธ์ž…๋‹ˆ๋‹ค.
  federated_train_loader = sy.FederatedDataLoader(
        datasets.MNIST(
            "../data",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        ).federate(tuple(workers)),
        batch_size=args.batch_size,
        shuffle=True,
        iter_per_worker=True,
        **kwargs,
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "../data",
            train=False,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        ),
        batch_size=args.test_batch_size,
        shuffle=True,
        **kwargs,
    )

    model = Net().to(device)

    for epoch in range(1, args.epochs + 1):
				# output : 2020-11-05 15:07:04,953 INFO run_websocket_client(0.2.3).py(l:268) - Starting epoch 1/2
        logger.info("Starting epoch %s/%s", epoch, args.epochs)
        model = train(model, device, federated_train_loader, args.lr, args.federate_after_n_batches)
        test(model, device, test_loader)

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")
  • federated_train_loader๋Š” ๋ถˆ๋Ÿฌ์˜จ datasets์„ Federated Learning์ด ๊ฐ€๋Šฅํ•œ ๊ฐ์ฒด๋กœ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
  • FederateDataloader๋Š” Federate Learning์„ ์‹คํ–‰ํ•˜๊ธฐ ์œ„ํ•œ ๋ช…๋ น์–ด๋“ค์ด ๋ชจ์—ฌ์žˆ๋Š” ๊ฐ์ฒด์ž…๋‹ˆ๋‹ค. ๋ฐ˜๋ณต์ž๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • ๊ทธ ์ „์— ๊ฐ ์›Œ์ปค์—๊ฒŒ ๋ฐ์ดํ„ฐ๋ฅผ ๋ถ„๋ฐฐํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” datasets ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ federated(tuple(workers))๋ฅผ ์ด์šฉํ•ฉ๋‹ˆ๋‹ค.
  • ์ด์ œ args.epochs์— ๋ช…์‹œ๋œ ์ˆ˜๋งŒํผ ํ•™์Šต์„ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค. ์ด epoch๋Š” ์ค‘์•™ ์„œ๋ฒ„์—์„œ ๋ชจ๋ธ์„ ์ง‘๊ณ„ํ•˜๋Š” epoch์ž…๋‹ˆ๋‹ค. ๊ธฐ๋ณธ๊ฐ’์€ 2์ž…๋‹ˆ๋‹ค.

8. part : main

if __name__ == "__main__":
    FORMAT = "%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d) - %(message)s"
    LOG_LEVEL = logging.DEBUG
    logging.basicConfig(format=FORMAT, level=LOG_LEVEL)

    websockets_logger = logging.getLogger("websockets")
    websockets_logger.setLevel(logging.DEBUG)
    websockets_logger.addHandler(logging.StreamHandler())

    main()
  • ๋ณ„๊ฑด ์—†๊ณ  ๋กœ๊น… ๋ฉ”์‹œ์ง€ ์„ค์ •๊ณผ ๋ฉ”์ธ ํ•จ์ˆ˜ ์ง„์ž…ํ•˜๋Š” ๋‘๊ฐ€์ง€ ํŒŒํŠธ๋กœ ๋‚˜๋‰ฉ๋‹ˆ๋‹ค.
  • getLogger๋ฅผ ์ด์šฉํ•ด websockets๋ผ๋Š” ๋กœ๊ฑฐ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. setLevel์„ ์ด์šฉํ•ด DEBUG ๋ ˆ๋ฒจ ์œ„์˜ ๋ ˆ๋ฒจ์€ ๋ชจ๋‘ ํ”„๋ฆฐํŠธํ•ฉ๋‹ˆ๋‹ค. (๋กœ๊ฑฐ ๋ ˆ๋ฒจ์€ DEBUG, INFO, WARNING, ERROR, CRITICAL 5๊ฐœ๊ฐ€ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค.)
  • addHandler๋ฅผ ์ด์šฉํ•ด ์ฝ˜์†”์ฐฝ์— ๋กœ๊ทธ๊ฐ€ ์ถœ๋ ฅ๋˜๊ฒŒ๋” ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค. ํŒŒ์ผ,DB,์†Œ์ผ“,ํ ๋“ฑ์„ ํ†ตํ•ด ์ถœ๋ ฅํ•˜๋„๋ก ์„ค์ •ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๋กœ๊น…์— ๋Œ€ํ•ด ์ฐธ๊ณ ํ• ๋งŒํ•œ ๋ธ”๋กœ๊ทธ ๊ธ€ โฌ‡๏ธ

ํŒŒ์ด์ฌ ๋กœ๊น…์˜ ๋ชจ๋“ ๊ฒƒ

์ฐธ๊ณ ํ•  ํ•จ์ˆ˜

  • FederatedDataLoader

    class FederatedDataLoader(object):
        """
        Data loader. Combines a dataset and a sampler, and provides
        single or several iterators over the dataset.
        Arguments:
            federated_dataset (FederatedDataset): dataset from which to load the data.
            batch_size (int, optional): how many samples per batch to load
                (default: ``1``).
            shuffle (bool, optional): set to ``True`` to have the data reshuffled
                at every epoch (default: ``False``).
            collate_fn (callable, optional): merges a list of samples to form a mini-batch.
            drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
                if the dataset size is not divisible by the batch size. If ``False`` and
                the size of dataset is not divisible by the batch size, then the last batch
                will be smaller. (default: ``False``)
            num_iterators (int): number of workers from which to retrieve data in parallel.
                num_iterators <= len(federated_dataset.workers) - 1
                the effect is to retrieve num_iterators epochs of data but at each step data
                from num_iterators distinct workers is returned.
            iter_per_worker (bool): if set to true, __next__() will return a dictionary
                containing one batch per worker
        """
    
        __initialized = False
    
        def __init__(
            self,
            federated_dataset,
            batch_size=8,
            shuffle=False,
            num_iterators=1,
            drop_last=False,
            collate_fn=default_collate,
            iter_per_worker=False,
            **kwargs,
        ):
            if len(kwargs) > 0:
                options = ", ".join([f"{k}: {v}" for k, v in kwargs.items()])
                logging.warning(f"The following options are not supported: {options}")
    
            try:
                self.workers = federated_dataset.workers
            except AttributeError:
                raise Exception(
                    "Your dataset is not a FederatedDataset, please use "
                    "torch.utils.data.DataLoader instead."
                )
    
            self.federated_dataset = federated_dataset
            self.batch_size = batch_size
            self.drop_last = drop_last
            self.collate_fn = collate_fn
            self.iter_class = _DataLoaderOneWorkerIter if iter_per_worker else _DataLoaderIter
    
            # Build a batch sampler per worker
            self.batch_samplers = {}
            for worker in self.workers:
                data_range = range(len(federated_dataset[worker]))
                if shuffle:
                    sampler = RandomSampler(data_range)
                else:
                    sampler = SequentialSampler(data_range)
                batch_sampler = BatchSampler(sampler, batch_size, drop_last)
                self.batch_samplers[worker] = batch_sampler
    
            if iter_per_worker:
                self.num_iterators = len(self.workers)
            else:
                # You can't have more iterators than n - 1 workers, because you always
                # need a worker idle in the worker switch process made by iterators
                if len(self.workers) == 1:
                    self.num_iterators = 1
                else:
                    self.num_iterators = min(num_iterators, len(self.workers) - 1)
    
        def __iter__(self):
            self.iterators = []
            for idx in range(self.num_iterators):
                self.iterators.append(self.iter_class(self, worker_idx=idx))
            return self
    
        def __next__(self):
            if self.num_iterators > 1:
                batches = {}
                for iterator in self.iterators:
                    data, target = next(iterator)
                    batches[data.location] = (data, target)
                return batches
            else:
                iterator = self.iterators[0]
                data, target = next(iterator)
                return data, target
    
        def __len__(self):
            length = len(self.federated_dataset) / self.batch_size
            if self.drop_last:
                return int(length)
            else:
                return math.ceil(length)
    • .federate() ํ•จ์ˆ˜๋ฅผ ์ด์šฉํ•ด federated๋œ ๋ฐ์ดํ„ฐ์…‹์„ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์Šต๋‹ˆ๋‹ค.
  • torch.device

    CUDA Tensors : .to ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ Tensor๋ฅผ ์–ด๋– ํ•œ ์žฅ์น˜๋กœ๋„ ์˜ฎ๊ธธ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

    # ์ด ์ฝ”๋“œ๋Š” CUDA๊ฐ€ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ํ™˜๊ฒฝ์—์„œ๋งŒ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
    # ``torch.device`` ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ tensor๋ฅผ GPU ์•ˆํŒŽ์œผ๋กœ ์ด๋™ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
    if torch.cuda.is_available(): 
    	device = torch.device("cuda")         # CUDA ์žฅ์น˜ ๊ฐ์ฒด(device object)๋กœ 
    	y = torch.ones_like(x, device=device) # GPU ์ƒ์— ์ง์ ‘์ ์œผ๋กœ tensor๋ฅผ ์ƒ์„ฑํ•˜๊ฑฐ๋‚˜
    	x = x.to(device)                      # ``.to("cuda")`` ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค. 
    	z = x + y 
    	print(z) 
    	print(z.to("cpu", torch.double))      # ``.to`` ๋Š” dtype๋„ ํ•จ๊ป˜ ๋ณ€๊ฒฝํ•ฉ๋‹ˆ๋‹ค!

ํ•ด๊ฒฐํ•  ์ผ

  • ์ค‘์•™์žฅ์น˜๊ฐ€ ์•„๋‹Œ ์›Œ์ปค๊ฐ€ ์†Œ์œ ํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์ด์šฉํ•œ ํ•™์Šต
  • [ ]