You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airavata.apache.org by di...@apache.org on 2022/12/31 10:53:28 UTC

[airavata-mft] branch master updated: Auto deployment of EC2 transfer agents

This is an automated email from the ASF dual-hosted git repository.

dimuthuupe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airavata-mft.git


The following commit(s) were added to refs/heads/master by this push:
     new a1f5432  Auto deployment of EC2 transfer agents
a1f5432 is described below

commit a1f5432d2ed72da205aedcf210dc1cd83df06f13
Author: Dimuthu Wannipurage <di...@gmail.com>
AuthorDate: Sat Dec 31 05:53:05 2022 -0500

    Auto deployment of EC2 transfer agents
---
 .../airavata/mft/api/handler/MFTApiHandler.java    |   1 +
 .../mft/controller/AgentTransferDispatcher.java    | 149 +++++++++++++++-----
 .../apache/airavata/mft/controller/AppConfig.java  |   4 +-
 .../mft/controller/spawner/CloudAgentSpawner.java  |   7 +-
 .../{AwsAgentSpawner.java => EC2AgentSpawner.java} | 150 +++++++++++++++++----
 .../mft/controller/spawner/SSHProvider.java        | 134 ++++++++++++++++++
 .../mft/controller/spawner/SpawnerSelector.java    |   2 +-
 python-cli/mft_cli/mft_cli/main.py                 |   3 +-
 8 files changed, 389 insertions(+), 61 deletions(-)

diff --git a/api/service/src/main/java/org/apache/airavata/mft/api/handler/MFTApiHandler.java b/api/service/src/main/java/org/apache/airavata/mft/api/handler/MFTApiHandler.java
index fd932c0..cc00835 100644
--- a/api/service/src/main/java/org/apache/airavata/mft/api/handler/MFTApiHandler.java
+++ b/api/service/src/main/java/org/apache/airavata/mft/api/handler/MFTApiHandler.java
@@ -265,6 +265,7 @@ public class MFTApiHandler extends MFTTransferServiceGrpc.MFTTransferServiceImpl
             } else if (!mainTransferStatus.isEmpty()){
                 stateBuilder.setState(mainTransferStatus.get(0).getState());
                 stateBuilder.setPercentage(0);
+                stateBuilder.setDescription(mainTransferStatus.get(0).getDescription());
 
                 responseObserver.onNext(stateBuilder.build());
                 responseObserver.onCompleted();
diff --git a/controller/src/main/java/org/apache/airavata/mft/controller/AgentTransferDispatcher.java b/controller/src/main/java/org/apache/airavata/mft/controller/AgentTransferDispatcher.java
index b5a5f1d..21cf35d 100644
--- a/controller/src/main/java/org/apache/airavata/mft/controller/AgentTransferDispatcher.java
+++ b/controller/src/main/java/org/apache/airavata/mft/controller/AgentTransferDispatcher.java
@@ -24,60 +24,134 @@ import org.apache.airavata.mft.api.service.EndpointPaths;
 import org.apache.airavata.mft.api.service.TransferApiRequest;
 import org.apache.airavata.mft.controller.spawner.CloudAgentSpawner;
 import org.apache.airavata.mft.controller.spawner.SpawnerSelector;
+import org.apache.commons.lang3.exception.ExceptionUtils;
 import org.apache.commons.lang3.tuple.Pair;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.springframework.beans.factory.annotation.Autowired;
 
 import java.util.*;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.Future;
+import java.util.concurrent.*;
 import java.util.stream.Collectors;
 
 public class AgentTransferDispatcher {
     private static final Logger logger = LoggerFactory.getLogger(AgentTransferDispatcher.class);
+
+    //getId(transferRequest):Pair<TransferApiRequest, AgentTransferRequest.Builder>
+
     private final Map<String, Pair<TransferApiRequest, AgentTransferRequest.Builder>> pendingTransferRequests = new ConcurrentHashMap<>();
+    private final Map<String, String> pendingTransferIds = new ConcurrentHashMap<>();
+    //getId(transferRequest):consulKey
+
     private final Map<String, String> pendingTransferConsulKeys = new ConcurrentHashMap<>();
-    private final Map<String, Future<String>> pendingAgentSpawners = new ConcurrentHashMap<>();
+
+    //getId(transferRequest):CloudAgentSpawner
+    private final Map<String, CloudAgentSpawner> pendingAgentSpawners = new ConcurrentHashMap<>();
+
+    // getId(transferRequest):Set(TransferId)
     private final Map<String, Set<String>> runningAgentCache = new ConcurrentHashMap<>();
 
+    // AgentID:Spawner - Use this to keep track of agent spawners. This is required to terminate agent
+    private final Map<String, CloudAgentSpawner> agentSpawners = new ConcurrentHashMap<>();
+
+    private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
+
+    // Temporarily store consul key until the optimizer spins up Agents. This will block the same pending transfer
+    // being handled twice
+    private final Set<String> optimizingConsulKeys = new ConcurrentSkipListSet<>();
+
     @Autowired
     private MFTConsulClient mftConsulClient;
 
+    public void init() {
+        scheduler.scheduleWithFixedDelay(() -> {
+            pendingAgentSpawners.forEach((key, spawner) -> {
+                if (spawner.getLaunchState().isDone()) {
+                    String transferId = pendingTransferIds.get(key);
+                    Pair<TransferApiRequest, AgentTransferRequest.Builder> transferRequests = pendingTransferRequests.get(key);
+                    String consulKey = pendingTransferConsulKeys.get(key);
+
+                    try {
+                        String agentId = spawner.getLaunchState().get();
+                        List<String> liveAgentIds = mftConsulClient.getLiveAgentIds();
+                        if (liveAgentIds.stream().noneMatch(id -> id.equals(agentId))) {
+                            throw new Exception("Agent was not registered even though the agent is up");
+                        }
+
+                        submitTransferToAgent(Collections.singletonList(agentId), transferId,
+                                transferRequests.getLeft(), transferRequests.getRight(), consulKey);
+
+                        // Use this to terminate agent in future
+                        agentSpawners.put(agentId, spawner);
+
+                    } catch (Exception e) {
+                        logger.error("Failed to launch agent for key {}", key, e);
+                        try {
+                            mftConsulClient.saveTransferState(transferId, new TransferState()
+                                    .setUpdateTimeMils(System.currentTimeMillis())
+                                    .setState("FAILED").setPercentage(0)
+                                    .setPublisher("controller")
+                                    .setDescription("Failed to launch the agent. " + ExceptionUtils.getRootCauseMessage(e)));
+                        } catch (Exception e2) {
+                            logger.error("Failed to submit transfer fail error for transfer id {}", transferId, e2);
+                        }
+
+                        logger.info("Removing consul key {}", consulKey);
+                        mftConsulClient.getKvClient().deleteKey(consulKey);
+                        logger.info("Terminating the spawner");
+                        spawner.terminate();
+
+                    } finally {
+                        pendingTransferIds.remove(key);
+                        pendingTransferRequests.remove(key);
+                        pendingAgentSpawners.remove(key);
+                        pendingTransferConsulKeys.remove(key);
+                        optimizingConsulKeys.remove(consulKey);
+                    }
+                }
+            });
+        }, 3, 5, TimeUnit.SECONDS);
+    }
+
+
     public void submitTransferToAgent(List<String> filteredAgents, String transferId,
                                       TransferApiRequest transferRequest,
                                       AgentTransferRequest.Builder agentTransferRequestTemplate, String consulKey)
             throws Exception {
 
-        if (filteredAgents.isEmpty()) {
+        try {
+            if (filteredAgents.isEmpty()) {
+                mftConsulClient.saveTransferState(transferId, new TransferState()
+                        .setUpdateTimeMils(System.currentTimeMillis())
+                        .setState("FAILED").setPercentage(0)
+                        .setPublisher("controller")
+                        .setDescription("No qualifying agent was found to orchestrate the transfer"));
+                return;
+            }
+
             mftConsulClient.saveTransferState(transferId, new TransferState()
+                    .setState("STARTING")
+                    .setPercentage(0)
                     .setUpdateTimeMils(System.currentTimeMillis())
-                    .setState("FAILED").setPercentage(0)
                     .setPublisher("controller")
-                    .setDescription("No qualifying agent was found to orchestrate the transfer"));
-            return;
-        }
+                    .setDescription("Initializing the transfer"));
 
-        mftConsulClient.saveTransferState(transferId, new TransferState()
-                .setState("STARTING")
-                .setPercentage(0)
-                .setUpdateTimeMils(System.currentTimeMillis())
-                .setPublisher("controller")
-                .setDescription("Initializing the transfer"));
+            AgentTransferRequest.Builder agentTransferRequest = agentTransferRequestTemplate.clone();
 
-        AgentTransferRequest.Builder agentTransferRequest = agentTransferRequestTemplate.clone();
+            agentTransferRequest.setRequestId(UUID.randomUUID().toString());
+            for (EndpointPaths ep : transferRequest.getEndpointPathsList()) {
+                agentTransferRequest.addEndpointPaths(org.apache.airavata.mft.agent.stub.EndpointPaths.newBuilder()
+                        .setSourcePath(ep.getSourcePath())
+                        .setDestinationPath(ep.getDestinationPath()).buildPartial());
+            }
 
-        agentTransferRequest.setRequestId(UUID.randomUUID().toString());
-        for (EndpointPaths ep : transferRequest.getEndpointPathsList()) {
-            agentTransferRequest.addEndpointPaths(org.apache.airavata.mft.agent.stub.EndpointPaths.newBuilder()
-                    .setSourcePath(ep.getSourcePath())
-                    .setDestinationPath(ep.getDestinationPath()).buildPartial());
+            // TODO use a better way to select the right agent
+            mftConsulClient.commandTransferToAgent(filteredAgents.get(0), transferId, agentTransferRequest.build());
+            mftConsulClient.markTransferAsProcessed(transferId, transferRequest);
+            logger.info("Marked transfer {} as processed", transferId);
+        } finally {
+            mftConsulClient.getKvClient().deleteKey(consulKey);
         }
-
-        // TODO use a better way to select the right agent
-        mftConsulClient.commandTransferToAgent(filteredAgents.get(0), transferId, agentTransferRequest.build());
-        mftConsulClient.markTransferAsProcessed(transferId, transferRequest);
-        mftConsulClient.getKvClient().deleteKey(consulKey);
     }
 
     public void handleTransferRequest(String transferId,
@@ -85,6 +159,12 @@ public class AgentTransferDispatcher {
                                       AgentTransferRequest.Builder agentTransferRequestTemplate,
                                       String consulKey) throws Exception{
 
+        if (optimizingConsulKeys.contains(consulKey)) {
+            logger.info("Ignoring handling transfer id {} as it is already in optimizing stage", transferId);
+            return;
+        }
+
+        logger.info("Handling transfer id {} with consul key {}", transferId, consulKey);
         List<String> liveAgentIds = mftConsulClient.getLiveAgentIds();
 
         Map<String, Integer> targetAgentsMap = transferRequest.getTargetAgentsMap();
@@ -115,22 +195,27 @@ public class AgentTransferDispatcher {
                 if (sourceSpawner.isPresent()) {
                     logger.info("Launching {} spawner in source side for transfer {}",
                             sourceSpawner.get().getClass().getName(), transferId);
-                    Future<String> launchFuture = sourceSpawner.get().launch();
-                    pendingAgentSpawners.put(getId(transferRequest, true), launchFuture);
+
+                    sourceSpawner.get().launch();
+                    pendingAgentSpawners.put(getId(transferRequest, true), sourceSpawner.get());
                     pendingTransferRequests.put(getId(transferRequest, true),
                             Pair.of(transferRequest, agentTransferRequestTemplate));
+                    pendingTransferIds.put(getId(transferRequest, true), transferId);
                     pendingTransferConsulKeys.put(getId(transferRequest, true), consulKey);
-
+                    optimizingConsulKeys.add(consulKey);
+                    return;
                 } else if (destSpawner.isPresent()) {
                     logger.info("Launching {} spawner in destination side for transfer {}",
                             destSpawner.get().getClass().getName(), transferId);
 
-                    Future<String> launchFuture = destSpawner.get().launch();
-                    pendingAgentSpawners.put(getId(transferRequest, false), launchFuture);
+                    destSpawner.get().launch();
+                    pendingAgentSpawners.put(getId(transferRequest, false), destSpawner.get());
                     pendingTransferRequests.put(getId(transferRequest, false),
                             Pair.of(transferRequest, agentTransferRequestTemplate));
+                    pendingTransferIds.put(getId(transferRequest, false), transferId);
                     pendingTransferConsulKeys.put(getId(transferRequest, false), consulKey);
-
+                    optimizingConsulKeys.add(consulKey);
+                    return;
                 } else {
                     logger.warn("No optimizing path is available. Moving user provided agents");
                     submitTransferToAgent(userProvidedAgents, transferId,
@@ -165,8 +250,6 @@ public class AgentTransferDispatcher {
                     agentTransferRequestTemplate,
                     consulKey);
         }
-
-        logger.info("Marked transfer {} as processed", transferId);
     }
 
     private String getId(TransferApiRequest transferRequest, boolean isSource) {
diff --git a/controller/src/main/java/org/apache/airavata/mft/controller/AppConfig.java b/controller/src/main/java/org/apache/airavata/mft/controller/AppConfig.java
index c962bbb..ca31060 100644
--- a/controller/src/main/java/org/apache/airavata/mft/controller/AppConfig.java
+++ b/controller/src/main/java/org/apache/airavata/mft/controller/AppConfig.java
@@ -67,6 +67,8 @@ public class AppConfig {
 
     @Bean
     public AgentTransferDispatcher pathOptimizer() {
-        return new AgentTransferDispatcher();
+        AgentTransferDispatcher agentTransferDispatcher = new AgentTransferDispatcher();
+        agentTransferDispatcher.init();
+        return agentTransferDispatcher;
     }
 }
diff --git a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/CloudAgentSpawner.java b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/CloudAgentSpawner.java
index 8d73a73..8b435ab 100644
--- a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/CloudAgentSpawner.java
+++ b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/CloudAgentSpawner.java
@@ -31,6 +31,9 @@ public abstract class CloudAgentSpawner {
         this.storageWrapper = storageWrapper;
     }
 
-    public abstract Future<String> launch();
-    public abstract Future<Boolean> terminate();
+    public abstract void launch();
+    public abstract Future<String> getLaunchState();
+    public abstract void terminate();
+
+    public abstract Future<Boolean> getTerminateState();
 }
diff --git a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/AwsAgentSpawner.java b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/EC2AgentSpawner.java
similarity index 52%
rename from controller/src/main/java/org/apache/airavata/mft/controller/spawner/AwsAgentSpawner.java
rename to controller/src/main/java/org/apache/airavata/mft/controller/spawner/EC2AgentSpawner.java
index c992e01..890e395 100644
--- a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/AwsAgentSpawner.java
+++ b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/EC2AgentSpawner.java
@@ -25,54 +25,50 @@ import com.amazonaws.services.ec2.AmazonEC2ClientBuilder;
 import com.amazonaws.services.ec2.model.*;
 import org.apache.airavata.mft.agent.stub.SecretWrapper;
 import org.apache.airavata.mft.agent.stub.StorageWrapper;
-import org.apache.airavata.mft.controller.AgentTransferDispatcher;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.File;
+import java.io.IOException;
 import java.nio.charset.StandardCharsets;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.util.*;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.stream.Stream;
 
-public class AwsAgentSpawner extends CloudAgentSpawner {
+public class EC2AgentSpawner extends CloudAgentSpawner {
 
-    private static final Logger logger = LoggerFactory.getLogger(AwsAgentSpawner.class);
+    private static final Logger logger = LoggerFactory.getLogger(EC2AgentSpawner.class);
 
     private final ExecutorService executor = Executors.newSingleThreadExecutor();
 
-    public AwsAgentSpawner(StorageWrapper storageWrapper, SecretWrapper secretWrapper) {
+    private String instanceId;
+    private CountDownLatch portForwardLock;
+    Future<String> launchFuture;
+    Future<Boolean> terminateFuture;
+
+    public EC2AgentSpawner(StorageWrapper storageWrapper, SecretWrapper secretWrapper) {
         super(storageWrapper, secretWrapper);
     }
 
     @Override
-    public Future<String> launch() {
+    public void launch() {
 
-        return executor.submit( () -> {
-            String imageId = "ami-0574da719dca65348";
+        launchFuture = executor.submit( () -> {
+            String imageId = "ami-0ecc74eca1d66d8a6"; // Ubuntu base image
             String keyNamePrefix = "mft-aws-agent-key-";
-            String region = storageWrapper.getS3().getRegion();
             String secGroupName = "MFTAgentSecurityGroup";
             String agentId = UUID.randomUUID().toString();
+            String systemUser = "ubuntu";
 
             String mftKeyDir = System.getProperty("user.home") + File.separator + ".mft" + File.separator + "keys";
             String accessKey = secretWrapper.getS3().getAccessKey();
             String secretKey = secretWrapper.getS3().getSecretKey();
-
-            String cloudInit =
-                    "#cloud-config\n" +
-                            "\n" +
-                            "runcmd:\n" +
-                            " - apt install -y openjdk-11-jre-headless\n" +
-                            " - apt install -y unzip\n" +
-                            " - wget https://github.com/apache/airavata-mft/releases/download/v0.0.1/MFT-Agent-0.01-bin.zip\n" +
-                            " - unzip MFT-Agent-0.01-bin.zip -d /home/ubuntu/\n" +
-                            " - sed -ir \"s/^[#]*\\s*agent.id=.*/agent.id=" + agentId + "/\" /home/ubuntu/MFT-Agent-0.01/conf/application.properties\n" +
-                            " - chown -R ubuntu:ubuntu /home/ubuntu/MFT-Agent-0.01/\n";
+            String region = storageWrapper.getS3().getRegion();
 
             try {
                 BasicAWSCredentials awsCreds = new BasicAWSCredentials(accessKey, secretKey);
@@ -159,22 +155,130 @@ public class AwsAgentSpawner extends CloudAgentSpawner {
                                 new TagSpecification().withResourceType(ResourceType.Instance)
                                         .withTags(new Tag().withKey("Type").withValue("MFT-Agent"),
                                                 new Tag().withKey("AgentId").withValue(agentId)))
-                        .withUserData(Base64.getEncoder().encodeToString(cloudInit.getBytes()))
                         .withSecurityGroups(secGroupName);
 
 
                 logger.info("Launching the EC2 VM to start Agent {}", agentId);
                 RunInstancesResult result = amazonEC2.runInstances(runInstancesRequest);
 
+                instanceId = result.getReservation().getInstances().get(0).getInstanceId();
+
+                try {
+                    DescribeInstancesRequest describeInstancesRequest = new DescribeInstancesRequest();
+                    describeInstancesRequest.setInstanceIds(Collections.singletonList(instanceId));
+
+                    InstanceState instanceState = null;
+                    String publicIpAddress = null;
+
+                    logger.info("Waiting until instance {} is ready", instanceId);
+
+                    for (int i = 0; i < 30; i++) {
+                        DescribeInstancesResult describeInstancesResult = amazonEC2.describeInstances(describeInstancesRequest);
+                        Instance instance = describeInstancesResult.getReservations().get(0).getInstances().get(0);
+                        instanceState = instance.getState();
+                        publicIpAddress = instance.getPublicIpAddress();
+
+                        logger.info("Instance state {}, public ip {}", instanceState.getName(), publicIpAddress);
+
+                        if (instanceState.getName().equals("running") && publicIpAddress != null) {
+                            break;
+                        }
+                        Thread.sleep(2000);
+                    }
+
+                    logger.info("Waiting 30 seconds until the ssh interface comes up in instance {}", instanceId);
+                    Thread.sleep(30000);
+                    if ("running".equals(instanceState.getName()) && publicIpAddress != null) {
+                        SSHProvider portForwardAgent = new SSHProvider();
+                        portForwardAgent.initConnection(publicIpAddress, 22,
+                                Path.of(mftKeyDir, keyName).toAbsolutePath().toString(), systemUser);
+                        logger.info("Created SSH Connection. Installing dependencies...");
+
+                        int exeCode = portForwardAgent.runCommand("sudo apt install -y openjdk-11-jre-headless");
+                        if (exeCode != 0)
+                            throw new IOException("Failed to install jdk on new VM");
+                        exeCode = portForwardAgent.runCommand("sudo apt install -y unzip");
+                        if (exeCode != 0)
+                            throw new IOException("Failed to install unzip on new VM");
+                        exeCode = portForwardAgent.runCommand("wget https://github.com/apache/airavata-mft/releases/download/v0.0.1/MFT-Agent-0.01-bin.zip");
+                        if (exeCode != 0)
+                            throw new IOException("Failed to download mft distribution");
+                        exeCode = portForwardAgent.runCommand("unzip MFT-Agent-0.01-bin.zip");
+                        if (exeCode != 0)
+                            throw new IOException("Failed to unzip mft distribution");
+
+                        exeCode = portForwardAgent.runCommand("sed -ir \"s/^[#]*\\s*agent.id=.*/agent.id=" + agentId + "/\" /home/ubuntu/MFT-Agent-0.01/conf/application.properties");
+                        if (exeCode != 0)
+                            throw new IOException("Failed to update agent id in config file");
+
+                        portForwardLock = new CountDownLatch(1);
+                        CountDownLatch portForwardPendingLock = portForwardAgent.createSshPortForward(8500, portForwardLock);
+
+                        logger.info("Waiting until the port forward is setup");
+                        portForwardPendingLock.await();
+
+                        exeCode = portForwardAgent.runCommand("sh MFT-Agent-0.01/bin/agent-daemon.sh start");
+                        if (exeCode != 0)
+                            throw new IOException("Failed to start the MFT Agent");
+
+                        // Waiting 10 seconds to start the Agent
+                        Thread.sleep(10000);
+
+                    } else {
+                        logger.info("Instance {} was not setup properly", instanceId);
+                        throw new Exception("Instance " + instanceId + " was not setup properly");
+                    }
+                } catch (Exception e) {
+                    logger.error("Failed preparing instance {}. Deleting the instance", instanceId);
+                    TerminateInstancesRequest terminateInstancesRequest = new TerminateInstancesRequest();
+                    terminateInstancesRequest.setInstanceIds(Collections.singleton(instanceId));
+                    amazonEC2.terminateInstances(terminateInstancesRequest);
+                    throw e;
+                }
+
                 return agentId;
             } catch (Exception e) {
-                throw new RuntimeException("Failed to spin up the AWS Agent", e);
+                logger.error("Failed to spin up the EC2 Agent", e);
+                throw new RuntimeException("Failed to spin up the EC2 Agent", e);
             }
         });
     }
 
     @Override
-    public Future<Boolean> terminate() {
-        return null;
+    public void terminate() {
+
+        terminateFuture =  executor.submit(() -> {
+            if (instanceId != null) {
+                String accessKey = secretWrapper.getS3().getAccessKey();
+                String secretKey = secretWrapper.getS3().getSecretKey();
+                String region = storageWrapper.getS3().getRegion();
+
+                BasicAWSCredentials awsCreds = new BasicAWSCredentials(accessKey, secretKey);
+
+                AmazonEC2 amazonEC2 = AmazonEC2ClientBuilder.standard().withEndpointConfiguration(new AwsClientBuilder.EndpointConfiguration(
+                                "https://ec2." + region + ".amazonaws.com", region))
+                        .withCredentials(new AWSStaticCredentialsProvider(awsCreds))
+                        .build();
+
+                if (portForwardLock != null) {
+                    portForwardLock.countDown();
+                }
+
+                TerminateInstancesRequest terminateInstancesRequest = new TerminateInstancesRequest();
+                terminateInstancesRequest.setInstanceIds(Collections.singleton(instanceId));
+                amazonEC2.terminateInstances(terminateInstancesRequest);
+            }
+            return true;
+        });
+    }
+
+    @Override
+    public Future<String> getLaunchState() {
+        return launchFuture;
+    }
+
+    @Override
+    public Future<Boolean> getTerminateState() {
+        return terminateFuture;
     }
 }
diff --git a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SSHProvider.java b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SSHProvider.java
new file mode 100644
index 0000000..bbd0166
--- /dev/null
+++ b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SSHProvider.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.airavata.mft.controller.spawner;
+
+import net.schmizz.keepalive.KeepAliveProvider;
+import net.schmizz.sshj.DefaultConfig;
+import net.schmizz.sshj.SSHClient;
+import net.schmizz.sshj.connection.channel.direct.Session;
+import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder;
+import net.schmizz.sshj.connection.channel.forwarded.SocketForwardingConnectListener;
+import net.schmizz.sshj.transport.verification.HostKeyVerifier;
+import net.schmizz.sshj.userauth.keyprovider.KeyProvider;
+import net.schmizz.sshj.userauth.method.AuthKeyboardInteractive;
+import net.schmizz.sshj.userauth.method.AuthMethod;
+import net.schmizz.sshj.userauth.method.AuthPublickey;
+import net.schmizz.sshj.userauth.method.ChallengeResponseProvider;
+import net.schmizz.sshj.userauth.password.Resource;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.security.PublicKey;
+import java.util.ArrayList;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+
+public class SSHProvider {
+
+    private static final Logger logger = LoggerFactory.getLogger(SSHProvider.class);
+
+    private SSHClient client;
+
+    public void initConnection(String hostName, int port, String keyPath, String user) throws IOException {
+        DefaultConfig defaultConfig = new DefaultConfig();
+        defaultConfig.setKeepAliveProvider(KeepAliveProvider.KEEP_ALIVE);
+
+        client = new SSHClient(defaultConfig);
+        client.addHostKeyVerifier(new HostKeyVerifier() {
+            @Override
+            public boolean verify(String s, int i, PublicKey publicKey) {
+                return true;
+            }
+
+            @Override
+            public List<String> findExistingAlgorithms(String s, int i) {
+                return null;
+            }
+        });
+
+        KeyProvider keyProvider = client.loadKeys(keyPath);
+
+        final List<AuthMethod> am = new LinkedList<>();
+        am.add(new AuthPublickey(keyProvider));
+
+        am.add(new AuthKeyboardInteractive(new ChallengeResponseProvider() {
+            @Override
+            public List<String> getSubmethods() {
+                return new ArrayList<>();
+            }
+
+            @Override
+            public void init(Resource resource, String name, String instruction) {
+            }
+
+            @Override
+            public char[] getResponse(String prompt, boolean echo) {
+                return new char[0];
+            }
+
+            @Override
+            public boolean shouldRetry() {
+                return false;
+            }
+        }));
+
+        client.connect(hostName, port);
+        client.auth(user, am);
+    }
+
+    public int runCommand(String command) throws IOException {
+        Session session = this.client.startSession();
+        logger.info("Running command {}", command);
+        Session.Command execResult = session.exec(command);
+        String stdOut = new String(execResult.getInputStream().readAllBytes());
+        String stdErr = new String(execResult.getErrorStream().readAllBytes());
+        logger.info("Std out: {}", stdOut);
+        logger.info("Std err: {}", stdErr);
+        logger.info("Exit code: {}", execResult.getExitStatus());
+        session.close();
+        return execResult.getExitStatus();
+    }
+
+    public CountDownLatch createSshPortForward(int localPort, CountDownLatch portForwardHoldLock) throws IOException, InterruptedException {
+
+        CountDownLatch portForwardCompleteLock = new CountDownLatch(1);
+        new Thread(() -> {
+            String consulHost = "localhost";
+
+            try {
+                client.getRemotePortForwarder().bind(
+                        new RemotePortForwarder.Forward(localPort),
+                        new SocketForwardingConnectListener(new InetSocketAddress(consulHost, localPort)));
+
+                portForwardCompleteLock.countDown();
+                logger.info("Created port forward to port " + localPort);
+                portForwardHoldLock.await();
+
+                logger.info("Releasing the remote port forward");
+                client.getRemotePortForwarder().cancel(new RemotePortForwarder.Forward(localPort));
+
+            } catch (Exception e) {
+                logger.error("Failed to create the remote port forward for port {}", localPort, e);
+            }
+        }).start();
+        return portForwardCompleteLock;
+    }
+}
diff --git a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SpawnerSelector.java b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SpawnerSelector.java
index 72cfad3..1c86446 100644
--- a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SpawnerSelector.java
+++ b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SpawnerSelector.java
@@ -28,7 +28,7 @@ public class SpawnerSelector {
         switch (storageWrapper.getStorageCase()) {
             case S3:
                 if (storageWrapper.getS3().getEndpoint().endsWith("amazonaws.com")) {
-                    return Optional.of(new AwsAgentSpawner(storageWrapper, secretWrapper));
+                    return Optional.of(new EC2AgentSpawner(storageWrapper, secretWrapper));
                 }
                 break;
         }
diff --git a/python-cli/mft_cli/mft_cli/main.py b/python-cli/mft_cli/mft_cli/main.py
index baf59d3..1accfd6 100644
--- a/python-cli/mft_cli/mft_cli/main.py
+++ b/python-cli/mft_cli/mft_cli/main.py
@@ -101,7 +101,8 @@ def copy(source, destination):
     transfer_request = MFTTransferApi_pb2.TransferApiRequest(sourceStorageId = source_storage_id,
                                                              sourceSecretId = source_secret_id,
                                                              destinationStorageId = dest_storage_id,
-                                                             destinationSecretId = dest_secret_id)
+                                                             destinationSecretId = dest_secret_id,
+                                                             optimizeTransferPath = True)
 
     if (source_metadata.WhichOneof('metadata') == 'directory') :
         if (destination[-1] != "/"):