#40 Guess repo name for pull request based on tracking branch
Merged 5 years ago by lsedlar. Opened 5 years ago by lsedlar.

file modified
+37 -7
@@ -10,6 +10,9 @@ 

      in_git_repo,

      get_default_upstream_branch,

      get_current_local_branch,

+     get_tracking_branch,

+     get_remote_url,

+     repo_url,

      run,

      die,

  )
@@ -33,6 +36,27 @@ 

      return repo, branch

  

  

+ def guess_repo_name(default_repo, current_branch, username):

+     """Given a name of the main repo, current branch and a user name, try to

+     guess repo name where the branch is pushed. This is either the main repo,

+     or a fork. In order for this to work, the branch must be set as remote

+     tracking.

+ 

+     Returns name of the remote repo and name of the branch on the remote.

+     """

+     tracking = get_tracking_branch()

+     if not tracking:

+         return default_repo, current_branch

+     remote, branch = tracking

+ 

+     remote_url = get_remote_url(remote)

+     for r in (default_repo, '%s/%s' % (username, default_repo)):

+         if remote_url == repo_url(r, ssh=True, git=True):

+             return r, branch

+ 

+     return default_repo, current_branch

+ 

+ 

  @app.command('pull-request')

  @assert_local_repo

  @click.option('-b', '--base', help='Branch to merge the changes in')
@@ -44,11 +68,15 @@ 

      current branch to default upstream branch (usually 'master' or 'develop').

  

      The '--head' option can be used to specify other branch than the current

-     one. You can open a pull request from a fork using

-     'YOUR_USERNAME:BRANCH_NAME' as argument to '--head'.

+     one. Note that the name in the remote repo is needed here. You can open a

+     pull request from a fork using 'YOUR_USERNAME:BRANCH_NAME' as argument to

+     '--head'. Alternatively you can push the branch with `-u` to make it track

+     the remote branch. In such case pag will automatically know to open the PR

+     from your fork.

      """

  

      name = in_git_repo()

+     username = conf['username']

  

      if base is None:

          try:
@@ -61,13 +89,16 @@ 

          name, base = split_input(base, name)

  

      if head is None:

-         head = get_current_local_branch()

+         local_head = get_current_local_branch()

+         name, head = guess_repo_name(name, head, username)

      else:

          name, head = split_input(head, name)

-         if '/' in name:

-             name = 'fork/' + name

+         local_head = head

  

-     cmd = ['git', 'log', '{base}..{head}'.format(base=base, head=head)]

+     if '/' in name:

+         name = 'fork/' + name

+ 

+     cmd = ['git', 'log', '{base}..{head}'.format(base=base, head=local_head)]

      _, log = run(cmd, echo=False)

  

      def modify(line):
@@ -92,7 +123,6 @@ 

      comment = comment.split(MARKER)[0]

      comment = comment.strip()

  

-     username = conf['username']

      if not client.is_logged_in:

          password = getpass.getpass("FAS password for %r" % username)

          client.login(username=username, password=password)

file modified
+20
@@ -83,6 +83,26 @@ 

              return real_ref[len(remote) + 1:]

  

  

+ def get_tracking_branch():

+     """Get branch name and the name of the remote repo that this branch is

+     tracking.

+     """

+     _, output = run(['git', 'status', '-sb', '--porcelain'], silent=True)

+     try:

+         # Get first line, ... separates local and remote names

+         return tuple(output.splitlines()[0]

+                      .split(' ')[1]

+                      .split('...')[1]

+                      .split('/', 1))

+     except IndexError:

+         return None

+ 

+ 

+ def get_remote_url(remote):

+     ret, output = run(['git', 'remote', 'get-url', remote], silent=True)

+     return output.strip() if ret == 0 else None

+ 

+ 

  def get_current_local_branch():

      _, stdout = run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])

      branch = stdout.strip()

file modified
+44 -14
@@ -36,11 +36,21 @@ 

          # and chdir into it.

          os.chdir(self.repo)

  

+         # Prepare a directory for creating a clone

+         self.clone = tempfile.mkdtemp(prefix='cloned_repo_')

+ 

      def tearDown(self):

          shutil.rmtree(self.repo)

+         shutil.rmtree(self.clone)

          # Change back to original working directory

          os.chdir(self.orig_path)

  

+     def _clone_to(self, remote, branch):

+         if branch != 'master':

+             self.cmd(['git', 'checkout', '-b', branch])

+         self.cmd(['git', 'clone', self.repo, self.clone, '-o', remote])

+         os.chdir(self.clone)

+ 

  

  class TestGetDefaultBranch(GitTestCase):

      """
@@ -50,20 +60,6 @@ 

      default branch.

      """

  

-     def setUp(self):

-         super().setUp()

-         self.clone = tempfile.mkdtemp(prefix='cloned_repo_')

- 

-     def tearDown(self):

-         super().tearDown()

-         shutil.rmtree(self.clone)

- 

-     def _clone_to(self, remote, branch):

-         if branch != 'master':

-             self.cmd(['git', 'checkout', '-b', branch])

-         self.cmd(['git', 'clone', self.repo, self.clone, '-o', remote])

-         os.chdir(self.clone)

- 

      def test_origin_master(self):

          self._clone_to('origin', 'master')

  
@@ -111,3 +107,37 @@ 

  

          with self.assertRaises(RuntimeError):

              utils.get_current_local_branch()

+ 

+ 

+ class TestGetTrackingBranch(GitTestCase):

+ 

+     def setUp(self):

+         super(TestGetTrackingBranch, self).setUp()

+         self._clone_to('origin', 'master')

+ 

+     def get_test_not_tracking(self, mock_run):

+         self.cmd(['git', 'checkout', '-b', 'foo'], cwd=self.clone)

+         self.assertEqual(utils.get_tracking_branch(), None)

+ 

+     def test_with_tracking_branch(self):

+         self.cmd(['git', 'checkout', '-b', 'foo'], cwd=self.clone)

+         self.cmd(['git', 'push', '-u', 'origin', 'foo:bar'], cwd=self.clone)

+         self.assertEqual(utils.get_tracking_branch(), ('origin', 'bar'))

+ 

+     def test_with_tracking_branch_ahead(self):

+         self.cmd(['git', 'checkout', '-b', 'foo'], cwd=self.clone)

+         self.cmd(['git', 'push', '-u', 'origin', 'foo:bar'], cwd=self.clone)

+         self.cmd(['git', 'commit', '--allow-empty', '-m', 'Dummy commit'],

+                  cwd=self.clone)

+         self.assertEqual(utils.get_tracking_branch(), ('origin', 'bar'))

+ 

+ 

+ class TestGetRemoteUrl(GitTestCase):

+ 

+     def test_no_remote(self):

+         self.assertEqual(utils.get_remote_url('origin'), None)

+ 

+     def test_with_remote(self):

+         self._clone_to('upstream', 'master')

+         os.chdir(self.clone)

+         self.assertEqual(utils.get_remote_url('upstream'), self.repo)

When current branch is tracking a remote branch, use that to guess repo name to from which to open the pull request.

Without the remote tracking info there is no change.

rebased onto f234e22

5 years ago

Pull-Request has been merged by lsedlar

5 years ago